DEiT實戰(zhàn):使用DEiT實現(xiàn)圖像分類任務(wù)(一)
@
安裝timm
摘要
DEiT是FaceBook在2020年提出的一篇Transformer模型。該模型解決了Transformer難以訓(xùn)練的問題,三天內(nèi)使用4塊GPU,完成了ImageNet的訓(xùn)練,并且沒有使用外部數(shù)據(jù),達(dá)到了SOTA水平。 DEiT提出的蒸餾策略只增加了對token的蒸餾,沒有引入其他的重要架構(gòu)。如下圖:

蒸餾令牌與類令牌的使用類似:它通過自注意力與其他令牌交互,并在最后一層后由網(wǎng)絡(luò)輸出。蒸餾令牌允許模型從老師的輸出中學(xué)習(xí),就像在常規(guī)蒸餾中一樣,同時與類令牌保持互補(bǔ)。這一點我們可以代碼中找到:
?self.dist_token?=?nn.Parameter(torch.zeros(1,?1,?self.embed_dim))
?self.cls_token?=?nn.Parameter(torch.zeros(1,?1,?embed_dim))
cls_tokens 是類令牌,dist_token 是蒸餾令牌,確實很像是,仔細(xì)看都沒有找到差別。
?def?forward(self,?x):
????????x,?x_dist?=?self.forward_features(x)
????????x?=?self.head(x)
????????x_dist?=?self.head_dist(x_dist)
????????if?self.training:
????????????return?x,?x_dist
????????else:
????????????#?during?inference,?return?the?average?of?both?classifier?predictions
????????????return?(x?+?x_dist)?/?2
如果想讓模型從Teacher模型里學(xué)習(xí),將 if self.training:設(shè)置為true,這樣我們就可以像論文中那樣使用RegNet做Teacher,使用DeiT模型做Student去蒸餾x_dist,否則將兩者做平均實現(xiàn)二者的互補(bǔ)。 timm里面的代碼和官方的代碼有所不通,不過邏輯上是一樣的,下面是timm的代碼:
?def?forward_head(self,?x,?pre_logits:?bool?=?False)?->?torch.Tensor:
????????if?pre_logits:
????????????return?(x[:,?0]?+?x[:,?1])?/?2
????????x,?x_dist?=?self.head(x[:,?0]),?self.head_dist(x[:,?1])
????????if?self.distilled_training?and?self.training?and?not?torch.jit.is_scripting():
????????????#?only?return?separate?classification?predictions?when?training?in?distilled?mode
????????????return?x,?x_dist
????????else:
????????????#?during?standard?train?/?finetune,?inference?average?the?classifier?predictions
????????????return?(x?+?x_dist)?/?2
(終于搞明白了。憋了好幾天了,直到看了官方的代碼才理解。) 等我有時間了再寫一篇使用外部模型蒸餾的教程。
這篇文章主要講解如何使用DEiT完成圖像分類任務(wù),接下來我們一起完成項目的實戰(zhàn)。本例選用的模型是deit_small_patch16_224和deit_small_distilled_patch16_224,在植物幼苗數(shù)據(jù)集上實現(xiàn)了96%和97%的準(zhǔn)確率。deit_small_patch16_224是沒有蒸餾token的操作,deit_small_distilled_patch16_224有蒸餾token的操作,從結(jié)果上看蒸餾還是有不錯的效果。 論文鏈接:https://arxiv.org/abs/2012.12877v2 論文翻譯:https://wanghao.blog.csdn.net/article/details/128180419?spm=1001.2014.3001.5502 視頻講解:https://www.zhihu.com/zvideo/1587194506348040192 DeiT測試結(jié)果:


DeiT_dist測試結(jié)果:


通過這篇文章能讓你學(xué)到:
如何使用數(shù)據(jù)增強(qiáng),包括transforms的增強(qiáng)、CutOut、MixUp、CutMix等增強(qiáng)手段?
如何實現(xiàn)DEit模型和DEiT_dist模型實現(xiàn)訓(xùn)練?
如何使用pytorch自帶混合精度?
如何使用梯度裁剪防止梯度爆炸?
如何使用DP多顯卡訓(xùn)練?
如何繪制loss和acc曲線?
如何生成val的測評報告?
如何編寫測試腳本測試測試集?
如何使用余弦退火策略調(diào)整學(xué)習(xí)率?
如何使用AverageMeter類統(tǒng)計ACC和loss等自定義變量?
如何理解和統(tǒng)計ACC1和ACC5?
如何使用EMA?
安裝包
安裝timm
使用pip就行,命令:
pip?install?timm
本文實戰(zhàn)用的timm里面的模型。
數(shù)據(jù)增強(qiáng)Cutout和Mixup
為了提高成績我在代碼中加入Cutout和Mixup這兩種增強(qiáng)方式。實現(xiàn)這兩種增強(qiáng)需要安裝torchtoolbox。安裝命令:
pip?install?torchtoolbox
Cutout實現(xiàn),在transforms中。
from?torchtoolbox.transform?import?Cutout
#?數(shù)據(jù)預(yù)處理
transform?=?transforms.Compose([
????transforms.Resize((224,?224)),
????Cutout(),
????transforms.ToTensor(),
????transforms.Normalize([0.5,?0.5,?0.5],?[0.5,?0.5,?0.5])
])
需要導(dǎo)入包:from timm.data.mixup import Mixup,
定義Mixup,和SoftTargetCrossEntropy
??mixup_fn?=?Mixup(
????mixup_alpha=0.8,?cutmix_alpha=1.0,?cutmix_minmax=None,
????prob=0.1,?switch_prob=0.5,?mode='batch',
????label_smoothing=0.1,?num_classes=12)
?criterion_train?=?SoftTargetCrossEntropy()
參數(shù)詳解:
★mixup_alpha (float): mixup alpha 值,如果 > 0,則 mixup 處于活動狀態(tài)。
cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 處于活動狀態(tài)。
cutmix_minmax (List[float]):cutmix 最小/最大圖像比率,cutmix 處于活動狀態(tài),如果不是 None,則使用這個 vs alpha。
如果設(shè)置了 cutmix_minmax 則cutmix_alpha 默認(rèn)為1.0
prob (float): 每批次或元素應(yīng)用 mixup 或 cutmix 的概率。
switch_prob (float): 當(dāng)兩者都處于活動狀態(tài)時切換cutmix 和mixup 的概率 。
mode (str): 如何應(yīng)用 mixup/cutmix 參數(shù)(每個'batch','pair'(元素對),'elem'(元素)。
correct_lam (bool): 當(dāng) cutmix bbox 被圖像邊框剪裁時應(yīng)用。 lambda 校正
label_smoothing (float):將標(biāo)簽平滑應(yīng)用于混合目標(biāo)張量。
num_classes (int): 目標(biāo)的類數(shù)。
”
EMA
EMA(Exponential Moving Average)是指數(shù)移動平均值。在深度學(xué)習(xí)中的做法是保存歷史的一份參數(shù),在一定訓(xùn)練階段后,拿歷史的參數(shù)給目前學(xué)習(xí)的參數(shù)做一次平滑。具體實現(xiàn)如下:
class?EMA():
????def?__init__(self,?model,?decay):
????????self.model?=?model
????????self.decay?=?decay
????????self.shadow?=?{}
????????self.backup?=?{}
????def?register(self):
????????for?name,?param?in?self.model.named_parameters():
????????????if?param.requires_grad:
????????????????self.shadow[name]?=?param.data.clone()
????def?update(self):
????????for?name,?param?in?self.model.named_parameters():
????????????if?param.requires_grad:
????????????????assert?name?in?self.shadow
????????????????new_average?=?(1.0?-?self.decay)?*?param.data?+?self.decay?*?self.shadow[name]
????????????????self.shadow[name]?=?new_average.clone()
????def?apply_shadow(self):
????????for?name,?param?in?self.model.named_parameters():
????????????if?param.requires_grad:
????????????????assert?name?in?self.shadow
????????????????self.backup[name]?=?param.data
????????????????param.data?=?self.shadow[name]
????def?restore(self):
????????for?name,?param?in?self.model.named_parameters():
????????????if?param.requires_grad:
????????????????assert?name?in?self.backup
????????????????param.data?=?self.backup[name]
????????self.backup?=?{}
加入到模型中。
#?初始化
ema?=?EMA(model,?0.999)
ema.register()
#?訓(xùn)練過程中,更新完參數(shù)后,同步update?shadow?weights
def?train():
????optimizer.step()
????ema.update()
#?eval前,apply?shadow?weights;eval之后,恢復(fù)原來模型的參數(shù)
def?evaluate():
????ema.apply_shadow()
????#?evaluate
????ema.restore()
針對沒有預(yù)訓(xùn)練的模型,容易出現(xiàn)EMA不上分的情況,這點大家要注意?。?/p>
項目結(jié)構(gòu)
DEiT_demo
├─data1
│??├─Black-grass
│??├─Charlock
│??├─Cleavers
│??├─Common?Chickweed
│??├─Common?wheat
│??├─Fat?Hen
│??├─Loose?Silky-bent
│??├─Maize
│??├─Scentless?Mayweed
│??├─Shepherds?Purse
│??├─Small-flowered?Cranesbill
│??└─Sugar?beet
├─mean_std.py
├─makedata.py
├─ema.py
├─train.py
├─train_dist.py
└─test.py
mean_std.py:計算mean和std的值。 makedata.py:生成數(shù)據(jù)集。 ema.py:EMA腳本 train.py:訓(xùn)練DEiT模型 train_dist.py:訓(xùn)練蒸餾策略的DEiT模型。
為了能在DP方式中使用混合精度,還需要在模型的forward函數(shù)前增加@autocast(),如果使用GPU訓(xùn)練導(dǎo)入包from torch.cuda.amp import autocast,如果使用CPU,則導(dǎo)入from torch.cpu.amp import autocast。

計算mean和std
為了使模型更加快速的收斂,我們需要計算出mean和std的值,新建mean_std.py,插入代碼:
from?torchvision.datasets?import?ImageFolder
import?torch
from?torchvision?import?transforms
def?get_mean_and_std(train_data):
????train_loader?=?torch.utils.data.DataLoader(
????????train_data,?batch_size=1,?shuffle=False,?num_workers=0,
????????pin_memory=True)
????mean?=?torch.zeros(3)
????std?=?torch.zeros(3)
????for?X,?_?in?train_loader:
????????for?d?in?range(3):
????????????mean[d]?+=?X[:,?d,?:,?:].mean()
????????????std[d]?+=?X[:,?d,?:,?:].std()
????mean.div_(len(train_data))
????std.div_(len(train_data))
????return?list(mean.numpy()),?list(std.numpy())
if?__name__?==?'__main__':
????train_dataset?=?ImageFolder(root=r'data1',?transform=transforms.ToTensor())
????print(get_mean_and_std(train_dataset))
數(shù)據(jù)集結(jié)構(gòu):

運行結(jié)果:
([0.3281186,?0.28937867,?0.20702125],?[0.09407319,?0.09732835,?0.106712654])
把這個結(jié)果記錄下來,后面要用!
生成數(shù)據(jù)集
我們整理還的圖像分類的數(shù)據(jù)集結(jié)構(gòu)是這樣的
data
├─Black-grass
├─Charlock
├─Cleavers
├─Common?Chickweed
├─Common?wheat
├─Fat?Hen
├─Loose?Silky-bent
├─Maize
├─Scentless?Mayweed
├─Shepherds?Purse
├─Small-flowered?Cranesbill
└─Sugar?beet
pytorch和keras默認(rèn)加載方式是ImageNet數(shù)據(jù)集格式,格式是
├─data
│??├─val
│??│???├─Black-grass
│??│???├─Charlock
│??│???├─Cleavers
│??│???├─Common?Chickweed
│??│???├─Common?wheat
│??│???├─Fat?Hen
│??│???├─Loose?Silky-bent
│??│???├─Maize
│??│???├─Scentless?Mayweed
│??│???├─Shepherds?Purse
│??│???├─Small-flowered?Cranesbill
│??│???└─Sugar?beet
│??└─train
│??????├─Black-grass
│??????├─Charlock
│??????├─Cleavers
│??????├─Common?Chickweed
│??????├─Common?wheat
│??????├─Fat?Hen
│??????├─Loose?Silky-bent
│??????├─Maize
│??????├─Scentless?Mayweed
│??????├─Shepherds?Purse
│??????├─Small-flowered?Cranesbill
│??????└─Sugar?beet
新增格式轉(zhuǎn)化腳本makedata.py,插入代碼:
import?glob
import?os
import?shutil
image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if?os.path.exists(file_dir):
????print('true')
????#os.rmdir(file_dir)
????shutil.rmtree(file_dir)#刪除再建立
????os.makedirs(file_dir)
else:
????os.makedirs(file_dir)
from?sklearn.model_selection?import?train_test_split
trainval_files,?val_files?=?train_test_split(image_list,?test_size=0.3,?random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for?file?in?trainval_files:
????file_class=file.replace("\\","/").split('/')[-2]
????file_name=file.replace("\\","/").split('/')[-1]
????file_class=os.path.join(train_root,file_class)
????if?not?os.path.isdir(file_class):
????????os.makedirs(file_class)
????shutil.copy(file,?file_class?+?'/'?+?file_name)
for?file?in?val_files:
????file_class=file.replace("\\","/").split('/')[-2]
????file_name=file.replace("\\","/").split('/')[-1]
????file_class=os.path.join(val_root,file_class)
????if?not?os.path.isdir(file_class):
????????os.makedirs(file_class)
????shutil.copy(file,?file_class?+?'/'?+?file_name)
完成上面的內(nèi)容就可以開啟訓(xùn)練和測試了。