最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會(huì)員登陸 & 注冊(cè)

Pytorch學(xué)習(xí)筆記12:Sinx正弦函數(shù)曲線擬合(參數(shù)保存和加載)

2021-06-30 09:09 作者:車科技2020  | 我要投稿

#需要import的lib
import numpy
import torch
import time
import platform
import cmath
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation

#import CV2
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

#需要import的lib
#運(yùn)行環(huán)境tesla k20/python 3.7/pytorch 1.20
print('——————————運(yùn)行環(huán)境——————————')
print('Python Version:',platform.python_version())
print('Torch Version:',torch.__version__)
#print('OpenCV Version:',CV2.__version__)
print('CUDA GPU check:',torch.cuda.is_available())
if(torch.cuda.is_available()):
print('CUDA GPU num:', torch.cuda.device_count())
n=torch.cuda.device_count()
while n > 0:
? print('CUDA GPU name:', torch.cuda.get_device_name(n-1))
? print('CUDA GPU capability:', torch.cuda.get_device_capability(n-1))
? print('CUDA GPU properties:', torch.cuda.get_device_properties(n-1))
? n -= 1
print('CUDA GPU index:', torch.cuda.current_device())
print('——————————運(yùn)行環(huán)境——————————')

time_start=time.time()

#device=torch.device('cuda:0')
#工程優(yōu)化應(yīng)用,GPU不一定更快
device=torch.device('cpu')

class net(torch.nn.Module):
? ?def __init__(self):
? ? ? ?super(net, self).__init__()
? ? ? ?self.model=torch.nn.Sequential(
? ? ? ? ? ?torch.nn.Linear(1,100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 1),
? ? ? ?)
? ?def forward(self,x) :
? ? ? ?x=self.model(x)
? ? ? ?return x

net_inst=net()
#print(net_inst.state_dict())
#net_inst.load_state_dict(torch.load("c:/net.pth"))#加載神經(jīng)網(wǎng)絡(luò)的參數(shù)
#print(net_inst.state_dict())
optim=torch.optim.SGD(net_inst.parameters(),lr=1e-4)
loss=torch.nn.MSELoss()
epoch=96

net_inst.load_state_dict(torch.load("c:/net.pth"))
optim.load_state_dict(torch.load("c:/optim.pth"))



sinx=torch.zeros(100,2)
sinx.requires_grad=False

for i in range(0,100):#對(duì)sinx進(jìn)行采樣,加了正態(tài)分布的噪聲
? ?sinx[i][0]=i*0.1
? ?mid=torch.from_numpy(np.random.randn(1)*0.03) ?# 注意numpy轉(zhuǎn)tensor
? ?sinx[i][1]=cmath.sin(i*0.1).real+mid

#print(sinx)
#print(sinx.shape)


#t1=sinx.transpose(0,1) #轉(zhuǎn)置一下
#print(t1)
#e=t1[0].detach().numpy()
#f=t1[1].detach().numpy()
#plt.plot(e, f,'b.-')
#plt.show()


data=torch.zeros(100,1)
#print(data)
#print(data.shape)
t1=sinx.transpose(0,1)
t2=data.transpose(0,1)
t2=t1[0]
t3=t2.unsqueeze(1)
#print(t3.shape)
data=t3
#print(data.shape)

target=torch.zeros(100,1)
t2=t1[1]
t3=t2.unsqueeze(1)
target=t3
#print(target.shape)
#print(target)

logits=torch.zeros(100,1)

for j in range(20000):
? ?logits=net_inst(data)
? ?loss1=loss(logits, target)
? ?optim.zero_grad()
? ?loss1.backward()
? ?optim.step()
? ?if j%5000==0:
? ? ? ?print('loss:', loss1)

torch.save(obj=net_inst.state_dict(),f="c:/net.pth")
torch.save(obj=optim.state_dict(),f="c:/optim.pth")

tt1=data.transpose(0,1)
tt2=target.transpose(0,1)
tt3=logits.transpose(0,1)
print(tt1.shape)
print(tt2.shape)
print(tt3.shape)
ttt1=tt1.detach().numpy()
ttt2=tt2.detach().numpy()
ttt3=tt3.detach().numpy()

plt.plot(ttt1, ttt2,'b.-')
plt.plot(ttt1, ttt3,'r.-')
plt.show()

time_end=time.time()
print('Totally cost',time_end-time_start,'s')



#print(net_inst.state_dict())
#torch.save(obj=net_inst.state_dict(),f="c:/net.pth")#保存神經(jīng)網(wǎng)絡(luò)的參數(shù)

Pytorch學(xué)習(xí)筆記12:Sinx正弦函數(shù)曲線擬合(參數(shù)保存和加載)的評(píng)論 (共 條)

分享到微博請(qǐng)遵守國(guó)家法律
贡嘎县| 桦川县| 长垣县| 南川市| 博罗县| 汤原县| 颍上县| 桦川县| 鄂伦春自治旗| 临沭县| 潞西市| 修文县| 蕉岭县| 逊克县| 沈阳市| 万盛区| 承德市| 朔州市| 淅川县| 新邵县| 南漳县| 枣阳市| 临高县| 崇义县| 休宁县| 沛县| 文昌市| 竹北市| 中江县| 绵竹市| 塘沽区| 平潭县| 科技| 凤台县| 盐边县| 长治市| 清新县| 炎陵县| 秭归县| 博乐市| 菏泽市|