MNIST集的分類的實(shí)現(xiàn)
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
#定義訓(xùn)練集
train_dataset = datasets.MNIST(root='./', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./', train=False, transform=transforms.ToTensor(), download=True)
#定義一個(gè)批次的大小
batch_size = 64
#裝載訓(xùn)練集(dataset是指要讀取的數(shù)據(jù)位置, batch_size是指每個(gè)批次讀入的數(shù)據(jù)大小, shuffle是指是否在讀取前進(jìn)行順序的打亂)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
#裝載測(cè)試集
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)
#數(shù)據(jù)加載
for i, data in enumerate(train_loader):
? ?inputs, labels = data
? ?print(inputs.shape)
? ?print(labels.shape)
? ?break
#定義分類模型
class Net(torch.nn.Module):
? ?def __init__(self):
? ? ? ?#繼承父類
? ? ? ?super(Net, self).__init__()
? ? ? ?#定義神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)
? ? ? ?#定義全連接層(輸入的每個(gè)樣本的大小為28 * 28 == 784, 所以輸入神經(jīng)元個(gè)數(shù)為784, 輸出為0 - 9的10個(gè)類別的概率, 所以輸出神經(jīng)元個(gè)數(shù)為10)
? ? ? ?self.fcl = torch.nn.Linear(784, 10)
? ? ? ?#定義激活函數(shù)為softmax函數(shù), dim=1是因?yàn)橐疵恳恍械拿恳涣星笕「怕蕘?lái)判斷每一行屬于什么類
? ? ? ?self.softmax = torch.nn.Softmax(dim=1)
? ?def forward(self, x):
? ? ? ?#由于x的形狀為[64, 1, 28, 28], 而神經(jīng)網(wǎng)絡(luò)的計(jì)算取決于矩陣計(jì)算, 所以應(yīng)該先將x轉(zhuǎn)換為二維的[64, 784]的形狀再進(jìn)行計(jì)算
? ? ? ?x = x.view(x.size(0), -1)
? ? ? ?x = self.fcl(x)
? ? ? ?x = self.softmax(x)
? ? ? ?return x
#定義模型
model = Net()
#定義代價(jià)函數(shù)
mse_loss = torch.nn.MSELoss(reduction='mean')
#定義優(yōu)化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
#定義模型的訓(xùn)練
def train():
? ?for i, data in enumerate(train_loader):
? ? ? ?inputs, labels = data
? ? ? ?#獲得數(shù)據(jù)的輸入值并將其放入模型中進(jìn)行計(jì)算
? ? ? ?out = model(inputs)
? ? ? ?#將labels轉(zhuǎn)換為獨(dú)熱編碼, 先將labels數(shù)據(jù)從一維變成兩維
? ? ? ?labels = labels.view(-1, 1)
? ? ? ?#使用scatter函數(shù)得到one_hot_labels
? ? ? ?#Tensor.scatter(dim, index, src), 沿著第dim維將Tensor中的第index(索引)位替換為src
? ? ? ?one_hot_labels = torch.zeros(labels.size(0), 10).scatter(1, labels, 1)
? ? ? ?#通過(guò)獨(dú)熱編碼和經(jīng)過(guò)softmax計(jì)算的概率值(都是從0 - 1的, 可以計(jì)算損失, 這也是將標(biāo)簽繪制為獨(dú)熱編碼的意義之一)計(jì)算損失值
? ? ? ?loss = mse_loss(out, one_hot_labels)
? ? ? ?#每100個(gè)批次輸出一次損失值
? ? ? ?if (i + 1) % 100 == 0:
? ? ? ? ? ?print("第", i + 1, "批次的損失值為:", loss.item())
? ? ? ?#在反向傳播計(jì)算梯度之前, 先進(jìn)行梯度清零
? ? ? ?optimizer.zero_grad()
? ? ? ?#進(jìn)行反向傳播求解梯度
? ? ? ?loss.backward()
? ? ? ?#通過(guò)優(yōu)化器更新參數(shù)
? ? ? ?optimizer.step()
#定義模型的測(cè)試
def test():
? ?#定義一個(gè)變量correct, 用來(lái)統(tǒng)計(jì)測(cè)試集中正確的次數(shù)
? ?correct = 0
? ?for i, data in enumerate(test_loader):
? ? ? ?#獲得一個(gè)批次的輸入和標(biāo)簽
? ? ? ?inputs, labels = data
? ? ? ?#將測(cè)試集的inputs放入模型中計(jì)算得到測(cè)試集的預(yù)測(cè)結(jié)果
? ? ? ?test_pred = model(inputs)
? ? ? ?#通過(guò)max(test_pred, 1)函數(shù)得到test_pred中按照每一行取最大值的次序, _,是指無(wú)關(guān)變量, 因?yàn)閙ax函數(shù)返回最大的值和最大的值對(duì)應(yīng)的位序, 所以需要兩個(gè)返回值
? ? ? ?_, predict = torch.max(test_pred, 1)
? ? ? ?#通過(guò)predict和標(biāo)簽進(jìn)行對(duì)比(predict是第幾位是最大的概率, 標(biāo)簽是0 - 9的數(shù)字, 所以當(dāng)predict和labels相等時(shí)就相當(dāng)于預(yù)測(cè)值是正確的)來(lái)得到判斷正確的數(shù)量并賦給correct
? ? ? ?correct += (predict == labels).sum()
? ?#計(jì)算一整次測(cè)試的準(zhǔn)確率
? ?print("第", epoch + 1, "次訓(xùn)練后測(cè)試的預(yù)測(cè)準(zhǔn)確率為:", correct.item() / len(test_dataset))
#模型進(jìn)行epoch次訓(xùn)練并且每次訓(xùn)練完都進(jìn)行一次測(cè)試
for epoch in range(10):
? ?print("第", epoch + 1, "次:")
? ?train()
? ?test()