Pytorch學(xué)習(xí)筆記11:Sinx正弦函數(shù)曲線擬合
#需要import的lib
import torch
import time
import platform
import cmath
import matplotlib.pyplot as plt
import numpy as np
#import CV2
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#需要import的lib
#運(yùn)行環(huán)境tesla k20/python 3.7/pytorch 1.20
print('——————————運(yùn)行環(huán)境——————————')
print('Python Version:',platform.python_version())
print('Torch Version:',torch.__version__)
#print('OpenCV Version:',CV2.__version__)
print('CUDA GPU check:',torch.cuda.is_available())
if(torch.cuda.is_available()):
print('CUDA GPU num:', torch.cuda.device_count())
n=torch.cuda.device_count()
while n > 0:
? print('CUDA GPU name:', torch.cuda.get_device_name(n-1))
? print('CUDA GPU capability:', torch.cuda.get_device_capability(n-1))
? print('CUDA GPU properties:', torch.cuda.get_device_properties(n-1))
? n -= 1
print('CUDA GPU index:', torch.cuda.current_device())
print('——————————運(yùn)行環(huán)境——————————')
time_start=time.time()
#device=torch.device('cuda:0')
#工程優(yōu)化應(yīng)用,GPU不一定更快
device=torch.device('cpu')
#對sinx進(jìn)行采樣
#3.14弧度=180度
data=torch.zeros(2,100)
data.requires_grad=False
pred=torch.zeros(2,100)
pred.requires_grad=False
graph=torch.zeros(2,100)
graph.requires_grad=False
wb=torch.ones(2,5)
wb.requires_grad=True
print(wb)
for i in range(0,100):#對sinx進(jìn)行采樣,加了正態(tài)分布的噪聲
? ?data[0][i]=i*0.1
? ?mid=torch.from_numpy(np.random.randn(1)*0.01)#注意numpy轉(zhuǎn)tensor
? ?data[1][i]=cmath.sin(i*0.1).real+mid
print(data)
pred[0]=data[0]
graph[0]=data[0]
print(pred)
def func1(wb,pred):
? ?y=(pred[0]*0.1)*wb[0][0]+wb[1][0]+(pred[0]*0.1)**2*wb[0][1]+wb[1][1]+(pred[0]*0.1)**3*wb[0][2]+wb[1][2]+(pred[0]*0.1)**4*wb[0][3]+wb[1][3]+(pred[0]*0.1)**5*wb[0][4]+wb[1][4] #5次方泰勒展開
? ?return y
loss_func=torch.nn.MSELoss()
optim=torch.optim.Adam([wb],lr=1e-3)
for step in range(3000000):#如果迭代過程太慢,可以減少這個數(shù)值做測試體驗(yàn)下,不過縮小后擬合效果不好
? ?loss=loss_func(func1(wb,pred),data[1])
? ?optim.zero_grad()
? ?loss.backward()
? ?optim.step()
? ?if step % 200 == 0:
? ? ? ?#print('step {}:x={},f(x)={}'.format(step, wb.tolist(), loss.item()))
? ? ? ?print('wb:',wb)
? ? ? ?print('loss:',loss.item())
graph[1]=func1(wb, graph)
a=graph[0].detach().numpy()#注意tensor轉(zhuǎn)numpy
b=graph[1].detach().numpy()
plt.plot(a, b)
c=data[0].detach().numpy()
d=data[1].detach().numpy()
plt.plot(c, d)
plt.show()
time_end=time.time()
print('Totally cost',time_end-time_start,'s')