Pytorch學(xué)習(xí)筆記12:Sinx正弦函數(shù)曲線擬合(參數(shù)保存和加載)
#需要import的lib
import numpy
import torch
import time
import platform
import cmath
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
#import CV2
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#需要import的lib
#運(yùn)行環(huán)境tesla k20/python 3.7/pytorch 1.20
print('——————————運(yùn)行環(huán)境——————————')
print('Python Version:',platform.python_version())
print('Torch Version:',torch.__version__)
#print('OpenCV Version:',CV2.__version__)
print('CUDA GPU check:',torch.cuda.is_available())
if(torch.cuda.is_available()):
print('CUDA GPU num:', torch.cuda.device_count())
n=torch.cuda.device_count()
while n > 0:
? print('CUDA GPU name:', torch.cuda.get_device_name(n-1))
? print('CUDA GPU capability:', torch.cuda.get_device_capability(n-1))
? print('CUDA GPU properties:', torch.cuda.get_device_properties(n-1))
? n -= 1
print('CUDA GPU index:', torch.cuda.current_device())
print('——————————運(yùn)行環(huán)境——————————')
time_start=time.time()
#device=torch.device('cuda:0')
#工程優(yōu)化應(yīng)用,GPU不一定更快
device=torch.device('cpu')
class net(torch.nn.Module):
? ?def __init__(self):
? ? ? ?super(net, self).__init__()
? ? ? ?self.model=torch.nn.Sequential(
? ? ? ? ? ?torch.nn.Linear(1,100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 1),
? ? ? ?)
? ?def forward(self,x) :
? ? ? ?x=self.model(x)
? ? ? ?return x
net_inst=net()
#print(net_inst.state_dict())
#net_inst.load_state_dict(torch.load("c:/net.pth"))#加載神經(jīng)網(wǎng)絡(luò)的參數(shù)
#print(net_inst.state_dict())
optim=torch.optim.SGD(net_inst.parameters(),lr=1e-4)
loss=torch.nn.MSELoss()
epoch=96
net_inst.load_state_dict(torch.load("c:/net.pth"))
optim.load_state_dict(torch.load("c:/optim.pth"))
sinx=torch.zeros(100,2)
sinx.requires_grad=False
for i in range(0,100):#對(duì)sinx進(jìn)行采樣,加了正態(tài)分布的噪聲
? ?sinx[i][0]=i*0.1
? ?mid=torch.from_numpy(np.random.randn(1)*0.03) ?# 注意numpy轉(zhuǎn)tensor
? ?sinx[i][1]=cmath.sin(i*0.1).real+mid
#print(sinx)
#print(sinx.shape)
#t1=sinx.transpose(0,1) #轉(zhuǎn)置一下
#print(t1)
#e=t1[0].detach().numpy()
#f=t1[1].detach().numpy()
#plt.plot(e, f,'b.-')
#plt.show()
data=torch.zeros(100,1)
#print(data)
#print(data.shape)
t1=sinx.transpose(0,1)
t2=data.transpose(0,1)
t2=t1[0]
t3=t2.unsqueeze(1)
#print(t3.shape)
data=t3
#print(data.shape)
target=torch.zeros(100,1)
t2=t1[1]
t3=t2.unsqueeze(1)
target=t3
#print(target.shape)
#print(target)
logits=torch.zeros(100,1)
for j in range(20000):
? ?logits=net_inst(data)
? ?loss1=loss(logits, target)
? ?optim.zero_grad()
? ?loss1.backward()
? ?optim.step()
? ?if j%5000==0:
? ? ? ?print('loss:', loss1)
torch.save(obj=net_inst.state_dict(),f="c:/net.pth")
torch.save(obj=optim.state_dict(),f="c:/optim.pth")
tt1=data.transpose(0,1)
tt2=target.transpose(0,1)
tt3=logits.transpose(0,1)
print(tt1.shape)
print(tt2.shape)
print(tt3.shape)
ttt1=tt1.detach().numpy()
ttt2=tt2.detach().numpy()
ttt3=tt3.detach().numpy()
plt.plot(ttt1, ttt2,'b.-')
plt.plot(ttt1, ttt3,'r.-')
plt.show()
time_end=time.time()
print('Totally cost',time_end-time_start,'s')
#print(net_inst.state_dict())
#torch.save(obj=net_inst.state_dict(),f="c:/net.pth")#保存神經(jīng)網(wǎng)絡(luò)的參數(shù)