DEiT實(shí)戰(zhàn):使用DEiT實(shí)現(xiàn)圖像分類(lèi)任務(wù)(二)
@
導(dǎo)入項(xiàng)目使用的庫(kù)
設(shè)置隨機(jī)因子
設(shè)置全局參數(shù)
圖像預(yù)處理與增強(qiáng)
讀取數(shù)據(jù)
設(shè)置模型
train.py
train_dist.py
定義訓(xùn)練和驗(yàn)證函數(shù)
訓(xùn)練函數(shù)
驗(yàn)證函數(shù)
調(diào)用訓(xùn)練和驗(yàn)證方法
在上一篇文章中完成了前期的準(zhǔn)備工作,見(jiàn)鏈接:DEiT實(shí)戰(zhàn):使用DEiT實(shí)現(xiàn)圖像分類(lèi)任務(wù)(一)這篇主要是講解如何訓(xùn)練和測(cè)試
訓(xùn)練
完成上面的步驟后,就開(kāi)始train腳本的編寫(xiě),新建train.py和train_dist.py
導(dǎo)入項(xiàng)目使用的庫(kù)
在train.py導(dǎo)入
import?json
import?os
import?shutil
import?matplotlib.pyplot?as?plt
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.optim?as?optim
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?timm.utils?import?accuracy,?AverageMeter
from?sklearn.metrics?import?classification_report
from?timm.data.mixup?import?Mixup
from?timm.loss?import?SoftTargetCrossEntropy
from?torchvision?import?datasets
from?timm.models?import?deit_small_patch16_224
torch.backends.cudnn.benchmark?=?False
import?warnings
warnings.filterwarnings("ignore")
from?ema?import?EMA
在train_dist.py導(dǎo)入
import?json
import?os
import?matplotlib.pyplot?as?plt
import?torch
import?torch.nn?as?nn
import?torch.nn.parallel
import?torch.optim?as?optim
import?torch.utils.data
import?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?timm.utils?import?accuracy,?AverageMeter
from?sklearn.metrics?import?classification_report
from?timm.data.mixup?import?Mixup
from?timm.loss?import?SoftTargetCrossEntropy
from?torchvision?import?datasets
from?timm.models?import?deit_small_distilled_patch16_224
torch.backends.cudnn.benchmark?=?False
import?warnings
warnings.filterwarnings("ignore")
from?ema?import?EMA
distilled表示含有蒸餾的token。
設(shè)置隨機(jī)因子
def?seed_everything(seed=42):
????os.environ['PYHTONHASHSEED']?=?str(seed)
????torch.manual_seed(seed)
????torch.cuda.manual_seed(seed)
????torch.backends.cudnn.deterministic?=?True
設(shè)置了固定的隨機(jī)因子,再次訓(xùn)練的時(shí)候就可以保證圖片的加載順序不會(huì)發(fā)生變化。
設(shè)置全局參數(shù)
設(shè)置學(xué)習(xí)率、BatchSize、epoch等參數(shù),判斷環(huán)境中是否存在GPU,如果沒(méi)有則使用CPU。建議使用GPU,CPU太慢了。
if?__name__?==?'__main__':
????#創(chuàng)建保存模型的文件夾
????file_dir?=?'checkpoints/DEiT'
????if?os.path.exists(file_dir):
????????print('true')
????????os.makedirs(file_dir,exist_ok=True)
????else:
????????os.makedirs(file_dir)
????#?設(shè)置全局參數(shù)
????model_lr?=?1e-4
????BATCH_SIZE?=?16
????EPOCHS?=?1000
????DEVICE?=?torch.device('cuda:0'?if?torch.cuda.is_available()?else?'cpu')
????use_amp?=?True??#?是否使用混合精度
????use_dp=False?#是否開(kāi)啟dp方式的多卡訓(xùn)練
????classes?=?12
????resume?=?False
????CLIP_GRAD?=?5.0
????model_path?=?'best.pth'
????Best_ACC?=?0?#記錄最高得分
????use_ema=True
????SEED=42
????seed_everything(42)
設(shè)置存放權(quán)重文件的文件夾,如果文件夾存在刪除再建立。
接下來(lái),查看全局參數(shù):
★model_lr:學(xué)習(xí)率,根據(jù)實(shí)際情況做調(diào)整。
BATCH_SIZE:batchsize,根據(jù)顯卡的大小設(shè)置。
EPOCHS:epoch的個(gè)數(shù),一般300夠用。
use_amp:是否使用混合精度。
classes:類(lèi)別個(gè)數(shù)。
resume:是否接著上次模型繼續(xù)訓(xùn)練。
model_path:模型的路徑。如果resume設(shè)置為T(mén)rue時(shí),就采用model_path定義的模型繼續(xù)訓(xùn)練。
CLIP_GRAD:梯度的最大范數(shù),在梯度裁剪里設(shè)置。
Best_ACC:記錄最高ACC得分。 use_ema:是否使用ema SEED:隨機(jī)因子,數(shù)值可以隨意設(shè)定,但是設(shè)置后,不要隨意更改,更改后,圖片加載的順序會(huì)改變,影響測(cè)試結(jié)果。
”
?file_dir?=?'checkpoints/DEiT'
這是存放DEiT模型的路徑。 在train_dist.py 則應(yīng)該設(shè)置為:
?file_dir?=?'checkpoints/DEiT_dist'
圖像預(yù)處理與增強(qiáng)
數(shù)據(jù)處理比較簡(jiǎn)單,加入了Cutout、做了Resize和歸一化,定義Mixup函數(shù)。
這里注意下Resize的大小,由于選用的MaxViT模型輸入是224×224的大小,所以要Resize為224×224。
???#?數(shù)據(jù)預(yù)處理7
????transform?=?transforms.Compose([
????????transforms.RandomRotation(10),
????????transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1,?3.0)),
????????transforms.ColorJitter(brightness=0.5,?contrast=0.5,?saturation=0.5),
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=?[0.18507297,?0.18050247,?0.16784933])
????])
????transform_test?=?transforms.Compose([
????????transforms.Resize((224,?224)),
????????transforms.ToTensor(),
????????transforms.Normalize(mean=[0.44127703,?0.4712498,?0.43714803],?std=?[0.18507297,?0.18050247,?0.16784933])
????])
????mixup_fn?=?Mixup(
????????mixup_alpha=0.8,?cutmix_alpha=1.0,?cutmix_minmax=None,
????????prob=0.1,?switch_prob=0.5,?mode='batch',
????????label_smoothing=0.1,?num_classes=classes)
讀取數(shù)據(jù)
使用pytorch默認(rèn)讀取數(shù)據(jù)的方式,然后將dataset_train.class_to_idx打印出來(lái),預(yù)測(cè)的時(shí)候要用到。
將dataset_train.class_to_idx保存到txt文件或者json文件中。
????#?讀取數(shù)據(jù)
????dataset_train?=?datasets.ImageFolder('data/train',?transform=transform)
????dataset_test?=?datasets.ImageFolder("data/val",?transform=transform_test)
????with?open('class.txt',?'w')?as?file:
????????file.write(str(dataset_train.class_to_idx))
????with?open('class.json',?'w',?encoding='utf-8')?as?file:
????????file.write(json.dumps(dataset_train.class_to_idx))
????#?導(dǎo)入數(shù)據(jù)
????train_loader?=?torch.utils.data.DataLoader(dataset_train,?batch_size=BATCH_SIZE,?shuffle=True)
????test_loader?=?torch.utils.data.DataLoader(dataset_test,?batch_size=BATCH_SIZE,?shuffle=False)
class_to_idx的結(jié)果:
{'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}
設(shè)置模型
train.py
設(shè)置loss函數(shù),train的loss為:SoftTargetCrossEntropy,val的loss:nn.CrossEntropyLoss()。
設(shè)置模型為deit_small_patch16_224,pretrained設(shè)置為true,表示加載預(yù)訓(xùn)練模型,調(diào)用reset_classifier函數(shù)將classes設(shè)置為12。如果resume為T(mén)rue,則加載模型接著上次訓(xùn)練。
優(yōu)化器設(shè)置為adamW。
學(xué)習(xí)率調(diào)整策略選擇為余弦退火。
開(kāi)啟混合精度訓(xùn)練,聲明pytorch自帶的混合精度 torch.cuda.amp.GradScaler()。
檢測(cè)可用顯卡的數(shù)量,如果大于1,并且開(kāi)啟多卡訓(xùn)練的情況下,則要用torch.nn.DataParallel加載模型,開(kāi)啟多卡訓(xùn)練。
如果使用ema,則注冊(cè)ema
?#?實(shí)例化模型并且移動(dòng)到GPU
????criterion_train?=?SoftTargetCrossEntropy()
????criterion_val?=?torch.nn.CrossEntropyLoss()
????#設(shè)置模型
????model_ft?=?deit_small_patch16_224(pretrained=True)
????model_ft.reset_classifier(classes)
????#?num_ftrs?=?model_ft.head.in_features
????#?model_ft.head?=?nn.Linear(num_ftrs,?classes)
????if?resume:
????????model?=?torch.load(resume)
????????model_ft.load_state_dict(model['state_dict'])
????????Best_ACC?=?model['Best_ACC']
????????start_epoch?=?model['epoch']?+?1
????model_ft.to(DEVICE)
????print(model_ft)
????#?選擇簡(jiǎn)單暴力的Adam優(yōu)化器,學(xué)習(xí)率調(diào)低
????optimizer?=?optim.AdamW(model_ft.parameters(),lr=model_lr)
????cosine_schedule?=?optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,?T_max=20,?eta_min=1e-6)
????if?use_amp:
????????scaler?=?torch.cuda.amp.GradScaler()
????if?torch.cuda.device_count()?>?1?and?use_dp:
????????print("Let's?use",?torch.cuda.device_count(),?"GPUs!")
????????model_ft?=?torch.nn.DataParallel(model_ft)
????if?use_ema:
????????ema?=?EMA(model_ft,?0.9998)
????????ema.register()
注:torch.nn.DataParallel方式,默認(rèn)不能開(kāi)啟混合精度訓(xùn)練的,如果想要開(kāi)啟混合精度訓(xùn)練,則需要在模型的forward前面加上@autocast()函數(shù)。

如果不開(kāi)啟混合精度則要將@autocast()去掉,否則loss一直試nan。
train_dist.py
train_dist.py設(shè)置模型為deit_small_distilled_patch16_224。
??#設(shè)置模型
????model_ft?=?deit_small_distilled_patch16_224(pretrained=True)
????model_ft.reset_classifier(classes)
定義訓(xùn)練和驗(yàn)證函數(shù)
訓(xùn)練函數(shù)
訓(xùn)練的主要步驟:
★1、使用AverageMeter保存自定義變量,包括loss,ACC1,ACC5。
2、判斷迭代的數(shù)據(jù)是否是奇數(shù),由于mixup_fn只能接受偶數(shù),所以如果不是偶數(shù)則要減去一位,讓其變成偶數(shù)。但是有可能最后一次迭代只有一條數(shù)據(jù),減去后就變成了0,所以還要判斷不能小于2,如果小于2則直接中斷本次循環(huán)。
”
★3、將數(shù)據(jù)輸入mixup_fn生成mixup數(shù)據(jù),然后輸入model計(jì)算loss。
”
★4、 optimizer.zero_grad() 梯度清零,把loss關(guān)于weight的導(dǎo)數(shù)變成0。
”
★5、如果使用混合精度,則
★”
with torch.cuda.amp.autocast(),開(kāi)啟混合精度。
計(jì)算loss。
scaler.scale(loss).backward(),梯度放大。
torch.nn.utils.clip_grad_norm_,梯度裁剪,放置梯度爆炸。
scaler.step(optimizer) ,首先把梯度值unscale回來(lái),如果梯度值不是inf或NaN,則調(diào)用optimizer.step()來(lái)更新權(quán)重,否則,忽略step調(diào)用,從而保證權(quán)重不更新。
更新下一次迭代的scaler。
否則,直接反向傳播求梯度。torch.nn.utils.clip_grad_norm_函數(shù)執(zhí)行梯度裁剪,防止梯度爆炸。 6、如果use_ema為T(mén)rue,則執(zhí)行model_ema的updata函數(shù),更新模型。
”
★7、 torch.cuda.synchronize(),等待上面所有的操作執(zhí)行完成。
”
★8、接下來(lái),更新loss,ACC1,ACC5的值。
”
等待一個(gè)epoch訓(xùn)練完成后,計(jì)算平均loss和平均acc
#?定義訓(xùn)練過(guò)程
def?train(model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????loss_meter?=?AverageMeter()
????acc1_meter?=?AverageMeter()
????acc5_meter?=?AverageMeter()
????total_num?=?len(train_loader.dataset)
????print(total_num,?len(train_loader))
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?data.to(device,?non_blocking=True),?target.to(device,?non_blocking=True)
????????samples,?targets?=?mixup_fn(data,?target)
????????output?=?model(samples)
????????optimizer.zero_grad()
????????if?use_amp:
????????????with?torch.cuda.amp.autocast():
????????????????loss?=?criterion_train(output,?targets)
????????????scaler.scale(loss).backward()
????????????torch.nn.utils.clip_grad_norm_(model.parameters(),?CLIP_GRAD)
????????????#?Unscales?gradients?and?calls
????????????#?or?skips?optimizer.step()
????????????scaler.step(optimizer)
????????????#?Updates?the?scale?for?next?iteration
????????????scaler.update()
????????else:
????????????loss?=?criterion_train(output,?targets)
????????????loss.backward()
????????????#?torch.nn.utils.clip_grad_norm_(model.parameters(),?CLIP_GRAD)
????????????optimizer.step()
????????if?use_ema?:
????????????ema.update()
????????torch.cuda.synchronize()
????????lr?=?optimizer.state_dict()['param_groups'][0]['lr']
????????loss_meter.update(loss.item(),?target.size(0))
????????acc1,?acc5?=?accuracy(output,?target,?topk=(1,?5))
????????loss_meter.update(loss.item(),?target.size(0))
????????acc1_meter.update(acc1.item(),?target.size(0))
????????acc5_meter.update(acc5.item(),?target.size(0))
????????if?(batch_idx?+?1)?%?10?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}\tLR:{:.9f}'.format(
????????????????epoch,?(batch_idx?+?1)?*?len(data),?len(train_loader.dataset),
???????????????????????100.?*?(batch_idx?+?1)?/?len(train_loader),?loss.item(),?lr))
????ave_loss?=loss_meter.avg
????acc?=?acc1_meter.avg
????print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch,?ave_loss,?acc))
????return?ave_loss,?acc
驗(yàn)證函數(shù)
驗(yàn)證集和訓(xùn)練集大致相似,主要步驟:
★1、定義參數(shù),loss_meter 測(cè)試的loss,total_num總的驗(yàn)證集的數(shù)量,val_list驗(yàn)證集的label,pred_list預(yù)測(cè)的label。
”
★2、在val的函數(shù)上面添加@torch.no_grad(),作用:所有計(jì)算得出的tensor的requires_grad都自動(dòng)設(shè)置為False。即使一個(gè)tensor(命名為x)的requires_grad = True,在with torch.no_grad計(jì)算,由x得到的新tensor(命名為w-標(biāo)量)requires_grad也為False,且grad_fn也為None,即不會(huì)對(duì)w求導(dǎo)。 3、如果use_ema 為T(mén)rue,則使用shadow字典的參數(shù)更新模型參數(shù)。 4、使用驗(yàn)證集的loss函數(shù)求出驗(yàn)證集的loss。 5、調(diào)用accuracy函數(shù)計(jì)算ACC1和ACC5
6、更新loss_meter、acc1_meter、acc5_meter的參數(shù)。 7、如果use_ema 為T(mén)rue,則清空backup字典。
”
本次epoch循環(huán)完成后,求得本次epoch的acc、loss。
如果acc比Best_ACC大,則保存模型。 保存模型的邏輯:
★如果ACC比Best_ACC高,則保存best模型 判斷模型是否為DP方式訓(xùn)練的模型。
★如果是DP方式訓(xùn)練的模型,模型參數(shù)放在model.module,則需要保存model.module。 否則直接保存model。
”接下來(lái)保存每個(gè)epoch的模型。 判斷模型是否為DP方式訓(xùn)練的模型。
★”如果是DP方式訓(xùn)練的模型,模型參數(shù)放在model.module,則需要保存model.module.state_dict()。 新建個(gè)字典,放置Best_ACC、epoch和 model.module.state_dict()等參數(shù)。然后將這個(gè)字典保存。 否則,新建個(gè)字典,放置Best_ACC、epoch和 model.state_dict()等參數(shù)。然后將這個(gè)字典保存。 在這里注意:對(duì)于每個(gè)epoch的模型只保存了state_dict參數(shù),沒(méi)有保存整個(gè)模型文件。
”
#?驗(yàn)證過(guò)程
def?val(model,?device,?test_loader):
????global?Best_ACC
????model.eval()
????loss_meter?=?AverageMeter()
????acc1_meter?=?AverageMeter()
????acc5_meter?=?AverageMeter()
????total_num?=?len(test_loader.dataset)
????print(total_num,?len(test_loader))
????val_list?=?[]
????pred_list?=?[]
????if?use_ema?:
????????ema.apply_shadow()
????for?data,?target?in?test_loader:
????????for?t?in?target:
????????????val_list.append(t.data.item())
????????data,?target?=?data.to(device,non_blocking=True),?target.to(device,non_blocking=True)
????????output?=?model(data)
????????loss?=?criterion_val(output,?target)
????????_,?pred?=?torch.max(output.data,?1)
????????for?p?in?pred:
????????????pred_list.append(p.data.item())
????????acc1,?acc5?=?accuracy(output,?target,?topk=(1,?5))
????????loss_meter.update(loss.item(),?target.size(0))
????????acc1_meter.update(acc1.item(),?target.size(0))
????????acc5_meter.update(acc5.item(),?target.size(0))
????if?use_ema?:
????????ema.restore()
????acc?=?acc1_meter.avg
????print('\nVal?set:?Average?loss:?{:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(
????????loss_meter.avg,??acc,??acc5_meter.avg))
????if?acc?>?Best_ACC:
????????if?isinstance(model,?torch.nn.DataParallel):
????????????torch.save(model.module,?file_dir?+?'/'?+?'best.pth')
????????else:
????????????torch.save(model,?file_dir?+?'/'?+?'best.pth')
????????Best_ACC?=?acc
????if?isinstance(model,?torch.nn.DataParallel):
????????state?=?{
????????????'epoch':?epoch,
????????????'state_dict':?model.module.state_dict(),
????????????'Best_ACC':?Best_ACC
????????}
????????torch.save(state,?file_dir?+?"/"?+?'model_'?+?str(epoch)?+?'_'?+?str(round(acc,?3))?+?'.pth')
????else:
????????state?=?{
????????????'epoch':?epoch,
????????????'state_dict':?model.state_dict(),
????????????'Best_ACC':?Best_ACC
????????}
????????torch.save(state,?file_dir?+?"/"?+?'model_'?+?str(epoch)?+?'_'?+?str(round(acc,?3))?+?'.pth')
????return?val_list,?pred_list,?loss_meter.avg,?acc
調(diào)用訓(xùn)練和驗(yàn)證方法
調(diào)用訓(xùn)練函數(shù)和驗(yàn)證函數(shù)的主要步驟:
★1、定義參數(shù):
”
is_set_lr,是否已經(jīng)設(shè)置了學(xué)習(xí)率,當(dāng)epoch大于一定的次數(shù)后,會(huì)將學(xué)習(xí)率設(shè)置到一定的值,并將其置為T(mén)rue。
log_dir:記錄log用的,將有用的信息保存到字典中,然后轉(zhuǎn)為json保存起來(lái)。
train_loss_list:保存每個(gè)epoch的訓(xùn)練loss。
val_loss_list:保存每個(gè)epoch的驗(yàn)證loss。
train_acc_list:保存每個(gè)epoch的訓(xùn)練acc。
val_acc_list:保存么每個(gè)epoch的驗(yàn)證acc。
epoch_list:存放每個(gè)epoch的值。
★循環(huán)epoch
★”1、調(diào)用train函數(shù),得到 train_loss, train_acc,并將分別放入train_loss_list,train_acc_list,然后存入到logdir字典中。
2、調(diào)用驗(yàn)證函數(shù),得到val_list, pred_list, val_loss, val_acc。將val_loss, val_acc分別放入val_loss_list和val_acc_list中,然后存入到logdir字典中。
3、保存log。
4、打印本次的測(cè)試報(bào)告。
5、如果epoch大于600,將學(xué)習(xí)率設(shè)置為固定的1e-6。
6、繪制loss曲線(xiàn)和acc曲線(xiàn)。
”
?????#?訓(xùn)練與驗(yàn)證
????is_set_lr?=?False
????log_dir?=?{}
????train_loss_list,?val_loss_list,?train_acc_list,?val_acc_list,?epoch_list?=?[],?[],?[],?[],?[]
????for?epoch?in?range(1,?EPOCHS?+?1):
????????epoch_list.append(epoch)
????????train_loss,?train_acc?=?train(model_ft,?DEVICE,?train_loader,?optimizer,?epoch)
????????train_loss_list.append(train_loss)
????????train_acc_list.append(train_acc)
????????log_dir['train_acc']?=?train_acc_list
????????log_dir['train_loss']?=?train_loss_list
????????val_list,?pred_list,?val_loss,?val_acc?=?val(model_ft,?DEVICE,?test_loader)
????????val_loss_list.append(val_loss)
????????val_acc_list.append(val_acc)
????????log_dir['val_acc']?=?val_acc_list
????????log_dir['val_loss']?=?val_loss_list
????????log_dir['best_acc']?=?Best_ACC
????????with?open(file_dir?+?'/result.json',?'w',?encoding='utf-8')?as?file:
????????????file.write(json.dumps(log_dir))
????????print(classification_report(val_list,?pred_list,?target_names=dataset_train.class_to_idx))
????????if?epoch?<?600:
????????????cosine_schedule.step()
????????else:
????????????if?not?is_set_lr:
????????????????for?param_group?in?optimizer.param_groups:
????????????????????param_group["lr"]?=?1e-6
????????????????????is_set_lr?=?True
????????fig?=?plt.figure(1)
????????plt.plot(epoch_list,?train_loss_list,?'r-',?label=u'Train?Loss')
????????#?顯示圖例
????????plt.plot(epoch_list,?val_loss_list,?'b-',?label=u'Val?Loss')
????????plt.legend(["Train?Loss",?"Val?Loss"],?loc="upper?right")
????????plt.xlabel(u'epoch')
????????plt.ylabel(u'loss')
????????plt.title('Model?Loss?')
????????plt.savefig(file_dir?+?"/loss.png")
????????plt.close(1)
????????fig2?=?plt.figure(2)
????????plt.plot(epoch_list,?train_acc_list,?'r-',?label=u'Train?Acc')
????????plt.plot(epoch_list,?val_acc_list,?'b-',?label=u'Val?Acc')
????????plt.legend(["Train?Acc",?"Val?Acc"],?loc="lower?right")
????????plt.title("Model?Acc")
????????plt.ylabel("acc")
????????plt.xlabel("epoch")
????????plt.savefig(file_dir?+?"/acc.png")
????????plt.close(2)
運(yùn)行以及結(jié)果查看
完成上面的所有代碼就可以開(kāi)始運(yùn)行了。點(diǎn)擊右鍵,然后選擇“run train.py”即可,運(yùn)行結(jié)果如下:

在每個(gè)epoch測(cè)試完成之后,打印驗(yàn)證集的acc、recall等指標(biāo)。
DeiT測(cè)試結(jié)果:


DeiT_dist測(cè)試結(jié)果:


測(cè)試
測(cè)試,我們采用一種通用的方式。
測(cè)試集存放的目錄如下圖:
DEiT_demo
├─test
│??├─1.jpg
│??├─2.jpg
│??├─3.jpg
│??├?......
└─test.pyimport?torch.utils.data.distributed
import?torchvision.transforms?as?transforms
from?PIL?import?Image
from?torch.autograd?import?Variable
import?os
classes?=?('Black-grass',?'Charlock',?'Cleavers',?'Common?Chickweed',
???????????'Common?wheat',?'Fat?Hen',?'Loose?Silky-bent',
???????????'Maize',?'Scentless?Mayweed',?'Shepherds?Purse',?'Small-flowered?Cranesbill',?'Sugar?beet')
transform_test?=?transforms.Compose([
????transforms.Resize((224,?224)),
????transforms.ToTensor(),
????transforms.Normalize(mean=[0.51819474,?0.5250407,?0.4945761],?std=[0.24228974,?0.24347611,?0.2530049])
])
DEVICE?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu")
model=torch.load('checkpoints/DEiT/best.pth')
model.eval()
model.to(DEVICE)
path?=?'test/'
testList?=?os.listdir(path)
for?file?in?testList:
????img?=?Image.open(path?+?file)
????img?=?transform_test(img)
????img.unsqueeze_(0)
????img?=?Variable(img).to(DEVICE)
????out?=?model(img)
????#?Predict
????_,?pred?=?torch.max(out.data,?1)
????print('Image?Name:{},predict:{}'.format(file,?classes[pred.data.item()]))
測(cè)試的主要邏輯:
★1、定義類(lèi)別,這個(gè)類(lèi)別的順序和訓(xùn)練時(shí)的類(lèi)別順序?qū)?yīng),一定不要改變順序!?。?!
”
★2、定義transforms,transforms和驗(yàn)證集的transforms一樣即可,別做數(shù)據(jù)增強(qiáng)。
”
★3、 加載model,并將模型放在DEVICE里,
”
★4、循環(huán) 讀取圖片并預(yù)測(cè)圖片的類(lèi)別,在這里注意,讀取圖片用PIL庫(kù)的Image。不要用CV2,transforms不支持。循環(huán)里面的主要邏輯:
★””
使用Image.open讀取圖片
使用transform_test對(duì)圖片做歸一化和標(biāo)椎化。
img.unsqueeze_(0) 增加一個(gè)維度,由(3,224,224)變?yōu)椋?,3,224,224)
Variable(img).to(DEVICE):將數(shù)據(jù)放入DEVICE中。
model(img):執(zhí)行預(yù)測(cè)。
_, pred = torch.max(out.data, 1):獲取預(yù)測(cè)值的最大下角標(biāo)。
運(yùn)行結(jié)果:

完整的代碼
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/87294382