最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會(huì)員登陸 & 注冊(cè)

5.8 模型文件的讀寫

2023-02-15 11:23 作者:梗直哥丶  | 我要投稿

前面我們講了眾多方法,來(lái)訓(xùn)練好一個(gè)模型,讓模型能夠收斂,同時(shí)又不出現(xiàn)過(guò)擬合和欠擬合問(wèn)題。當(dāng)模型訓(xùn)練好以后,就需要對(duì)模型參數(shù)進(jìn)行保存,以便部署到不同的環(huán)境中去使用。通常,一個(gè)深度學(xué)習(xí)模型的訓(xùn)練需要消耗很長(zhǎng)的時(shí)間,比如幾天,如果訓(xùn)練過(guò)程中出現(xiàn)了問(wèn)題,無(wú)論是軟件問(wèn)題還是硬件問(wèn)題,又或者是外部因素比如突然斷電,造成的損失都是非常大的。因此,我們不僅要在最后對(duì)模型進(jìn)行保存,訓(xùn)練過(guò)程中也應(yīng)該定時(shí)保存中間結(jié)果,以避免損失。這一節(jié)我們就來(lái)學(xué)習(xí)一下模型保存的方法。

5.8.1?張量的保存和加載

在深度學(xué)習(xí)中,模型的參數(shù)一般是張量形式的。對(duì)于單個(gè)的張量,pytorch為我們提供了方便直接的函數(shù)來(lái)進(jìn)行讀寫。比如我們定義如下的一個(gè)張量a。

import?torch

a?=?torch.rand(10)
a

?

tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,
????????0.8449])

?

可以簡(jiǎn)單的用一個(gè)save函數(shù)去存儲(chǔ)這個(gè)張量a,這里需要我們給他起一個(gè)名字,我們就叫它tensor-a,把它放在model文件夾里。

torch.save(a,?'model/tensor-a')

這就完成了張量的寫入,這時(shí)我們可以在當(dāng)前路徑下的model文件夾里看到tensor-a這個(gè)文件。讀取同樣簡(jiǎn)單,只需要用一個(gè)load函數(shù)就可以完成張量的加載,傳入的參數(shù)是文件的路徑。

torch.load('model/tensor-a')

?

tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,
????????0.8449])

?

如果我們要存儲(chǔ)的不止一個(gè)張量,也沒(méi)有問(wèn)題,save和load函數(shù)同樣支持保存張量列表。先把張量數(shù)據(jù)存儲(chǔ)起來(lái)。

a?=?torch.rand(10)
b?=?torch.rand(10)
c?=?torch.rand(10)
torch.save([a,b,c],?'model/tensor-abc')
[a,b,c]

?

[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,
?????????0.5002]),
?tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,
?????????0.2336]),
?tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,
?????????0.1998])]

?

然后再把它讀取出來(lái)。

torch.load('model/tensor-abc')

?

[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,
?????????0.5002]),
?tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,
?????????0.2336]),
?tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,
?????????0.1998])]

?

對(duì)于多個(gè)張量,pytorch同樣支持以字典的形式來(lái)進(jìn)行存儲(chǔ)。比如我們建立一個(gè)字典tensor_dict,然后把它存起來(lái)。

a?=?torch.rand(10)
b?=?torch.rand(10)
c?=?torch.rand(10)
tensor_dict={'a':a,?'b':b,?'c':c}
torch.save(tensor_dict,?'model/tensor_dict')
tensor_dict

?

{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,
?????????0.9062]),
?'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,
?????????0.0902]),
?'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,
?????????0.4338])}

?

然后是讀取。

torch.load('model/tensor_dict')

?

{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,
?????????0.9062]),
?'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,
?????????0.0902]),
?'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,
?????????0.4338])}

?

張量的讀寫非常的簡(jiǎn)單,接下來(lái)我們看看模型整體參數(shù)的讀寫。

5.8.2?模型參數(shù)的保存和加載

模型參數(shù)一般都是張量形式的,雖然單個(gè)張量的保存和加載非常簡(jiǎn)單,但整個(gè)模型中包含著大大小小的若干張量,單獨(dú)保存這些張量會(huì)很困難。為了解決這個(gè)問(wèn)題,pytorch貼心的為我們準(zhǔn)備了內(nèi)置函數(shù)來(lái)保存加載整個(gè)模型參數(shù)。我們以5.2節(jié)的多層感知機(jī)為例,來(lái)看一下如何保存。

import?torch
import?torch.nn?as?nn
import?torch.optim?as?optim
from?torchvision?import?datasets, transforms

#?定義?MLP?網(wǎng)絡(luò)
class?MLP(nn.Module):
????def?__init__(self, input_size, hidden_size, num_classes):
????????super(MLP,?self).__init__()
????????self.fc1?=?nn.Linear(input_size, hidden_size)
????????self.relu?=?nn.ReLU()
????????self.fc2?=?nn.Linear(hidden_size, hidden_size)
????????self.fc3?=?nn.Linear(hidden_size, num_classes)
????
????def?forward(self, x):
????????out?=?self.fc1(x)
????????out?=?self.relu(out)
????????out?=?self.fc2(out)
????????out?=?self.relu(out)
????????out?=?self.fc3(out)
????????return?out

#?定義超參數(shù)
input_size?=?28?*?28??#?輸入大小
hidden_size?=?512??#?隱藏層大小
num_classes?=?10??#?輸出大?。悇e數(shù))

然后我們實(shí)例化一個(gè)MLP網(wǎng)絡(luò),并隨機(jī)生成一個(gè)輸入X,并計(jì)算出模型的輸出Y。

#?實(shí)例化?MLP?網(wǎng)絡(luò)
model?=?MLP(input_size, hidden_size, num_classes)
X?=?torch.randn(size=(2,?28*28))

然后同樣是調(diào)用save方法,我們把模型存儲(chǔ)到model文件夾里,取名叫做mlp.params。

torch.save(model.state_dict(),?'model/mlp.params')

接下來(lái),我們來(lái)讀取保存好的模型參數(shù),重新加載我們的模型。我們先把模型params參數(shù)讀取出來(lái),然后實(shí)例化一個(gè)模型,然后直接調(diào)用load_state_dict方法,傳入模型參數(shù)params。

params?=?torch.load('model/mlp.params')
model_load?=?MLP(input_size, hidden_size, num_classes)
model_load.load_state_dict(params)

?

<All keys matched successfully>

?

此時(shí)兩個(gè)模型model和model_load具有相同的參數(shù),我們給他輸入相同的X,看一下輸出結(jié)果。

output1?=?model(X)
output1

?

tensor([[ 0.0914, ?0.0178, ?0.0692, ?0.1486, ?0.1002, ?0.0057, -0.1099, ?0.1332,
??????????0.0241, ?0.1137],
????????[-0.0228, ?0.0446, ?0.1374, ?0.2009, -0.0978, -0.0831, -0.0193, ?0.1040,
??????????0.1097, ?0.1484]], grad_fn=<AddmmBackward0>)

?

output2?=?model_load(X)
output2

?

tensor([[ 0.0914, ?0.0178, ?0.0692, ?0.1486, ?0.1002, ?0.0057, -0.1099, ?0.1332,
??????????0.0241, ?0.1137],
????????[-0.0228, ?0.0446, ?0.1374, ?0.2009, -0.0978, -0.0831, -0.0193, ?0.1040,
??????????0.1097, ?0.1484]], grad_fn=<AddmmBackward0>)

?

可以看到,輸出的結(jié)果完全一致,說(shuō)明我們將參數(shù)成功地讀取并載入了模型中。

梗直哥提示:使用save保存的是模型參數(shù)而不是整個(gè)模型,因此在模型加載參數(shù)的時(shí)候,需要我們單獨(dú)指定模型架構(gòu),并且要保證模型結(jié)構(gòu)和保存的時(shí)候一致,否則可能會(huì)導(dǎo)致參數(shù)加載失敗。如果你想了解更多內(nèi)容,歡迎入群學(xué)習(xí)(加V: gengzhige99)

同步更新:Gitbub/公眾號(hào):梗直哥



?


5.8 模型文件的讀寫的評(píng)論 (共 條)

分享到微博請(qǐng)遵守國(guó)家法律
长沙县| 眉山市| 桐梓县| 手机| 甘泉县| 兴宁市| 德阳市| 德庆县| 石首市| 罗田县| 大冶市| 会昌县| 广平县| 外汇| 宁强县| 和硕县| 康马县| 祁阳县| 故城县| 平谷区| 杨浦区| 铜川市| 乡宁县| 克东县| 临潭县| 巫山县| 辽阳县| 习水县| 游戏| 广东省| 鹤岗市| 苏尼特左旗| 内江市| 遵化市| 廊坊市| 五台县| 图木舒克市| 永泰县| 潞城市| 施秉县| 慈利县|