知識(shí)蒸餾IRG算法實(shí)戰(zhàn):使用ResNet50蒸餾ResNet18
摘要
復(fù)雜度的檢測(cè)模型雖然可以取得SOTA的精度,但它們往往難以直接落地應(yīng)用。模型壓縮方法幫助模型在效率和精度之間進(jìn)行折中。知識(shí)蒸餾是模型壓縮的一種有效手段,它的核心思想是迫使輕量級(jí)的學(xué)生模型去學(xué)習(xí)教師模型提取到的知識(shí),從而提高學(xué)生模型的性能。已有的知識(shí)蒸餾方法可以分別為三大類:
基于特征的(feature-based,例如VID、NST、FitNets、fine-grained feature imitation)
基于關(guān)系的(relation-based,例如IRG、Relational KD、CRD、similarity-preserving knowledge distillation)
基于響應(yīng)的(response-based,例如Hinton的知識(shí)蒸餾開(kāi)山之作。

今天我們就嘗試用基于關(guān)系的IRG知識(shí)蒸餾算法完成這篇實(shí)戰(zhàn)。IRG蒸餾是對(duì)模型里面的的Block和展平層做蒸餾,所以需要返回每個(gè)block層的值和展平層的值。所以我們對(duì)模型要做修改來(lái)適應(yīng)IRG算法,并且為了使Teacher和Student的網(wǎng)絡(luò)層之間的參數(shù)一致,我們這次選用ResNet50作為T(mén)eacher模型,選擇ResNet18作為Student。
最終結(jié)論
先把結(jié)論說(shuō)了吧! Teacher網(wǎng)絡(luò)使用ResNet50 ,Student網(wǎng)絡(luò)使用ResNet18。如下表
網(wǎng)絡(luò)epochsACCResNet5010086%ResNet1810089%ResNet18 +IRG10089%
這個(gè)結(jié)論有點(diǎn)意外,ResNet50 和ResNet18 模型都是我自己寫(xiě)的。我嘗試了ResNet151和ResNet102,這兩個(gè)模型的結(jié)果和ResNet50差不多,都是86%左右,相反,ResNet18 卻有89%的準(zhǔn)確率。ResNet18 +IRG的準(zhǔn)確率也是89%。
模型
模型沒(méi)有用pytorch官方自帶的,而是參照以前總結(jié)的ResNet模型修改的。ResNet模型結(jié)構(gòu)如下圖:

ResNet18, ResNet34
ResNet18, ResNet34模型的殘差結(jié)構(gòu)是一致的,結(jié)構(gòu)如下:

代碼如下: resnet.py
import?torch
import?torchvision
from?torch?import?nn
from?torch.nn?import?functional?as?F
#?from?torchsummary?import?summary
class?ResidualBlock(nn.Module):
????"""
????實(shí)現(xiàn)子module:?Residual?Block
????"""
????def?__init__(self,?inchannel,?outchannel,?stride=1,?shortcut=None):
????????super(ResidualBlock,?self).__init__()
????????self.left?=?nn.Sequential(
????????????nn.Conv2d(inchannel,?outchannel,?3,?stride,?1,?bias=False),
????????????nn.BatchNorm2d(outchannel),
????????????nn.ReLU(inplace=True),
????????????nn.Conv2d(outchannel,?outchannel,?3,?1,?1,?bias=False),
????????????nn.BatchNorm2d(outchannel)
????????)
????????self.right?=?shortcut
????def?forward(self,?x):
????????out?=?self.left(x)
????????residual?=?x?if?self.right?is?None?else?self.right(x)
????????out?+=?residual
????????return?F.relu(out)
class?ResNet(nn.Module):
????"""
????實(shí)現(xiàn)主module:ResNet34
????ResNet34包含多個(gè)layer,每個(gè)layer又包含多個(gè)Residual?block
????用子module來(lái)實(shí)現(xiàn)Residual?block,用_make_layer函數(shù)來(lái)實(shí)現(xiàn)layer
????"""
????def?__init__(self,?blocks,?num_classes=1000):
????????super(ResNet,?self).__init__()
????????self.model_name?=?'resnet34'
????????#?前幾層:?圖像轉(zhuǎn)換
????????self.pre?=?nn.Sequential(
????????????nn.Conv2d(3,?64,?7,?2,?3,?bias=False),
????????????nn.BatchNorm2d(64),
????????????nn.ReLU(inplace=True),
????????????nn.MaxPool2d(3,?2,?1))
????????#?重復(fù)的layer,分別有3,4,6,3個(gè)residual?block
????????self.layer1?=?self._make_layer(64,?64,?blocks[0])
????????self.layer2?=?self._make_layer(64,?128,?blocks[1],?stride=2)
????????self.layer3?=?self._make_layer(128,?256,?blocks[2],?stride=2)
????????self.layer4?=?self._make_layer(256,?512,?blocks[3],?stride=2)
????????#?分類用的全連接
????????self.fc?=?nn.Linear(512,?num_classes)
????def?_make_layer(self,?inchannel,?outchannel,?block_num,?stride=1):
????????"""
????????構(gòu)建layer,包含多個(gè)residual?block
????????"""
????????shortcut?=?nn.Sequential(
????????????nn.Conv2d(inchannel,?outchannel,?1,?stride,?bias=False),
????????????nn.BatchNorm2d(outchannel),
????????????nn.ReLU()
????????)
????????layers?=?[]
????????layers.append(ResidualBlock(inchannel,?outchannel,?stride,?shortcut))
????????for?i?in?range(1,?block_num):
????????????layers.append(ResidualBlock(outchannel,?outchannel))
????????return?nn.Sequential(*layers)
????def?forward(self,?x):
????????x?=?self.pre(x)
????????l1_out?=?self.layer1(x)
????????l2_out?=?self.layer2(l1_out)
????????l3_out?=?self.layer3(l2_out)
????????l4_out?=?self.layer4(l3_out)
????????p_out?=?F.avg_pool2d(l4_out,?7)
????????fea?=?p_out.view(p_out.size(0),?-1)
????????out=self.fc(fea)
????????return?l1_out,l2_out,l3_out,l4_out,fea,out
def?ResNet18():
????return?ResNet([2,?2,?2,?2])
def?ResNet34():
????return?ResNet([3,?4,?6,?3])
if?__name__?==?'__main__':
????device?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu")
????model?=?ResNet34()
????model.to(device)
????#?summary(model,?(3,?224,?224))
主要修改了輸出結(jié)果,將每個(gè)block的結(jié)果輸出出來(lái)。
RseNet50、 RseNet101、 RseNet152
這個(gè)三個(gè)模型的block是一致的,結(jié)構(gòu)如下:

代碼: resnet_l.py
import?torch
import?torch.nn?as?nn
import?torchvision
import?numpy?as?np
print("PyTorch?Version:?",?torch.__version__)
print("Torchvision?Version:?",?torchvision.__version__)
__all__?=?['ResNet50',?'ResNet101',?'ResNet152']
def?Conv1(in_planes,?places,?stride=2):
????return?nn.Sequential(
????????nn.Conv2d(in_channels=in_planes,?out_channels=places,?kernel_size=7,?stride=stride,?padding=3,?bias=False),
????????nn.BatchNorm2d(places),
????????nn.ReLU(inplace=True),
????????nn.MaxPool2d(kernel_size=3,?stride=2,?padding=1)
????)
class?Bottleneck(nn.Module):
????def?__init__(self,?in_places,?places,?stride=1,?downsampling=False,?expansion=4):
????????super(Bottleneck,?self).__init__()
????????self.expansion?=?expansion
????????self.downsampling?=?downsampling
????????self.bottleneck?=?nn.Sequential(
????????????nn.Conv2d(in_channels=in_places,?out_channels=places,?kernel_size=1,?stride=1,?bias=False),
????????????nn.BatchNorm2d(places),
????????????nn.ReLU(inplace=True),
????????????nn.Conv2d(in_channels=places,?out_channels=places,?kernel_size=3,?stride=stride,?padding=1,?bias=False),
????????????nn.BatchNorm2d(places),
????????????nn.ReLU(inplace=True),
????????????nn.Conv2d(in_channels=places,?out_channels=places?*?self.expansion,?kernel_size=1,?stride=1,?bias=False),
????????????nn.BatchNorm2d(places?*?self.expansion),
????????)
????????if?self.downsampling:
????????????self.downsample?=?nn.Sequential(
????????????????nn.Conv2d(in_channels=in_places,?out_channels=places?*?self.expansion,?kernel_size=1,?stride=stride,
??????????????????????????bias=False),
????????????????nn.BatchNorm2d(places?*?self.expansion)
????????????)
????????self.relu?=?nn.ReLU(inplace=True)
????def?forward(self,?x):
????????residual?=?x
????????out?=?self.bottleneck(x)
????????if?self.downsampling:
????????????residual?=?self.downsample(x)
????????out?+=?residual
????????out?=?self.relu(out)
????????return?out
class?ResNet(nn.Module):
????def?__init__(self,?blocks,?num_classes=1000,?expansion=4):
????????super(ResNet,?self).__init__()
????????self.expansion?=?expansion
????????self.conv1?=?Conv1(in_planes=3,?places=64)
????????self.layer1?=?self.make_layer(in_places=64,?places=64,?block=blocks[0],?stride=1)
????????self.layer2?=?self.make_layer(in_places=256,?places=128,?block=blocks[1],?stride=2)
????????self.layer3?=?self.make_layer(in_places=512,?places=256,?block=blocks[2],?stride=2)
????????self.layer4?=?self.make_layer(in_places=1024,?places=512,?block=blocks[3],?stride=2)
????????self.avgpool?=?nn.AvgPool2d(7,?stride=1)
????????self.fc?=?nn.Linear(2048,?num_classes)
????????for?m?in?self.modules():
????????????if?isinstance(m,?nn.Conv2d):
????????????????nn.init.kaiming_normal_(m.weight,?mode='fan_out',?nonlinearity='relu')
????????????elif?isinstance(m,?nn.BatchNorm2d):
????????????????nn.init.constant_(m.weight,?1)
????????????????nn.init.constant_(m.bias,?0)
????def?make_layer(self,?in_places,?places,?block,?stride):
????????layers?=?[]
????????layers.append(Bottleneck(in_places,?places,?stride,?downsampling=True))
????????for?i?in?range(1,?block):
????????????layers.append(Bottleneck(places?*?self.expansion,?places))
????????return?nn.Sequential(*layers)
????def?forward(self,?x):
????????x?=?self.conv1(x)
????????l1_out?=?self.layer1(x)
????????l2_out?=?self.layer2(l1_out)
????????l3_out?=?self.layer3(l2_out)
????????l4_out?=?self.layer4(l3_out)
????????p_out?=?self.avgpool(l4_out)
????????fea?=?p_out.view(p_out.size(0),?-1)
????????out?=?self.fc(fea)
????????return?l1_out,?l2_out,?l3_out,?l4_out,?fea,?out
def?ResNet50():
????return?ResNet([3,?4,?6,?3])
def?ResNet101():
????return?ResNet([3,?4,?23,?3])
def?ResNet152():
????return?ResNet([3,?8,?36,?3])
if?__name__?==?'__main__':
????#?model?=?torchvision.models.resnet50()
????model?=?ResNet50()
????print(model)
????input?=?torch.randn(1,?3,?224,?224)
????out?=?model(input)
????print(out.shape)
同上,將每個(gè)block都輸出出來(lái)。
數(shù)據(jù)準(zhǔn)備
數(shù)據(jù)使用我以前在圖像分類任務(wù)中的數(shù)據(jù)集——植物幼苗數(shù)據(jù)集,先將數(shù)據(jù)集轉(zhuǎn)為訓(xùn)練集和驗(yàn)證集。執(zhí)行代碼:
import?glob
import?os
import?shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if?os.path.exists(file_dir):
????print('true')
????#os.rmdir(file_dir)
????shutil.rmtree(file_dir)#刪除再建立
????os.makedirs(file_dir)
else:
????os.makedirs(file_dir)
from?sklearn.model_selection?import?train_test_split
trainval_files,?val_files?=?train_test_split(image_list,?test_size=0.3,?random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for?file?in?trainval_files:
????file_class=file.replace("\\","/").split('/')[-2]
????file_name=file.replace("\\","/").split('/')[-1]
????file_class=os.path.join(train_root,file_class)
????if?not?os.path.isdir(file_class):
????????os.makedirs(file_class)
????shutil.copy(file,?file_class?+?'/'?+?file_name)
for?file?in?val_files:
????file_class=file.replace("\\","/").split('/')[-2]
????file_name=file.replace("\\","/").split('/')[-1]
????file_class=os.path.join(val_root,file_class)
????if?not?os.path.isdir(file_class):
????????os.makedirs(file_class)
????shutil.copy(file,?file_class?+?'/'?+?file_name)
訓(xùn)練Teacher模型
Teacher選用ResNet50。
步驟
新建teacher_train.py,插入代碼:
導(dǎo)入需要的庫(kù)
import?torch.optim?as?optim
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?torchvision?import?datasets
from?torch.autograd?import?Variable
from?model.resnet_l?import?ResNet50
import?json
import?os
定義訓(xùn)練和驗(yàn)證函數(shù)
def?train(model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????sum_loss?=?0
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????l1_out,l2_out,l3_out,l4_out,fea,?out?=?model(data)
????????loss?=?criterion(out,?target)
????????optimizer.zero_grad()
????????loss.backward()
????????optimizer.step()
????????print_loss?=?loss.data.item()
????????sum_loss?+=?print_loss
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))
????ave_loss?=?sum_loss?/?len(train_loader)
????print('epoch:{},loss:{}'.format(epoch,?ave_loss))
Best_ACC=0
#?驗(yàn)證過(guò)程
def?val(model,?device,?test_loader):
????global?Best_ACC
????model.eval()
????test_loss?=?0
????correct?=?0
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????????l1_out,l2_out,l3_out,l4_out,fea,?out?=?model(data)
????????????loss?=?criterion(out,?target)
????????????_,?pred?=?torch.max(out.data,?1)
????????????correct?+=?torch.sum(pred?==?target)
????????????print_loss?=?loss.data.item()
????????????test_loss?+=?print_loss
????????correct?=?correct.data.item()
????????acc?=?correct?/?total_num
????????avgloss?=?test_loss?/?len(test_loader)
????????if?acc?>?Best_ACC:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????????Best_ACC?=?acc
????????print('\nVal?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????????avgloss,?correct,?len(test_loader.dataset),?100?*?acc))
????????return?acc
定義全局參數(shù)
if?__name__?==?'__main__':
????#?創(chuàng)建保存模型的文件夾
????file_dir?=?'CoatNet'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,?exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????modellr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?100
????DEVICE?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
圖像預(yù)處理與增強(qiáng)
?#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,?5),?sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
設(shè)置模型和Loss
????#?實(shí)例化模型并且移動(dòng)到GPU
????criterion?=?nn.CrossEntropyLoss()
????model_ft?=?ResNet50()
????num_ftrs?=?model_ft.fc.in_features
????model_ft.fc?=?nn.Linear(num_ftrs,?12)
????model_ft.to(DEVICE)
????#?選擇簡(jiǎn)單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.Adam(model_ft.parameters(),?lr=modellr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-9)
????#?訓(xùn)練
????val_acc_list=?{}
????for?epoch?in?range(1,?EPOCHS?+?1):
????????train(model_ft,?DEVICE,?train_loader,?optimizer,?epoch)
????????cosine_schedule.step()
????????acc=val(model_ft,?DEVICE,?test_loader)
????????val_acc_list[epoch]=acc
????????with?open('result.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(val_acc_list))
????torch.save(model_ft,?'CoatNet/model_final.pth')
完成上面的代碼就可以開(kāi)始訓(xùn)練Teacher網(wǎng)絡(luò)了。
學(xué)生網(wǎng)絡(luò)
學(xué)生網(wǎng)絡(luò)選用ResNet18,是一個(gè)比較小一點(diǎn)的網(wǎng)絡(luò)了,模型的大小有40M。訓(xùn)練100個(gè)epoch。
步驟
新建student_train.py,插入代碼:
導(dǎo)入需要的庫(kù)
import?torch.optim?as?optim
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?torchvision?import?datasets
from?torch.autograd?import?Variable
from?model.resnet?import?ResNet18
import?json
import?os
定義訓(xùn)練和驗(yàn)證函數(shù)
#?定義訓(xùn)練過(guò)程
def?train(model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????sum_loss?=?0
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????l1_out,l2_out,l3_out,l4_out,fea,out?=?model(data)
????????loss?=?criterion(out,?target)
????????optimizer.zero_grad()
????????loss.backward()
????????optimizer.step()
????????print_loss?=?loss.data.item()
????????sum_loss?+=?print_loss
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))
????ave_loss?=?sum_loss?/?len(train_loader)
????print('epoch:{},loss:{}'.format(epoch,?ave_loss))
Best_ACC=0
#?驗(yàn)證過(guò)程
def?val(model,?device,?test_loader):
????global?Best_ACC
????model.eval()
????test_loss?=?0
????correct?=?0
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?Variable(data).to(device),?Variable(target).to(device)
????????????l1_out,l2_out,l3_out,l4_out,fea,out?=?model(data)
????????????loss?=?criterion(out,?target)
????????????_,?pred?=?torch.max(out.data,?1)
????????????correct?+=?torch.sum(pred?==?target)
????????????print_loss?=?loss.data.item()
????????????test_loss?+=?print_loss
????????correct?=?correct.data.item()
????????acc?=?correct?/?total_num
????????avgloss?=?test_loss?/?len(test_loader)
????????if?acc?>?Best_ACC:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????????Best_ACC?=?acc
????????print('\nVal?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????????avgloss,?correct,?len(test_loader.dataset),?100?*?acc))
????????return?acc
定義全局參數(shù)
if?__name__?==?'__main__':
????#?創(chuàng)建保存模型的文件夾
????file_dir?=?'resnet'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,?exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????modellr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?100
????DEVICE?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
圖像預(yù)處理與增強(qiáng)
?#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,?5),?sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
設(shè)置模型和Loss
??#?實(shí)例化模型并且移動(dòng)到GPU
????criterion?=?nn.CrossEntropyLoss()
????model_ft?=?ResNet18()
????print(model_ft)
????num_ftrs?=?model_ft.fc.in_features
????model_ft.fc?=?nn.Linear(num_ftrs,?12)
????model_ft.to(DEVICE)
????#?選擇簡(jiǎn)單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.Adam(model_ft.parameters(),?lr=modellr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-9)
????#?訓(xùn)練
????val_acc_list=?{}
????for?epoch?in?range(1,?EPOCHS?+?1):
????????train(model_ft,?DEVICE,?train_loader,?optimizer,?epoch)
????????cosine_schedule.step()
????????acc=val(model_ft,?DEVICE,?test_loader)
????????val_acc_list[epoch]=acc
????????with?open('result_student.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(val_acc_list))
????torch.save(model_ft,?'resnet/model_final.pth')
完成上面的代碼就可以開(kāi)始訓(xùn)練Student網(wǎng)絡(luò)了。
蒸餾學(xué)生網(wǎng)絡(luò)
學(xué)生網(wǎng)絡(luò)繼續(xù)選用ResNet18,使用Teacher網(wǎng)絡(luò)蒸餾學(xué)生網(wǎng)絡(luò),訓(xùn)練100個(gè)epoch。 IRG知識(shí)蒸餾的腳本詳見(jiàn): https://wanghao.blog.csdn.net/article/details/127802486?spm=1001.2014.3001.5502。 代碼如下: irg.py
from?__future__?import?absolute_import
from?__future__?import?print_function
from?__future__?import?division
import?torch
import?torch.nn?as?nn
import?torch.nn.functional?as?F
class?IRG(nn.Module):
?'''
?Knowledge?Distillation?via?Instance?Relationship?Graph
?http://openaccess.thecvf.com/content_CVPR_2019/papers/
?Liu_Knowledge_Distillation_via_Instance_Relationship_Graph_CVPR_2019_paper.pdf
?The?official?code?is?written?by?Caffe
?https://github.com/yufanLIU/IRG
?'''
?def?__init__(self,?w_irg_vert,?w_irg_edge,?w_irg_tran):
??super(IRG,?self).__init__()
??self.w_irg_vert?=?w_irg_vert
??self.w_irg_edge?=?w_irg_edge
??self.w_irg_tran?=?w_irg_tran
?def?forward(self,?irg_s,?irg_t):
??fm_s1,?fm_s2,?feat_s,?out_s?=?irg_s
??fm_t1,?fm_t2,?feat_t,?out_t?=?irg_t
??loss_irg_vert?=?F.mse_loss(out_s,?out_t)
??irg_edge_feat_s?=?self.euclidean_dist_feat(feat_s,?squared=True)
??irg_edge_feat_t?=?self.euclidean_dist_feat(feat_t,?squared=True)
??irg_edge_fm_s1??=?self.euclidean_dist_fm(fm_s1,?squared=True)
??irg_edge_fm_t1??=?self.euclidean_dist_fm(fm_t1,?squared=True)
??irg_edge_fm_s2??=?self.euclidean_dist_fm(fm_s2,?squared=True)
??irg_edge_fm_t2??=?self.euclidean_dist_fm(fm_t2,?squared=True)
??loss_irg_edge?=?(F.mse_loss(irg_edge_feat_s,?irg_edge_feat_t)?+
???????F.mse_loss(irg_edge_fm_s1,??irg_edge_fm_t1?)?+
???????F.mse_loss(irg_edge_fm_s2,??irg_edge_fm_t2?))?/?3.0
??irg_tran_s?=?self.euclidean_dist_fms(fm_s1,?fm_s2,?squared=True)
??irg_tran_t?=?self.euclidean_dist_fms(fm_t1,?fm_t2,?squared=True)
??loss_irg_tran?=?F.mse_loss(irg_tran_s,?irg_tran_t)
??#?print(self.w_irg_vert?*?loss_irg_vert)
??#?print(self.w_irg_edge?*?loss_irg_edge)
??#?print(self.w_irg_tran?*?loss_irg_tran)
??#?print()
??loss?=?(self.w_irg_vert?*?loss_irg_vert?+
????self.w_irg_edge?*?loss_irg_edge?+
????self.w_irg_tran?*?loss_irg_tran)
??return?loss
?def?euclidean_dist_fms(self,?fm1,?fm2,?squared=False,?eps=1e-12):
??'''
??Calculating?the?IRG?Transformation,?where?fm1?precedes?fm2?in?the?network.
??'''
??if?fm1.size(2)?>?fm2.size(2):
???fm1?=?F.adaptive_avg_pool2d(fm1,?(fm2.size(2),?fm2.size(3)))
??if?fm1.size(1)?<?fm2.size(1):
???fm2?=?(fm2[:,0::2,:,:]?+?fm2[:,1::2,:,:])?/?2.0
??fm1?=?fm1.view(fm1.size(0),?-1)
??fm2?=?fm2.view(fm2.size(0),?-1)
??fms_dist?=?torch.sum(torch.pow(fm1-fm2,?2),?dim=-1).clamp(min=eps)
??if?not?squared:
???fms_dist?=?fms_dist.sqrt()
??fms_dist?=?fms_dist?/?fms_dist.max()
??return?fms_dist
?def?euclidean_dist_fm(self,?fm,?squared=False,?eps=1e-12):?
??'''
??Calculating?the?IRG?edge?of?feature?map.?
??'''
??fm?=?fm.view(fm.size(0),?-1)
??fm_square?=?fm.pow(2).sum(dim=1)
??fm_prod???=?torch.mm(fm,?fm.t())
??fm_dist???=?(fm_square.unsqueeze(0)?+?fm_square.unsqueeze(1)?-?2?*?fm_prod).clamp(min=eps)
??if?not?squared:
???fm_dist?=?fm_dist.sqrt()
??fm_dist?=?fm_dist.clone()
??fm_dist[range(len(fm)),?range(len(fm))]?=?0
??fm_dist?=?fm_dist?/?fm_dist.max()
??return?fm_dist
?def?euclidean_dist_feat(self,?feat,?squared=False,?eps=1e-12):
??'''
??Calculating?the?IRG?edge?of?feat.
??'''
??feat_square?=?feat.pow(2).sum(dim=1)
??feat_prod???=?torch.mm(feat,?feat.t())
??feat_dist???=?(feat_square.unsqueeze(0)?+?feat_square.unsqueeze(1)?-?2?*?feat_prod).clamp(min=eps)
??if?not?squared:
???feat_dist?=?feat_dist.sqrt()
??feat_dist?=?feat_dist.clone()
??feat_dist[range(len(feat)),?range(len(feat))]?=?0
??feat_dist?=?feat_dist?/?feat_dist.max()
??return?feat_dist
步驟
新建kd_train.py,插入代碼:
導(dǎo)入需要的庫(kù)
import?torch.optim?as?optim
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?torchvision?import?datasets
from?model.resnet?import?ResNet18
import?json
import?os
from?irg?import?IRG
定義訓(xùn)練和驗(yàn)證函數(shù)
#?定義訓(xùn)練過(guò)程
def?train(s_net,t_net,?device,?criterionCls,criterionKD,train_loader,?optimizer,?epoch):
????s_net.train()
????sum_loss?=?0
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?data.to(device),?target.to(device)
????????optimizer.zero_grad()
????????l1_out_s,l2_out_s,l3_out_s,l4_out_s,fea_s,?out_s?=?s_net(data)
????????cls_loss?=?criterionCls(out_s,?target)
????????l1_out_t,l2_out_t,l3_out_t,l4_out_t,fea_t,?out_t?=?t_net(data)??#?訓(xùn)練出教師的?teacher_output
????????kd_loss?=?criterionKD([l3_out_s,?l4_out_s,?fea_s,?out_s],
??????????????????????????????[l3_out_t.detach(),
???????????????????????????????l4_out_t.detach(),
???????????????????????????????fea_t.detach(),
???????????????????????????????out_t.detach()])?*?lambda_kd
????????loss?=?cls_loss?+?kd_loss
????????loss.backward()
????????optimizer.step()
????????print_loss?=?loss.data.item()
????????sum_loss?+=?print_loss
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item()))
????ave_loss?=?sum_loss?/?len(train_loader)
????print('epoch:{},loss:{}'.format(epoch,?ave_loss))
Best_ACC=0
#?驗(yàn)證過(guò)程
def?val(model,?device,criterionCls,?test_loader):
????global?Best_ACC
????model.eval()
????test_loss?=?0
????correct?=?0
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?data.to(device),?target.to(device)
????????????l1_out_s,?l2_out_s,?l3_out_s,?l4_out_s,?fea_s,?out_s?=?model(data)
????????????loss?=?criterionCls(out_s,?target)
????????????_,?pred?=?torch.max(out_s.data,?1)
????????????correct?+=?torch.sum(pred?==?target)
????????????print_loss?=?loss.data.item()
????????????test_loss?+=?print_loss
????????correct?=?correct.data.item()
????????acc?=?correct?/?total_num
????????avgloss?=?test_loss?/?len(test_loader)
????????if?acc?>?Best_ACC:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????????Best_ACC?=?acc
????????print('\nVal?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????????avgloss,?correct,?len(test_loader.dataset),?100?*?acc))
????????return?acc
定義全局參數(shù)
if?__name__?==?'__main__':
????#?創(chuàng)建保存模型的文件夾
????file_dir?=?'resnet_kd'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,?exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????modellr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?100
????DEVICE?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')
????w_irg_vert=0.1
????w_irg_edge=5.0
????w_irg_tran=5.0
????lambda_kd=1.0
圖像預(yù)處理與增強(qiáng)
?#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,?5),?sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=[0.18507297,?0.18050247,?0.16784933])
????])
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
設(shè)置模型和Loss
?model_ft?=?ResNet18()
????print(model_ft)
????num_ftrs?=?model_ft.fc.in_features
????model_ft.fc?=?nn.Linear(num_ftrs,?12)
????model_ft.to(DEVICE)
????#?選擇簡(jiǎn)單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.Adam(model_ft.parameters(),?lr=modellr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-9)
????teacher_model=torch.load('./CoatNet/best.pth')
????teacher_model.eval()
????#?實(shí)例化模型并且移動(dòng)到GPU
????criterionKD?=?IRG(w_irg_vert,?w_irg_edge,?w_irg_tran)
????criterionCls?=?nn.CrossEntropyLoss()
????#?訓(xùn)練
????val_acc_list=?{}
????for?epoch?in?range(1,?EPOCHS?+?1):
????????train(model_ft,teacher_model,?DEVICE,criterionCls,criterionKD,?train_loader,?optimizer,?epoch)
????????cosine_schedule.step()
????????acc=val(model_ft,DEVICE,criterionCls?,?test_loader)
????????val_acc_list[epoch]=acc
????????with?open('result_kd.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(val_acc_list))
????torch.save(model_ft,?'resnet_kd/model_final.pth')
完成上面的代碼就可以開(kāi)始蒸餾模式?。?!
結(jié)果比對(duì)
加載保存的結(jié)果,然后繪制acc曲線。
import?numpy?as?np
from?matplotlib?import?pyplot?as?plt
import?json
teacher_file='result.json'
student_file='result_student.json'
student_kd_file='result_kd.json'
def?read_json(file):
????with?open(file,?'r',?encoding='utf8')?as?fp:
????????json_data?=?json.load(fp)
????????print(json_data)
????return?json_data
teacher_data=read_json(teacher_file)
student_data=read_json(student_file)
student_kd_data=read_json(student_kd_file)
x?=[int(x)?for?x?in??list(dict(teacher_data).keys())]
print(x)
plt.plot(x,?list(teacher_data.values()),?label='teacher')
plt.plot(x,list(student_data.values()),?label='student?without?IRG')
plt.plot(x,?list(student_kd_data.values()),?label='student?with?IRG')
plt.title('Test?accuracy')
plt.legend()
plt.show()

總結(jié)
本文重點(diǎn)講解了如何使用IRG知識(shí)蒸餾算法對(duì)Student模型進(jìn)行蒸餾。希望能幫助到大家,如果覺(jué)得有用歡迎收藏、點(diǎn)贊和轉(zhuǎn)發(fā);如果有問(wèn)題也可以留言討論。 本次實(shí)戰(zhàn)用到的代碼和數(shù)據(jù)集詳見(jiàn):
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/87096149