知識蒸餾實(shí)戰(zhàn):使用CoatNet蒸餾ResNet
@
步驟
導(dǎo)入需要的庫
定義訓(xùn)練和驗(yàn)證函數(shù)
定義全局參數(shù)
圖像預(yù)處理與增強(qiáng)
讀取數(shù)據(jù)
設(shè)置模型和Loss
步驟
導(dǎo)入需要的庫
定義訓(xùn)練和驗(yàn)證函數(shù)
定義全局參數(shù)
圖像預(yù)處理與增強(qiáng)
讀取數(shù)據(jù)
設(shè)置模型和Loss
步驟
導(dǎo)入需要的庫
定義蒸餾函數(shù)
定義訓(xùn)練和驗(yàn)證函數(shù)
定義全局參數(shù)
圖像預(yù)處理與增強(qiáng)
讀取數(shù)據(jù)
設(shè)置模型和Loss
摘要
知識蒸餾(Knowledge Distillation),簡稱KD,將已經(jīng)訓(xùn)練好的模型包含的知識(”Knowledge”),蒸餾(“Distill”)提取到另一個(gè)模型里面去。Hinton在"Distilling the Knowledge in a Neural Network"首次提出了知識蒸餾(暗知識提?。┑母拍?,通過引入與教師網(wǎng)絡(luò)(Teacher network:復(fù)雜、但預(yù)測精度優(yōu)越)相關(guān)的軟目標(biāo)(Soft-target)作為Total loss的一部分,以誘導(dǎo)學(xué)生網(wǎng)絡(luò)(Student network:精簡、低復(fù)雜度,更適合推理部署)的訓(xùn)練,實(shí)現(xiàn)知識遷移(Knowledge transfer)。論文鏈接:https://arxiv.org/pdf/1503.02531.pdf
蒸餾的過程
知識蒸餾使用的是Teacher—Student模型,其中teacher是“知識”的輸出者,student是“知識”的接受者。知識蒸餾的過程分為2個(gè)階段:
原始模型訓(xùn)練: 訓(xùn)練"Teacher模型", 簡稱為Net-T,它的特點(diǎn)是模型相對復(fù)雜,也可以由多個(gè)分別訓(xùn)練的模型集成而成。我們對"Teacher模型"不作任何關(guān)于模型架構(gòu)、參數(shù)量、是否集成方面的限制,唯一的要求就是,對于輸入X, 其都能輸出Y,其中Y經(jīng)過softmax的映射,輸出值對應(yīng)相應(yīng)類別的概率值。
精簡模型訓(xùn)練: 訓(xùn)練"Student模型", 簡稱為Net-S,它是參數(shù)量較小、模型結(jié)構(gòu)相對簡單的單模型。同樣的,對于輸入X,其都能輸出Y,Y經(jīng)過softmax映射后同樣能輸出對應(yīng)相應(yīng)類別的概率值。
Teacher學(xué)習(xí)能力強(qiáng),可以將它學(xué)到的知識遷移給學(xué)習(xí)能力相對弱的Student模型,以此來增強(qiáng)Student模型的泛化能力。復(fù)雜笨重但是效果好的Teacher模型不上線,就單純是個(gè)導(dǎo)師角色,真正部署上線進(jìn)行預(yù)測任務(wù)的是靈活輕巧的Student小模型。

最終結(jié)論
先把結(jié)論說了吧! Teacher網(wǎng)絡(luò)使用coatnet_2,Student網(wǎng)絡(luò)使用ResNet18。如下表
網(wǎng)絡(luò)epochsACCcoatnet_25092%ResNet185086%ResNet18 +KD5089%
在相同的條件下,加入知識蒸餾后,ResNet18的ACC上升了3個(gè)點(diǎn),提升的還是很高的。如下圖:

數(shù)據(jù)準(zhǔn)備
數(shù)據(jù)使用我以前在圖像分類任務(wù)中的數(shù)據(jù)集——植物幼苗數(shù)據(jù)集,先將數(shù)據(jù)集轉(zhuǎn)為訓(xùn)練集和驗(yàn)證集。執(zhí)行代碼:
import?glob
import?os
import?shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if?os.path.exists(file_dir):
????print('true')
????#os.rmdir(file_dir)
????shutil.rmtree(file_dir)#刪除再建立
????os.makedirs(file_dir)
else:
????os.makedirs(file_dir)
from?sklearn.model_selection?import?train_test_split
trainval_files,?val_files?=?train_test_split(image_list,?test_size=0.3,?random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for?file?in?trainval_files:
????file_class=file.replace("\\","/").split('/')[-2]
????file_name=file.replace("\\","/").split('/')[-1]
????file_class=os.path.join(train_root,file_class)
????if?not?os.path.isdir(file_class):
????????os.makedirs(file_class)
????shutil.copy(file,?file_class?+?'/'?+?file_name)
for?file?in?val_files:
????file_class=file.replace("\\","/").split('/')[-2]
????file_name=file.replace("\\","/").split('/')[-1]
????file_class=os.path.join(val_root,file_class)
????if?not?os.path.isdir(file_class):
????????os.makedirs(file_class)
????shutil.copy(file,?file_class?+?'/'?+?file_name)
教師網(wǎng)絡(luò)
教師網(wǎng)絡(luò)選用coatnet_2,是一個(gè)比較大一點(diǎn)的網(wǎng)絡(luò)了,模型的大小有200M。訓(xùn)練50個(gè)epoch,最好的模型在92%左右。
步驟
新建teacher_train.py,插入代碼:
導(dǎo)入需要的庫
import?torch.optim?as?optim
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?torchvision?import?datasets
from?torch.autograd?import?Variable
from?model.coatnet?import?coatnet_2
import?json
import?os
定義訓(xùn)練和驗(yàn)證函數(shù)
def?train(model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????sum_loss?=?0
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????output?=?model(data)
????????loss?=?criterion(output,?target)
????????optimizer.zero_grad()
????????loss.backward()
????????optimizer.step()
????????print_loss?=?loss.data.item()
????????sum_loss?+=?print_loss
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))
????ave_loss?=?sum_loss?/?len(train_loader)
????print('epoch:{},loss:{}'.format(epoch,?ave_loss))
Best_ACC=0
#?驗(yàn)證過程
def?val(model,?device,?test_loader):
????global?Best_ACC
????model.eval()
????test_loss?=?0
????correct?=?0
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????????output?=?model(data)
????????????loss?=?criterion(output,?target)
????????????_,?pred?=?torch.max(output.data,?1)
????????????correct?+=?torch.sum(pred?==?target)
????????????print_loss?=?loss.data.item()
????????????test_loss?+=?print_loss
????????correct?=?correct.data.item()
????????acc?=?correct?/?total_num
????????avgloss?=?test_loss?/?len(test_loader)
????????if?acc?>?Best_ACC:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????????Best_ACC?=?acc
????????print('\nVal?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????????avgloss,?correct,?len(test_loader.dataset),?100?*?acc))
????????return?acc
定義全局參數(shù)
if?__name__?==?'__main__':
????#?創(chuàng)建保存模型的文件夾
????file_dir?=?'CoatNet'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,?exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????modellr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?50
????DEVICE?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
圖像預(yù)處理與增強(qiáng)
?#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,?5),?sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
設(shè)置模型和Loss
???#?實(shí)例化模型并且移動(dòng)到GPU
????criterion?=?nn.CrossEntropyLoss()
????model_ft?=?coatnet_2()
????num_ftrs?=?model_ft.fc.in_features
????model_ft.fc?=?nn.Linear(num_ftrs,?12)
????model_ft.to(DEVICE)
????#?選擇簡單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.Adam(model_ft.parameters(),?lr=modellr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-9)
????#?訓(xùn)練
????val_acc_list=?{}
????for?epoch?in?range(1,?EPOCHS?+?1):
????????train(model_ft,?DEVICE,?train_loader,?optimizer,?epoch)
????????cosine_schedule.step()
????????acc=val(model_ft,?DEVICE,?test_loader)
????????val_acc_list[epoch]=acc
????????with?open('result.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(val_acc_list))
????torch.save(model_ft,?'CoatNet/model_final.pth')
完成上面的代碼就可以開始訓(xùn)練Teacher網(wǎng)絡(luò)了。
學(xué)生網(wǎng)絡(luò)
學(xué)生網(wǎng)絡(luò)選用ResNet18,是一個(gè)比較小一點(diǎn)的網(wǎng)絡(luò)了,模型的大小有40M。訓(xùn)練50個(gè)epoch,最好的模型在86%左右。
步驟
新建student_train.py,插入代碼:
導(dǎo)入需要的庫
import?torch.optim?as?optim
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?torchvision?import?datasets
from?torch.autograd?import?Variable
from?torchvision.models.resnet?import?resnet18
import?json
import?os
定義訓(xùn)練和驗(yàn)證函數(shù)
def?train(model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????sum_loss?=?0
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????output?=?model(data)
????????loss?=?criterion(output,?target)
????????optimizer.zero_grad()
????????loss.backward()
????????optimizer.step()
????????print_loss?=?loss.data.item()
????????sum_loss?+=?print_loss
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))
????ave_loss?=?sum_loss?/?len(train_loader)
????print('epoch:{},loss:{}'.format(epoch,?ave_loss))
Best_ACC=0
#?驗(yàn)證過程
def?val(model,?device,?test_loader):
????global?Best_ACC
????model.eval()
????test_loss?=?0
????correct?=?0
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????????output?=?model(data)
????????????loss?=?criterion(output,?target)
????????????_,?pred?=?torch.max(output.data,?1)
????????????correct?+=?torch.sum(pred?==?target)
????????????print_loss?=?loss.data.item()
????????????test_loss?+=?print_loss
????????correct?=?correct.data.item()
????????acc?=?correct?/?total_num
????????avgloss?=?test_loss?/?len(test_loader)
????????if?acc?>?Best_ACC:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????????Best_ACC?=?acc
????????print('\nVal?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????????avgloss,?correct,?len(test_loader.dataset),?100?*?acc))
????????return?acc
定義全局參數(shù)
if?__name__?==?'__main__':
????#?創(chuàng)建保存模型的文件夾
????file_dir?=?'resnet'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,?exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????modellr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?50
????DEVICE?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
圖像預(yù)處理與增強(qiáng)
?#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,?5),?sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
設(shè)置模型和Loss
?#?實(shí)例化模型并且移動(dòng)到GPU
????criterion?=?nn.CrossEntropyLoss()
????model_ft?=?resnet18()
????print(model_ft)
????num_ftrs?=?model_ft.fc.in_features
????model_ft.fc?=?nn.Linear(num_ftrs,?12)
????model_ft.to(DEVICE)
????#?選擇簡單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.Adam(model_ft.parameters(),?lr=modellr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-9)
????#?訓(xùn)練
????val_acc_list=?{}
????for?epoch?in?range(1,?EPOCHS?+?1):
????????train(model_ft,?DEVICE,?train_loader,?optimizer,?epoch)
????????cosine_schedule.step()
????????acc=val(model_ft,?DEVICE,?test_loader)
????????val_acc_list[epoch]=acc
????????with?open('result_student.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(val_acc_list))
????torch.save(model_ft,?'resnet/model_final.pth')
完成上面的代碼就可以開始訓(xùn)練Student網(wǎng)絡(luò)了。
蒸餾學(xué)生網(wǎng)絡(luò)
學(xué)生網(wǎng)絡(luò)繼續(xù)選用ResNet18,使用Teacher網(wǎng)絡(luò)蒸餾學(xué)生網(wǎng)絡(luò),訓(xùn)練50個(gè)epoch,最終ACC是89%。
步驟
新建student_kd_train.py,插入代碼:
導(dǎo)入需要的庫
import?torch.optim?as?optim
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?torchvision?import?datasets
from?torch.autograd?import?Variable
from?torchvision.models.resnet?import?resnet18
import?json
import?os
定義蒸餾函數(shù)
def?distillation(y,?labels,?teacher_scores,?temp,?alpha):
????return?nn.KLDivLoss()(F.log_softmax(y?/?temp,?dim=1),?F.softmax(teacher_scores?/?temp,?dim=1))?*?(
????????????temp?*?temp?*?2.0?*?alpha)?+?F.cross_entropy(y,?labels)?*?(1.?-?alpha)
定義訓(xùn)練和驗(yàn)證函數(shù)
#?定義訓(xùn)練過程
def?train(model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????sum_loss?=?0
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?data.to(device),?target.to(device)
????????optimizer.zero_grad()
????????output?=?model(data)
????????teacher_output?=?teacher_model(data)??#?訓(xùn)練出教師的?teacher_output
????????teacher_output?=?teacher_output.detach()??#?切斷老師網(wǎng)絡(luò)的反向傳播
????????loss?=?distillation(output,?target,?teacher_output,?temp=7.0,?alpha=0.7)??#?通過老師的?teacher_output訓(xùn)練學(xué)生的output
????????loss.backward()
????????optimizer.step()
????????print_loss?=?loss.data.item()
????????sum_loss?+=?print_loss
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))
????ave_loss?=?sum_loss?/?len(train_loader)
????print('epoch:{},loss:{}'.format(epoch,?ave_loss))
Best_ACC=0
#?驗(yàn)證過程
def?val(model,?device,?test_loader):
????global?Best_ACC
????model.eval()
????test_loss?=?0
????correct?=?0
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????????output?=?model(data)
????????????loss?=?criterion(output,?target)
????????????_,?pred?=?torch.max(output.data,?1)
????????????correct?+=?torch.sum(pred?==?target)
????????????print_loss?=?loss.data.item()
????????????test_loss?+=?print_loss
????????correct?=?correct.data.item()
????????acc?=?correct?/?total_num
????????avgloss?=?test_loss?/?len(test_loader)
????????if?acc?>?Best_ACC:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????????Best_ACC?=?acc
????????print('\nVal?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????????avgloss,?correct,?len(test_loader.dataset),?100?*?acc))
????????return?acc
定義全局參數(shù)
if?__name__?==?'__main__':
????#?創(chuàng)建保存模型的文件夾
????file_dir?=?'resnet_kd'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,?exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????modellr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?50
????DEVICE?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
圖像預(yù)處理與增強(qiáng)
?#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,?5),?sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
設(shè)置模型和Loss
?#?實(shí)例化模型并且移動(dòng)到GPU
????criterion?=?nn.CrossEntropyLoss()
????model_ft?=?resnet18()
????print(model_ft)
????num_ftrs?=?model_ft.fc.in_features
????model_ft.fc?=?nn.Linear(num_ftrs,?12)
????model_ft.to(DEVICE)
????#?選擇簡單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.Adam(model_ft.parameters(),?lr=modellr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-9)
????#?訓(xùn)練
????val_acc_list=?{}
????for?epoch?in?range(1,?EPOCHS?+?1):
????????train(model_ft,?DEVICE,?train_loader,?optimizer,?epoch)
????????cosine_schedule.step()
????????acc=val(model_ft,?DEVICE,?test_loader)
????????val_acc_list[epoch]=acc
????????with?open('result_student.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(val_acc_list))
????torch.save(model_ft,?'resnet_kd/model_final.pth')
完成上面的代碼就可以開始蒸餾模式?。?!
結(jié)果比對
加載保存的結(jié)果,然后繪制acc曲線。
import?numpy?as?np
from?matplotlib?import?pyplot?as?plt
import?json
teacher_file='result.json'
student_file='result_student.json'
student_kd_file='result_kd.json'
def?read_json(file):
????with?open(file,?'r',?encoding='utf8')?as?fp:
????????json_data?=?json.load(fp)
????????print(json_data)
????return?json_data
teacher_data=read_json(teacher_file)
student_data=read_json(student_file)
student_kd_data=read_json(student_kd_file)
x?=[int(x)?for?x?in??list(dict(teacher_data).keys())]
print(x)
plt.plot(x,?list(teacher_data.values()),?label='teacher')
plt.plot(x,list(student_data.values()),?label='student?without?KD')
plt.plot(x,?list(student_kd_data.values()),?label='student?with?KD')
plt.title('Test?accuracy')
plt.legend()
plt.show()
總結(jié)
知識蒸餾是常用的一種對輕量化模型壓縮和提升的方法。今天通過一個(gè)簡單的例子講解了如何使用Teacher網(wǎng)絡(luò)對Student網(wǎng)絡(luò)進(jìn)行蒸餾。
本次用到的代碼和數(shù)據(jù)集:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/87029904
碼字不易,歡迎大家點(diǎn)贊評論收藏!