最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會(huì)員登陸 & 注冊(cè)

Pytorch學(xué)習(xí)筆記12:Sinx正弦函數(shù)曲線擬合(參數(shù)保存和加載)

2021-06-30 09:09 作者:車科技2020  | 我要投稿

#需要import的lib
import numpy
import torch
import time
import platform
import cmath
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation

#import CV2
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

#需要import的lib
#運(yùn)行環(huán)境tesla k20/python 3.7/pytorch 1.20
print('——————————運(yùn)行環(huán)境——————————')
print('Python Version:',platform.python_version())
print('Torch Version:',torch.__version__)
#print('OpenCV Version:',CV2.__version__)
print('CUDA GPU check:',torch.cuda.is_available())
if(torch.cuda.is_available()):
print('CUDA GPU num:', torch.cuda.device_count())
n=torch.cuda.device_count()
while n > 0:
? print('CUDA GPU name:', torch.cuda.get_device_name(n-1))
? print('CUDA GPU capability:', torch.cuda.get_device_capability(n-1))
? print('CUDA GPU properties:', torch.cuda.get_device_properties(n-1))
? n -= 1
print('CUDA GPU index:', torch.cuda.current_device())
print('——————————運(yùn)行環(huán)境——————————')

time_start=time.time()

#device=torch.device('cuda:0')
#工程優(yōu)化應(yīng)用,GPU不一定更快
device=torch.device('cpu')

class net(torch.nn.Module):
? ?def __init__(self):
? ? ? ?super(net, self).__init__()
? ? ? ?self.model=torch.nn.Sequential(
? ? ? ? ? ?torch.nn.Linear(1,100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 100),
? ? ? ? ? ?torch.nn.ReLU(inplace=True),
? ? ? ? ? ?torch.nn.Linear(100, 1),
? ? ? ?)
? ?def forward(self,x) :
? ? ? ?x=self.model(x)
? ? ? ?return x

net_inst=net()
#print(net_inst.state_dict())
#net_inst.load_state_dict(torch.load("c:/net.pth"))#加載神經(jīng)網(wǎng)絡(luò)的參數(shù)
#print(net_inst.state_dict())
optim=torch.optim.SGD(net_inst.parameters(),lr=1e-4)
loss=torch.nn.MSELoss()
epoch=96

net_inst.load_state_dict(torch.load("c:/net.pth"))
optim.load_state_dict(torch.load("c:/optim.pth"))



sinx=torch.zeros(100,2)
sinx.requires_grad=False

for i in range(0,100):#對(duì)sinx進(jìn)行采樣,加了正態(tài)分布的噪聲
? ?sinx[i][0]=i*0.1
? ?mid=torch.from_numpy(np.random.randn(1)*0.03) ?# 注意numpy轉(zhuǎn)tensor
? ?sinx[i][1]=cmath.sin(i*0.1).real+mid

#print(sinx)
#print(sinx.shape)


#t1=sinx.transpose(0,1) #轉(zhuǎn)置一下
#print(t1)
#e=t1[0].detach().numpy()
#f=t1[1].detach().numpy()
#plt.plot(e, f,'b.-')
#plt.show()


data=torch.zeros(100,1)
#print(data)
#print(data.shape)
t1=sinx.transpose(0,1)
t2=data.transpose(0,1)
t2=t1[0]
t3=t2.unsqueeze(1)
#print(t3.shape)
data=t3
#print(data.shape)

target=torch.zeros(100,1)
t2=t1[1]
t3=t2.unsqueeze(1)
target=t3
#print(target.shape)
#print(target)

logits=torch.zeros(100,1)

for j in range(20000):
? ?logits=net_inst(data)
? ?loss1=loss(logits, target)
? ?optim.zero_grad()
? ?loss1.backward()
? ?optim.step()
? ?if j%5000==0:
? ? ? ?print('loss:', loss1)

torch.save(obj=net_inst.state_dict(),f="c:/net.pth")
torch.save(obj=optim.state_dict(),f="c:/optim.pth")

tt1=data.transpose(0,1)
tt2=target.transpose(0,1)
tt3=logits.transpose(0,1)
print(tt1.shape)
print(tt2.shape)
print(tt3.shape)
ttt1=tt1.detach().numpy()
ttt2=tt2.detach().numpy()
ttt3=tt3.detach().numpy()

plt.plot(ttt1, ttt2,'b.-')
plt.plot(ttt1, ttt3,'r.-')
plt.show()

time_end=time.time()
print('Totally cost',time_end-time_start,'s')



#print(net_inst.state_dict())
#torch.save(obj=net_inst.state_dict(),f="c:/net.pth")#保存神經(jīng)網(wǎng)絡(luò)的參數(shù)

Pytorch學(xué)習(xí)筆記12:Sinx正弦函數(shù)曲線擬合(參數(shù)保存和加載)的評(píng)論 (共 條)

分享到微博請(qǐng)遵守國(guó)家法律
毕节市| 河源市| 德化县| 亚东县| 西乌珠穆沁旗| 湖南省| 青田县| 工布江达县| 新建县| 荣成市| 乌鲁木齐县| 连平县| 巴楚县| 佛冈县| 濮阳市| 乐陵市| 阳新县| 太仆寺旗| 宜宾县| 延长县| 萨迦县| 北安市| 常山县| 南昌市| 娱乐| 德安县| 呼伦贝尔市| 肥城市| 龙门县| 和田市| 遂平县| 清河县| 武汉市| 方山县| 衡南县| 龙口市| 荥阳市| 永康市| 西藏| 南丹县| 平度市|