一個簡單的pytorch中線性回歸模型的創(chuàng)建、訓(xùn)練和測試,每一步都有明確注釋
#調(diào)用torch庫
import torch
#調(diào)用numpy庫
import numpy
#可有可無,用來調(diào)用查看線型圖的方法
import matplotlib.pyplot as plt
#定義一個訓(xùn)練集
t_data = torch.rand(100, 1)
print("t_data:", t_data)
#定義一個訓(xùn)練集的答案
p_data = t_data * 2
print("p_data:", p_data)
#定義一個線性回歸模型
class LinearRegression(torch.nn.Module):
? ?#初始化線性回歸模型,定義網(wǎng)絡(luò)的結(jié)構(gòu)
? ?def __init__(self):
? ? ? ?#初始化父類
? ? ? ?super(LinearRegression, self).__init__()
? ? ? ?#定義自己的神經(jīng)元
? ? ? ?self.Linear = torch.nn.Linear(1, 1)
? ?#定義網(wǎng)絡(luò)的計算:前向傳遞計算
? ?def forward(self, x):
? ? ? ?#將x(輸入的數(shù)據(jù),可以是訓(xùn)練集或者測試集)在自定義的全連接層中進(jìn)行計算并返回計算得到的預(yù)測值
? ? ? ?pred = self.Linear(x)
? ? ? ?return pred
#定義模型
model = LinearRegression()
#定義損失函數(shù)
mse_loss = torch.nn.MSELoss(reduction='mean')
#定義優(yōu)化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
#查看定義的模型中的參數(shù):可以看出模型的初始參數(shù)是隨機(jī)設(shè)置的,我們想要的參數(shù)是通過訓(xùn)練集來進(jìn)行更新得到的
for name, parameters in model.named_parameters():
? ?print("name:", name)
? ?print("parameters:", parameters)
#通過訓(xùn)練集對模型進(jìn)行訓(xùn)練
for epoch in range(1000):
? ?#通過訓(xùn)練集t_data得到訓(xùn)練結(jié)果pred
? ?pred = model(t_data)
? ?#通過pred和p_data計算損失
? ?loss = mse_loss(pred, p_data)
? ?#輸出經(jīng)過計算的損失值以方便查看損失是否降低來判斷訓(xùn)練是否有效(為了不過于繁瑣,只查看第0, 200, 400, 600, 800, 1000次的損失值)
? ?if (epoch + 1) % 200 == 0:
? ? ? ?print("Epoch:", epoch + 1)
? ? ? ?print("loss:", loss.item())
? ?#在反向計算梯度前,需要先進(jìn)行優(yōu)化器的梯度清零,否則梯度會在每次訓(xùn)練過程中累加
? ?optimizer.zero_grad()
? ?#進(jìn)行梯度的反向計算
? ?loss.backward()
? ?#通過優(yōu)化器更新模型中的參數(shù)
? ?optimizer.step()
? ?#輸出經(jīng)過調(diào)試后的參數(shù)以方便查看權(quán)重參數(shù)的更替(為了不過于繁瑣,只查看第0, 200, 400, 600, 800, 1000次的權(quán)重參數(shù))
? ?if (epoch + 1) % 200 == 0:
? ? ? ?for name,parameters in model.named_parameters():
? ? ? ? ? ?print("name:", name)
? ? ? ? ? ?print("parameters:", parameters, end='\n\n')
#通過自己輸入一個測試集得到預(yù)測結(jié)果來判斷模型訓(xùn)練的如何
t_test = torch.rand(100, 1)
p_test = model(t_test)
print("t_test:", t_test)
print("p_test:", p_test)
#可有可無 用來查看線型圖,通過線型圖可以發(fā)現(xiàn),測試集和測試結(jié)果符合的線性關(guān)系是否為2來判斷訓(xùn)練的效果
t_test = t_test.data.numpy()
p_test = p_test.data.numpy()
plt.scatter(t_test, p_test)
plt.show()