ResNet復(fù)現(xiàn)(基于ImageNet分類任務(wù))

?1/60 補(bǔ)充一下之前的專欄,主要以代碼為主,文章本體上我就不做結(jié)構(gòu)化的闡述了。
Introduction
????????這篇文章以前看過的,這次是復(fù)現(xiàn)代碼。大白話介紹一下resnet,思想非常簡(jiǎn)潔,下面這一張圖就是resnet的所有東西了!其“學(xué)習(xí)殘差”的這個(gè)思想?yún)s極為先進(jìn),一舉解決深度網(wǎng)絡(luò)難訓(xùn)練的問題,真正讓神經(jīng)網(wǎng)絡(luò)變成了深度學(xué)習(xí),可以說是現(xiàn)代深度學(xué)習(xí)的半壁江山,基本現(xiàn)在熱門模型大多數(shù)都用到了殘差學(xué)習(xí)的思想,如transformer。

????????傳統(tǒng)網(wǎng)絡(luò)有這樣的問題:隨著深度增加,精度不增反減,作者稱之為退化,且這種減少(可能)不是由于 overfitting 問題造成的(不好解決)。?
????????總之,從常理來說,深度增加理應(yīng)是會(huì)使網(wǎng)絡(luò)的表征能力變強(qiáng),精度提升,但事實(shí)卻并非如此,所以提出了一種學(xué)習(xí)策略—?dú)埐顚W(xué)習(xí),那就是只向能使精度變好的方向去學(xué)習(xí),換句 話說,這種結(jié)構(gòu)可以理解成,你學(xué)不到更好的你就別學(xué),直接用 shortcut 分支作 identity mapping。這種結(jié)構(gòu)還會(huì)很好的解決梯度的爆炸與消失問題,因?yàn)殒準(zhǔn)椒▌t求導(dǎo)后由連乘變成了連加,得以保證前面的層也能很好的接收到梯度。
ResNet Architecture:?
使用一個(gè)初等的基本網(wǎng)絡(luò)架構(gòu) ---- "Our plain baselines (Fig. 3, middle) are mainly inspired by the philosophy of VGG net?"
降維時(shí)使用 s = 2 的 1*1 卷積做下采樣,可以說是忽略了很多特征點(diǎn)?---- "We perform downsampling directly by convolutional layers that have a stride of 2. "
在shortcut需要降維時(shí)使用1*1卷積作線性映射(不在此作非線性操作),帶可訓(xùn)練參數(shù)的shortcut 提升不太明顯,同時(shí)會(huì)增加模型復(fù)雜度。?
大量使用 BN 而不使用 dropout ---- "We adopt batch normalization (BN) right after each convolution and before activation."
使用 1*1 卷積進(jìn)行降維后升維(bottleneck block)?
復(fù)雜度降低:?BottleNeck 結(jié)構(gòu)

????????這個(gè)設(shè)計(jì)在resnet中有比較詳細(xì)的解釋,目的是為了減少參數(shù)和計(jì)算量。
????????由于對(duì)這種一下將維度縮減的操作,理論上,這種降維肯定是會(huì)伴隨著信息損失的,但是對(duì)于特定任務(wù)來說,如此高的維度可以說是沒有必要的,被幾十倍放大過的特征里面有大部分其實(shí)都是"無用"的,有些甚至?xí)a(chǎn)生負(fù)面影響,這里可以和"壞死神經(jīng)元"的思想進(jìn)行一個(gè)類比(網(wǎng)絡(luò)中大多數(shù)神經(jīng)元其實(shí)都是未被使用到的)。所以我們可以通過降維操作去選取一些我們更應(yīng)該需要關(guān)注的特征。而緊接著的重新升維是為了提升網(wǎng)絡(luò)的表征能力,這里升維后的特征與降維前的特征雖然通道數(shù)相同,但是后者實(shí)際上是從低位特征還原得到的,同時(shí)也經(jīng)過了又一次的非線性激活函數(shù),可以認(rèn)為升維后的特征針對(duì)當(dāng)前任務(wù)更具有特異性。
?代碼復(fù)現(xiàn):基于卷積神經(jīng)網(wǎng)絡(luò)的圖片分類
任務(wù)描述:
在ImageNet數(shù)據(jù)集上,使用卷積神經(jīng)網(wǎng)絡(luò)進(jìn)行圖像分類。
數(shù)據(jù)集上篩選了100類,每一類有帶標(biāo)簽圖片1000張。train文件夾:訓(xùn)練集,子文件夾名即為標(biāo)簽;val文件夾:包含了圖片和對(duì)應(yīng)的標(biāo)簽;test文件夾:包含了5000張無標(biāo)簽圖片,需要將5000張圖片按照 xxx.jpg (labels)的格式輸入到一個(gè)txt文件中。
數(shù)據(jù)處理:
由于計(jì)算資源和圖片的size限制,在實(shí)驗(yàn)中并未使用過多的數(shù)據(jù)增強(qiáng)操作,只對(duì)測(cè)試集中部分圖像做了鏡像翻轉(zhuǎn)操作作為數(shù)據(jù)增強(qiáng)。已知可以使用的方法,包括中心(四角)裁剪,圖像增強(qiáng),高斯模糊等操作,此函數(shù)都封裝在torchvision.transforms庫中。
同時(shí)對(duì)于訓(xùn)練集,驗(yàn)證集集和測(cè)試集都作了nomalization操作,其中參數(shù)選取于ImageNet整體的均值與方差(網(wǎng)上可查),其中mean = (0.485, 0.456, 0.406),std = (0.229, 0.224, 0.225)。
在數(shù)據(jù)讀取上,使用了庫函數(shù)torchvision.datasets.ImageFolder(),其以標(biāo)好標(biāo)簽的文件夾為處理對(duì)象,可與torch.utils.data中的DataLoader搭配使用,十分方便。
以上操作被封裝在函數(shù)get_loader()中,函數(shù)參數(shù)接收文件夾的路徑和batchsize以及worker_nums,返回兩個(gè)dataloader。
網(wǎng)絡(luò)結(jié)構(gòu):
????????我的resnet18的實(shí)現(xiàn),封裝在net.resnet中。如果你是為了學(xué)習(xí),我更推薦去看torchvision.models.resnet中的官方源碼。
????????在數(shù)據(jù)進(jìn)入第一個(gè)塊之前,使用64個(gè)kernelsize = 3*3,stride=1與padding=1的卷積核對(duì)數(shù)據(jù)進(jìn)行通道上的擴(kuò)充。
????? ? 通過kernelsize = 3*3,stride=1與padding=1操作,數(shù)據(jù)在塊內(nèi)size和channel不變。每個(gè)塊將上一塊的輸出的size減半,channel翻倍。在數(shù)據(jù)輸入時(shí)候,通過設(shè)置每個(gè)塊的第一層的conv層的stride=2,來完成下采樣操作使size減半,同時(shí)使用兩倍的卷積核完成升維操作。此時(shí)的shortcut也需要進(jìn)行相應(yīng)的變換
????????在最后一層使用kernelsize為4的全局均值池化層,與常用flatten操作相比可以顯著減少參數(shù)量,之后將512個(gè)值進(jìn)行.view操作,直接接入接入輸出為100的全連接層輸出,而不再使用softmax。(這里有個(gè)點(diǎn)需要注意,為了更好的性能,部分框架選擇了在使用交叉熵?fù)p失函數(shù)時(shí)默認(rèn)加上softmax,這樣無論你的輸出層是什么,只要用了nn.CrossEntropyLoss就默認(rèn)加上了softmax。不僅是Pytorch,國內(nèi)的PaddlePaddle等框架也是這樣。)
resnet.py
class ResidualBlock(nn.Module):
? ?def __init__(self, in_channel, out_channel, stride=1):
? ? ? ?super(ResidualBlock, self).__init__()
? ? ? ?self.left = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False),
? ? ? ? ? ?nn.BatchNorm2d(out_channel),
? ? ? ? ? ?nn.ReLU(inplace=True),
? ? ? ? ? ?nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
? ? ? ? ? ?nn.BatchNorm2d(out_channel)
? ? ? ?)
? ? ? ?self.shortcut = nn.Sequential()
? ? ? ?if stride != 1 or in_channel != out_channel:
? ? ? ? ? ?self.shortcut = nn.Sequential(
? ? ? ? ? ? ? ?nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False),
? ? ? ? ? ? ? ?nn.BatchNorm2d(out_channel)
? ? ? ? ? ?)
? ?def forward(self, x):
? ? ? ?out = self.left(x)
? ? ? ?out += self.shortcut(x)
? ? ? ?out = F.relu(out)
? ? ? ?return out
class _resnet18(nn.Module):
? ?def __init__(self, ResidualBlock, num_classes=100):
? ? ? ?super(_resnet18, self).__init__()
? ? ? ?self.in_channel = 64
? ? ? ?self.conv1 = nn.Sequential(
? ? ? ? ? ?nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
? ? ? ? ? ?nn.BatchNorm2d(64),
? ? ? ? ? ?nn.ReLU(),
? ? ? ?)
? ? ? ?self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=2)
? ? ? ?self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
? ? ? ?self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
? ? ? ?self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
? ? ? ?self.avg = nn.AvgPool2d(kernel_size=4)
? ? ? ?self.fc = nn.Linear(512, num_classes)
? ?def make_layer(self, block, channels, num_blocks, stride):
? ? ? ?# strides=[1, 1] or [2, 1]
? ? ? ?strides = [stride] + [1] * (num_blocks - 1)
? ? ? ?layers = []
? ? ? ?for stride in strides:
? ? ? ? ? ?layers.append(block(self.in_channel, channels, stride))
? ? ? ? ? ?self.in_channel = channels
? ? ? ?return nn.Sequential(*layers)
? ?def forward(self, x):
? ? ? ?out = self.conv1(x)
? ? ? ?out = self.layer1(out)
? ? ? ?out = self.layer2(out)
? ? ? ?out = self.layer3(out)
? ? ? ?out = self.layer4(out)
? ? ? ?out = self.avg(out)
? ? ? ?out = out.view(out.size(0), -1)
? ? ? ?out = self.fc(out)
? ? ? ?return out
def ResNet18():
? ?return _resnet18(ResidualBlock)
訓(xùn)練:

最下面為腳本入口,在設(shè)置好了參數(shù)后載入主函數(shù)。
主函數(shù)為負(fù)責(zé)調(diào)用各個(gè)模塊完成訓(xùn)練任務(wù),并將模型保存:
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tf
from torch.utils.data import DataLoader
from net.resnet import ResNet18
from tqdm.auto import tqdm
from torchsummary import summary
from matplotlib import pyplot as plt
import matplotlib.ticker as m_tick
device = "cuda" if torch.cuda.is_available() else "cpu"
def set_seed(seed: int):
? ?random.seed(seed)
? ?# Numpy
? ?np.random.seed(seed)
? ?# Torch
? ?torch.manual_seed(seed)
? ?if torch.cuda.is_available():
? ? ? ?torch.cuda.manual_seed(seed)
? ? ? ?torch.cuda.manual_seed_all(seed)
? ?torch.backends.cudnn.benchmark = False
? ?torch.backends.cudnn.deterministic = True
def init_weights(m: nn.Module):
? ?class_name = m.__class__.__name__
? ?if class_name.find('Conv2d') != -1 or class_name.find('ConvTranspose2d') != -1:
? ? ? ?nn.init.kaiming_uniform_(m.weight)
? ? ? ?nn.init.zeros_(m.bias)
? ?elif class_name.find('BatchNorm') != -1:
? ? ? ?nn.init.normal_(m.weight, 1.0, 0.02)
? ? ? ?nn.init.zeros_(m.bias)
? ?elif class_name.find('Linear') != -1:
? ? ? ?nn.init.xavier_normal_(m.weight)
? ? ? ?nn.init.zeros_(m.bias)
def get_loader(train_path, val_path, batch_size, workers):
? ?imagenet_norm_mean = (0.485, 0.456, 0.406)
? ?imagenet_norm_std = (0.229, 0.224, 0.225)
? ?# Can do Data augmentation here
? ?t_transform = tf.Compose([
? ? ? ?tf.Resize((64, 64)),
? ? ? ?# reverse
? ? ? ?tf.RandomHorizontalFlip(),
? ? ? ?tf.ToTensor(),
? ? ? ?tf.Normalize(imagenet_norm_mean, imagenet_norm_std)
? ?])
? ?# Validation don't need any operation
? ?v_transform = tf.Compose([
? ? ? ?tf.Resize((64, 64)),
? ? ? ?tf.ToTensor(),
? ? ? ?tf.Normalize(imagenet_norm_mean, imagenet_norm_std)
? ?])
? ?# ImageFloder instead of 'class getData()'
? ?train_dataset = tv.datasets.ImageFolder(root=train_path, transform=t_transform)
? ?train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
? ?val_dataset = tv.datasets.ImageFolder(root=val_path, transform=v_transform)
? ?val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
? ?return train_loader, val_loader
def draw_acc_loss(train_acc, val_acc, train_loss):
? ?x1 = np.arange(args.epochs)
? ?fig = plt.figure(1)
? ?# Set Y as type of %
? ?ax1 = fig.add_subplot()
? ?fmt = '%.2f%%'
? ?y_ticks = m_tick.FormatStrFormatter(fmt)
? ?ax1.yaxis.set_major_formatter(y_ticks)
? ?# plt.figure(figsize=(9, 6), dpi=300)
? ?ax1.plot(x1, train_acc.reshape(-1), label='train_acc')
? ?ax1.plot(x1, val_acc.reshape(-1), '-', label='val_acc')
? ?ax1.set_ylabel('acc')
? ?ax1.set_xlabel('iter')
? ?ax1.set_ylim([0, 1]) ?# 設(shè)置y軸取值范圍
? ?# This is the important function, twin image.
? ?ax2 = ax1.twinx()
? ?ax2.set_ylim([0, 4]) ?# 設(shè)置y軸取值范圍
? ?ax2.set_ylabel('loss')
? ?ax2.plot(x1, train_loss.reshape(-1), '--', label='train_loss')
? ?# The loc of description
? ?ax1.legend(loc=(1 / 32, 16 / 19))
? ?ax2.legend(loc=(1 / 32, 12 / 19))
? ?# plt.savefig('./model/iters.png')
? ?plt.show()
def main(args: argparse.Namespace):
? ?print('---------Train on: ' + device + '----------')
? ?# Set random seed
? ?if args.seed is not None:
? ? ? ?set_seed(args.seed)
? ?# Loading data and transform. ImageNet -- mean & std
? ?train_loader, val_loader = get_loader(args.train_path, args.val_path, args.batch_size, args.workers)
? ?# Create model
? ?model = ResNet18().to(device)
? ?# We don't use init_weight here -- some bugs.
? ?# model.apply(init_weights)
? ?# Visualize model
? ?summary(model, input_size=(3, 64, 64))
? ?# Define Optimizer and Loss
? ?criterion = nn.CrossEntropyLoss()
? ?optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
? ?# Define list to record acc & loss for plt
? ?train_loss = np.array([])
? ?train_acc = np.array([])
? ?val_acc = np.array([])
? ?# For epoch in range(args.epochs):
? ?for epoch in range(args.epochs):
? ? ? ?# train
? ? ? ?train_batch_loss, train_batch_acc = train(epoch, train_loader, model, optimizer, criterion, args)
? ? ? ?train_loss = np.append(train_loss, train_batch_loss)
? ? ? ?train_acc = np.append(train_acc, train_batch_acc)
? ? ? ?# validate
? ? ? ?val_batch_acc = validate(epoch, val_loader, model, args)
? ? ? ?val_acc = np.append(val_acc, val_batch_acc)
? ? ? ?# Save model
? ? ? ?if epoch % 2 == 0:
? ? ? ? ? ?torch.save({
? ? ? ? ? ? ? ?'epoch': epoch,
? ? ? ? ? ? ? ?'model_state_dict': model.state_dict(),
? ? ? ? ? ? ? ?'train_loss': train_loss,
? ? ? ? ? ? ? ?'train_acc': train_acc,
? ? ? ? ? ? ? ?'val_acc': val_acc,
? ? ? ? ? ? ? ?'seed': args.seed
? ? ? ? ? ?}, './model/model.pth')
? ?# Draw loss & acc
? ?# draw_acc_loss(train_acc, val_acc, train_loss)
def train(epoch: int, train_loader: DataLoader, model, optimizer, criterion, args: argparse.Namespace):
? ?model.train()
? ?train_loss_lis = np.array([])
? ?train_acc_lis = np.array([])
? ?for batch in tqdm(train_loader):
? ? ? ?imgs, labels = batch
? ? ? ?imgs, labels = imgs.to(device), labels.to(device)
? ? ? ?# labels = torch.nn.functional.one_hot(labels).long().to(device)
? ? ? ?logits = model(imgs)
? ? ? ?loss = criterion(logits, labels)
? ? ? ?# Compute gradient and do GD step
? ? ? ?optimizer.zero_grad()
? ? ? ?loss.backward()
? ? ? ?optimizer.step()
? ? ? ?# Calculate the batch acc
? ? ? ?acc = (logits.argmax(dim=-1) == labels).float().mean()
? ? ? ?# Record the batch loss and accuracy.
? ? ? ?train_loss_lis = np.append(train_loss_lis, loss.item())
? ? ? ?train_acc_lis = np.append(train_acc_lis, acc.cpu())
? ?train_loss = sum(train_loss_lis) / len(train_loss_lis)
? ?train_acc = sum(train_acc_lis) / len(train_acc_lis)
? ?# Print the information.
? ?print(f"[ Train | {epoch + 1:03d}/{args.epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
? ?return train_loss, train_acc
def validate(epoch: int, val_loader: DataLoader, model, args: argparse.Namespace):
? ?model.eval()
? ?val_acc_lis = np.array([])
? ?for batch in tqdm(val_loader):
? ? ? ?imgs, labels = batch
? ? ? ?with torch.no_grad():
? ? ? ? ? ?imgs, labels = imgs.to(device), labels.to(device)
? ? ? ? ? ?logits = model(imgs)
? ? ? ? ? ?# Calculate the batch acc
? ? ? ? ? ?acc = (logits.argmax(dim=-1) == labels).float().mean()
? ? ? ? ? ?# Record the batch loss and accuracy.
? ? ? ? ? ?val_acc_lis = np.append(val_acc_lis, acc.cpu())
? ?val_acc = sum(val_acc_lis) / len(val_acc_lis)
? ?# Print the information.
? ?print(f"[ Validation | {epoch + 1:03d}/{args.epochs:03d} ] ?acc = {val_acc:.5f}")
? ?return val_acc
if __name__ == '__main__':
? ?parser = argparse.ArgumentParser(description='Source for ImageNet Classification')
? ?parser.add_argument('-sd', '--seed', default=17, type=int, help='seed for initializing training. ')
? ?# dataset parameters
? ?parser.add_argument('-tp', '--train_path', default='dataset/ImageData2/train', help='the path of training data.')
? ?parser.add_argument('-vp', '--val_path', default='dataset/ImageData2/val', help='the path of validation data.')
? ?parser.add_argument('-wn', '--workers', type=int, default=2, help='number of data loading workers (default: 2)')
? ?# train parameters
? ?parser.add_argument('-bs', '--batch_size', type=int, default=256, help='the size of batch.')
? ?parser.add_argument('-ep', '--epochs', type=int, default=10, help='the num of epochs.')
? ?# model parameters
? ?parser.add_argument('-lr', '--lr', type=float, default=0.001, help='initial learning rate', dest='lr')
? ?parser.add_argument('-mm', '--momentum', type=float, default=0.9, help='initial momentum')
? ?parser.add_argument('-wd', '--weight_decay', type=float, default=1e-5, help='initial momentum')
? ?args = parser.parse_args()
? ?main(args)
測(cè)試:
????????測(cè)試文件由于輸出txt的格式要求,所以沒法用dataloader了,只能一個(gè)一個(gè)圖片讀取然后寫出。
????????此處畫圖函數(shù)正式起作用,根據(jù)保存的list畫圖并保存到文件夾。
import os
import random
from glob import glob
import numpy as np
import torch
import torch.nn as nn
import argparse
import torchvision as tv
import torchvision.transforms as tf
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset
from net.resnet import ResNet18
from tqdm.auto import tqdm
from torchsummary import summary
import matplotlib.ticker as m_tick
device = "cuda" if torch.cuda.is_available() else "cpu"
imagenet_norm_mean = (0.485, 0.456, 0.406)
imagenet_norm_std = (0.229, 0.224, 0.225)
def set_seed(seed: int):
? ?random.seed(seed)
? ?# Numpy
? ?np.random.seed(seed)
? ?# Torch
? ?torch.manual_seed(seed)
? ?if torch.cuda.is_available():
? ? ? ?torch.cuda.manual_seed(seed)
? ? ? ?torch.cuda.manual_seed_all(seed)
? ?torch.backends.cudnn.benchmark = False
? ?torch.backends.cudnn.deterministic = True
def draw_acc_loss(train_acc, val_acc, train_loss):
? ?x1 = np.arange(len(train_acc))
? ?fig = plt.figure(1)
? ?# Set Y as type of %
? ?ax1 = fig.add_subplot()
? ?fmt = '%.2f%%'
? ?y_ticks = m_tick.FormatStrFormatter(fmt)
? ?ax1.yaxis.set_major_formatter(y_ticks)
? ?# plt.figure(figsize=(9, 6), dpi=300)
? ?ax1.plot(x1, train_acc.reshape(-1), label='train_acc')
? ?ax1.plot(x1, val_acc.reshape(-1), '-', label='val_acc')
? ?ax1.set_ylabel('acc')
? ?ax1.set_xlabel('iter')
? ?ax1.set_ylim([0, 1]) ?# 設(shè)置y軸取值范圍
? ?# This is the important function, twin image.
? ?ax2 = ax1.twinx()
? ?ax2.set_ylim([0, 4]) ?# 設(shè)置y軸取值范圍
? ?ax2.set_ylabel('loss')
? ?ax2.plot(x1, train_loss.reshape(-1), '--', label='train_loss')
? ?# The loc of description
? ?ax1.legend(loc=(1 / 32, 16 / 19))
? ?ax2.legend(loc=(1 / 32, 12 / 19))
? ?plt.savefig('./model/iters.png')
? ?plt.show()
def main(model, args: argparse.Namespace):
? ?model_state_dict = model['model_state_dict']
? ?train_loss = model['train_loss']
? ?train_acc = model['train_acc']
? ?val_acc = model['val_acc']
? ?draw_acc_loss(train_acc, val_acc, train_loss)
? ?#Define Optimizer and Loss
? ?print(args)
? ?print('---------Test on: ' + device + '----------')
? ?# Set random seed
? ?if args.seed is not None:
? ? ? ?set_seed(args.seed)
? ?#Create model
? ?test_model = ResNet18().to(device)
? ?test_model.load_state_dict(model_state_dict)
? ?test_model.eval()
? ?# Create txt
? ?fw = open('./dataset/ImageData2/test.txt', 'w')
? ?# 5000 pics
? ?for i in range(0, 5000):
? ? ? ?f_name = os.path.join(r'./dataset/ImageData2/test/' + str(i) + '.jpg')
? ? ? ?img = tv.io.read_image(f_name)
? ? ? ?test_transform = tf.Compose([
? ? ? ? ? ?tf.ToPILImage(),
? ? ? ? ? ?tf.Resize((64, 64)),
? ? ? ? ? ?tf.ToTensor(),
? ? ? ? ? ?tf.Normalize(imagenet_norm_mean, imagenet_norm_std)
? ? ? ?])
? ? ? ?img = test_transform(img)
? ? ? ?img = torch.unsqueeze(img, dim=0).float()
? ? ? ?with torch.no_grad():
? ? ? ? ? ?logits = test_model(img.to(device))
? ? ? ? ? ?test_batch_labels = logits.argmax(dim=-1)
? ? ? ?# Write in txt
? ? ? ?fw.write('test/' + str(i) + '.jpg ' + str(test_batch_labels.item()))
? ? ? ?fw.write("\n")
if __name__ == '__main__':
? ?pth_file = r'./model/model.pth'
? ?model = torch.load(pth_file)
? ?'''
? ?{
? ? ? ?'epoch': epoch,
? ? ? ?'model_state_dict': model.state_dict(),
? ? ? ?'optimizer_state_dict': optimizer.state_dict(),
? ? ? ?'train_loss': train_loss,
? ? ? ?'train_acc': train_acc,
? ? ? ?'val_acc': val_acc,
? ?}
? ?'''
? ?parser = argparse.ArgumentParser(description='Source for ImageNet Classification')
? ?parser.add_argument('-sd', '--seed', default=17, type=int, help='seed for initializing training. ')
? ?#dataset parameters
? ?parser.add_argument('-tp', '--train_path', default='./dataset/ImageData2/train', help='the path of training data.')
? ?parser.add_argument('-vp', '--val_path', default='./dataset/ImageData2/val', help='the path of validation data.')
? ?parser.add_argument('-tsp', '--test_path', default='./dataset/ImageData2/test', help='the path of test data.')
? ?parser.add_argument('-wn', '--workers', type=int, default=2, help='number of data loading workers (default: 2)')
? ?#train parameters
? ?parser.add_argument('-bs', '--batch_size', type=int, default=256, help='the size of batch.')
? ?parser.add_argument('-ep', '--epochs', type=int, default=10, help='the num of epochs.')
? ?#model parameters
? ?parser.add_argument('-lr', '--lr', type=float, default=0.001, help='initial learning rate', dest='lr')
? ?parser.add_argument('-mm', '--momentum', type=float, default=0.9, help='initial momentum')
? ?parser.add_argument('-wd', '--weight_decay', type=float, default=1e-5, help='initial momentum')
? ?args = parser.parse_args()
? ?main(model, args)
實(shí)驗(yàn)結(jié)果:

實(shí)驗(yàn)心得:
????????使用了matplotlib對(duì)三個(gè)記錄的數(shù)據(jù)進(jìn)行可視化,如上圖,我的電腦跑一個(gè)epoch需要2分多鐘,而且跑9,10個(gè)epoch就會(huì)崩掉…淚目,所以也沒做對(duì)比試驗(yàn)。不過個(gè)人決定還是從這次實(shí)驗(yàn)中學(xué)到了不少東西。
????????模型的保存與讀取很重要,當(dāng)訓(xùn)練數(shù)據(jù)和規(guī)模上來后不能說再像以前那樣幾十秒跑一個(gè)模型然后畫圖,復(fù)現(xiàn)什么的,現(xiàn)在隨便一個(gè)模型都需要跑少1個(gè)小時(shí),還是控制了數(shù)據(jù)量的情況下…,根據(jù)epoch實(shí)時(shí)保存模型太重要了,而且也可以防止某個(gè)epoch時(shí)候內(nèi)存溢出程序崩掉后沒有結(jié)果前功盡棄;
????????命令行傳參很重要,這個(gè)初使用多少有點(diǎn)不習(xí)慣,不過相信未來在遠(yuǎn)程服務(wù)器跑model時(shí)肯定能用的著;
????????重構(gòu)了代碼結(jié)構(gòu),封裝了主函數(shù),train函數(shù)和validate函數(shù)等等,不像曾經(jīng)那樣一個(gè)腳本從頭跑到尾,代碼可讀性明顯提升了;
????????盡可能的去使用多個(gè).py文件封裝不同種類的函數(shù)和模型,比如test和train分開,這樣測(cè)試的時(shí)候就會(huì)簡(jiǎn)介很多。