5.2 多層感知機(jī)代碼實(shí)現(xiàn)
在?PyTorch 中實(shí)現(xiàn)多層感知機(jī)(MLP)可以分為以下幾步:
導(dǎo)入所需的庫和模塊,包括?PyTorch 的 nn 模塊,torch.optim 模塊和數(shù)據(jù)加載及預(yù)處理的常用庫,例如 torchvision 和 torchtext。
定義多層感知機(jī)模型。這可以通過繼承?PyTorch 的 nn.Module 類并定義前向傳播函數(shù)來完成。
加載數(shù)據(jù)集。這可以使用?PyTorch 提供的數(shù)據(jù)加載器或自定義加載器完成。
定義損失函數(shù)和優(yōu)化器。這可以使用?PyTorch 的內(nèi)置函數(shù)和優(yōu)化器來完成。
開始訓(xùn)練模型。在訓(xùn)練循環(huán)中,您需要通過獲取輸入和標(biāo)簽,計(jì)算模型輸出,計(jì)算損失并更新模型參數(shù)來訓(xùn)練模型。
在訓(xùn)練之后,可以使用模型進(jìn)行推理或?qū)⑵浔4嬉怨硎褂谩?/span>
下面這個(gè)例子中,梗直哥給你演示了一個(gè)包含兩個(gè)隱藏層的?MLP實(shí)現(xiàn),使用 MNIST 數(shù)據(jù)集進(jìn)行訓(xùn)練,一起來感受一下整個(gè)過程:
import?torch
import?torch.nn?as?nn
import?torch.optim?as?optim
from?torchvision?import?datasets, transforms
#?定義?MLP?網(wǎng)絡(luò)
class?MLP(nn.Module):
????def?__init__(self, input_size, hidden_size, num_classes):
????????super(MLP,?self).__init__()
????????self.fc1?=?nn.Linear(input_size, hidden_size)
????????self.relu?=?nn.ReLU()
????????self.fc2?=?nn.Linear(hidden_size, hidden_size)
????????self.fc3?=?nn.Linear(hidden_size, num_classes)
????
????def?forward(self, x):
????????out?=?self.fc1(x)
????????out?=?self.relu(out)
????????out?=?self.fc2(out)
????????out?=?self.relu(out)
????????out?=?self.fc3(out)
????????return?out
#?定義超參數(shù)
input_size?=?28?*?28??#?輸入大小
hidden_size?=?512??#?隱藏層大小
num_classes?=?10??#?輸出大?。悇e數(shù))
batch_size?=?100??#?批大小
learning_rate?=?0.001??#?學(xué)習(xí)率
num_epochs?=?10??#?訓(xùn)練輪數(shù)
#?加載?MNIST?數(shù)據(jù)集
train_dataset?=?datasets.MNIST(root='../data/mnist', train=True, transform=transforms.ToTensor(), download=True)
test_dataset?=?datasets.MNIST(root='../data/mnist', train=False, transform=transforms.ToTensor(), download=True)
train_loader?=?torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader?=?torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
#?實(shí)例化?MLP?網(wǎng)絡(luò)
model?=?MLP(input_size, hidden_size, num_classes)
現(xiàn)在我們已經(jīng)定義了?MLP 網(wǎng)絡(luò)并加載了 MNIST 數(shù)據(jù)集,接下來使用 PyTorch 的自動(dòng)求導(dǎo)功能和優(yōu)化器進(jìn)行訓(xùn)練。首先,定義損失函數(shù)和優(yōu)化器;然后迭代訓(xùn)練數(shù)據(jù)并使用優(yōu)化器更新網(wǎng)絡(luò)參數(shù)。
#?定義損失函數(shù)和優(yōu)化器
criterion?=?nn.CrossEntropyLoss()
optimizer?=?optim.Adam(model.parameters(), lr=learning_rate)
#?訓(xùn)練網(wǎng)絡(luò)
for?epoch?in?range(num_epochs):
????for?i, (images, labels)?in?enumerate(train_loader):
????????images?=?images.reshape(-1,?28?*?28)
????????outputs?=?model(images)
????????loss?=?criterion(outputs, labels)
????????
????????optimizer.zero_grad()
????????loss.backward()
????????optimizer.step()
????????
????????if?(i?+?1)?%?100?==?0:
????????????print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss:?{loss.item():.4f}')
Epoch [1/10], Step [100/600], Loss: 0.0419
Epoch [1/10], Step [200/600], Loss: 0.0931
Epoch [1/10], Step [300/600], Loss: 0.0609
Epoch [1/10], Step [400/600], Loss: 0.0482
Epoch [1/10], Step [500/600], Loss: 0.1138
Epoch [1/10], Step [600/600], Loss: 0.0533
Epoch [2/10], Step [100/600], Loss: 0.0340
Epoch [2/10], Step [200/600], Loss: 0.0619
Epoch [2/10], Step [300/600], Loss: 0.2061
Epoch [2/10], Step [400/600], Loss: 0.0695
Epoch [2/10], Step [500/600], Loss: 0.0269
Epoch [2/10], Step [600/600], Loss: 0.0330
Epoch [3/10], Step [100/600], Loss: 0.0135
Epoch [3/10], Step [200/600], Loss: 0.0710
Epoch [3/10], Step [300/600], Loss: 0.0089
Epoch [3/10], Step [400/600], Loss: 0.0139
Epoch [3/10], Step [500/600], Loss: 0.0786
Epoch [3/10], Step [600/600], Loss: 0.0331
Epoch [4/10], Step [100/600], Loss: 0.0072
Epoch [4/10], Step [200/600], Loss: 0.0183
Epoch [4/10], Step [300/600], Loss: 0.0291
Epoch [4/10], Step [400/600], Loss: 0.0399
Epoch [4/10], Step [500/600], Loss: 0.0065
Epoch [4/10], Step [600/600], Loss: 0.0306
Epoch [5/10], Step [100/600], Loss: 0.0097
Epoch [5/10], Step [200/600], Loss: 0.0073
Epoch [5/10], Step [300/600], Loss: 0.0327
Epoch [5/10], Step [400/600], Loss: 0.0027
Epoch [5/10], Step [500/600], Loss: 0.0254
Epoch [5/10], Step [600/600], Loss: 0.0136
Epoch [6/10], Step [100/600], Loss: 0.0195
Epoch [6/10], Step [200/600], Loss: 0.0124
Epoch [6/10], Step [300/600], Loss: 0.0065
Epoch [6/10], Step [400/600], Loss: 0.0975
Epoch [6/10], Step [500/600], Loss: 0.0333
Epoch [6/10], Step [600/600], Loss: 0.0346
Epoch [7/10], Step [100/600], Loss: 0.0055
Epoch [7/10], Step [200/600], Loss: 0.0003
Epoch [7/10], Step [300/600], Loss: 0.0014
Epoch [7/10], Step [400/600], Loss: 0.0052
Epoch [7/10], Step [500/600], Loss: 0.0592
Epoch [7/10], Step [600/600], Loss: 0.0139
Epoch [8/10], Step [100/600], Loss: 0.0196
Epoch [8/10], Step [200/600], Loss: 0.0122
Epoch [8/10], Step [300/600], Loss: 0.0211
Epoch [8/10], Step [400/600], Loss: 0.0009
Epoch [8/10], Step [500/600], Loss: 0.0464
Epoch [8/10], Step [600/600], Loss: 0.0207
Epoch [9/10], Step [100/600], Loss: 0.0078
Epoch [9/10], Step [200/600], Loss: 0.0047
Epoch [9/10], Step [300/600], Loss: 0.0029
Epoch [9/10], Step [400/600], Loss: 0.0047
Epoch [9/10], Step [500/600], Loss: 0.0724
Epoch [9/10], Step [600/600], Loss: 0.0219
Epoch [10/10], Step [100/600], Loss: 0.0008
Epoch [10/10], Step [200/600], Loss: 0.0054
Epoch [10/10], Step [300/600], Loss: 0.0015
Epoch [10/10], Step [400/600], Loss: 0.0029
Epoch [10/10], Step [500/600], Loss: 0.0043
Epoch [10/10], Step [600/600], Loss: 0.0025
最后,我們可以在測試數(shù)據(jù)上評(píng)估模型的準(zhǔn)確率:
#?測試網(wǎng)絡(luò)
with?torch.no_grad():
????correct?=?0
????total?=?0
????for?images, labels?in?test_loader:
????????images?=?images.reshape(-1,?28?*?28)
????????outputs?=?model(images)
????????_, predicted?=?torch.max(outputs.data,?1)
????????total?+=?labels.size(0)
????????correct?+=?(predicted?==?labels).sum().item()
????print(f'Accuracy of the network on the 10000 test images:?{100?*?correct?/?total}?%')
Accuracy of the network on the 10000 test images: 97.88 %
可以看到訓(xùn)練效果還不錯(cuò),準(zhǔn)確率97.88%。
梗直哥建議:我們這節(jié)課言簡意賅的講解了一個(gè)例子,主要目的是突出代碼實(shí)現(xiàn)。如果你在理解方面感覺自己有問題,可以稍微分析一下原因。如果代碼看不懂是python的問題,可以考慮補(bǔ)充這方面知識(shí)。如果對(duì)神經(jīng)網(wǎng)絡(luò)原理還希望了解更多,可以選修哥《機(jī)器學(xué)習(xí)必修課:python實(shí)戰(zhàn)》中神經(jīng)網(wǎng)絡(luò)相關(guān)章節(jié)內(nèi)容。如果是運(yùn)行或調(diào)參經(jīng)驗(yàn)缺乏,歡迎入群討論(微信:gengzhige99)
?