5.8 模型文件的讀寫
前面我們講了眾多方法,來(lái)訓(xùn)練好一個(gè)模型,讓模型能夠收斂,同時(shí)又不出現(xiàn)過(guò)擬合和欠擬合問(wèn)題。當(dāng)模型訓(xùn)練好以后,就需要對(duì)模型參數(shù)進(jìn)行保存,以便部署到不同的環(huán)境中去使用。通常,一個(gè)深度學(xué)習(xí)模型的訓(xùn)練需要消耗很長(zhǎng)的時(shí)間,比如幾天,如果訓(xùn)練過(guò)程中出現(xiàn)了問(wèn)題,無(wú)論是軟件問(wèn)題還是硬件問(wèn)題,又或者是外部因素比如突然斷電,造成的損失都是非常大的。因此,我們不僅要在最后對(duì)模型進(jìn)行保存,訓(xùn)練過(guò)程中也應(yīng)該定時(shí)保存中間結(jié)果,以避免損失。這一節(jié)我們就來(lái)學(xué)習(xí)一下模型保存的方法。
5.8.1?張量的保存和加載
在深度學(xué)習(xí)中,模型的參數(shù)一般是張量形式的。對(duì)于單個(gè)的張量,pytorch為我們提供了方便直接的函數(shù)來(lái)進(jìn)行讀寫。比如我們定義如下的一個(gè)張量a。
import?torch
a?=?torch.rand(10)
a
?
tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,
????????0.8449])
?
可以簡(jiǎn)單的用一個(gè)save函數(shù)去存儲(chǔ)這個(gè)張量a,這里需要我們給他起一個(gè)名字,我們就叫它tensor-a,把它放在model文件夾里。
torch.save(a,?'model/tensor-a')
這就完成了張量的寫入,這時(shí)我們可以在當(dāng)前路徑下的model文件夾里看到tensor-a這個(gè)文件。讀取同樣簡(jiǎn)單,只需要用一個(gè)load函數(shù)就可以完成張量的加載,傳入的參數(shù)是文件的路徑。
torch.load('model/tensor-a')
?
tensor([0.1190, 0.7933, 0.9636, 0.5436, 0.2750, 0.3664, 0.4274, 0.9336, 0.1324,
????????0.8449])
?
如果我們要存儲(chǔ)的不止一個(gè)張量,也沒(méi)有問(wèn)題,save和load函數(shù)同樣支持保存張量列表。先把張量數(shù)據(jù)存儲(chǔ)起來(lái)。
a?=?torch.rand(10)
b?=?torch.rand(10)
c?=?torch.rand(10)
torch.save([a,b,c],?'model/tensor-abc')
[a,b,c]
?
[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,
?????????0.5002]),
?tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,
?????????0.2336]),
?tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,
?????????0.1998])]
?
然后再把它讀取出來(lái)。
torch.load('model/tensor-abc')
?
[tensor([0.0270, 0.8876, 0.4965, 0.5507, 0.9629, 0.7735, 0.9478, 0.7899, 0.7003,
?????????0.5002]),
?tensor([0.3628, 0.1818, 0.3137, 0.4671, 0.6445, 0.0022, 0.2800, 0.4637, 0.4888,
?????????0.2336]),
?tensor([0.8327, 0.3511, 0.2187, 0.6894, 0.9219, 0.7021, 0.1927, 0.0983, 0.6716,
?????????0.1998])]
?
對(duì)于多個(gè)張量,pytorch同樣支持以字典的形式來(lái)進(jìn)行存儲(chǔ)。比如我們建立一個(gè)字典tensor_dict,然后把它存起來(lái)。
a?=?torch.rand(10)
b?=?torch.rand(10)
c?=?torch.rand(10)
tensor_dict={'a':a,?'b':b,?'c':c}
torch.save(tensor_dict,?'model/tensor_dict')
tensor_dict
?
{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,
?????????0.9062]),
?'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,
?????????0.0902]),
?'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,
?????????0.4338])}
?
然后是讀取。
torch.load('model/tensor_dict')
?
{'a': tensor([0.1925, 0.3094, 0.8293, 0.3449, 0.3672, 0.3616, 0.9751, 0.7442, 0.8948,
?????????0.9062]),
?'b': tensor([0.6409, 0.1292, 0.1913, 0.0356, 0.0109, 0.8862, 0.9702, 0.4830, 0.2453,
?????????0.0902]),
?'c': tensor([0.4258, 0.1488, 0.8010, 0.0061, 0.9639, 0.2933, 0.3556, 0.0569, 0.9560,
?????????0.4338])}
?
張量的讀寫非常的簡(jiǎn)單,接下來(lái)我們看看模型整體參數(shù)的讀寫。
5.8.2?模型參數(shù)的保存和加載
模型參數(shù)一般都是張量形式的,雖然單個(gè)張量的保存和加載非常簡(jiǎn)單,但整個(gè)模型中包含著大大小小的若干張量,單獨(dú)保存這些張量會(huì)很困難。為了解決這個(gè)問(wèn)題,pytorch貼心的為我們準(zhǔn)備了內(nèi)置函數(shù)來(lái)保存加載整個(gè)模型參數(shù)。我們以5.2節(jié)的多層感知機(jī)為例,來(lái)看一下如何保存。
import?torch
import?torch.nn?as?nn
import?torch.optim?as?optim
from?torchvision?import?datasets, transforms
#?定義?MLP?網(wǎng)絡(luò)
class?MLP(nn.Module):
????def?__init__(self, input_size, hidden_size, num_classes):
????????super(MLP,?self).__init__()
????????self.fc1?=?nn.Linear(input_size, hidden_size)
????????self.relu?=?nn.ReLU()
????????self.fc2?=?nn.Linear(hidden_size, hidden_size)
????????self.fc3?=?nn.Linear(hidden_size, num_classes)
????
????def?forward(self, x):
????????out?=?self.fc1(x)
????????out?=?self.relu(out)
????????out?=?self.fc2(out)
????????out?=?self.relu(out)
????????out?=?self.fc3(out)
????????return?out
#?定義超參數(shù)
input_size?=?28?*?28??#?輸入大小
hidden_size?=?512??#?隱藏層大小
num_classes?=?10??#?輸出大?。悇e數(shù))
然后我們實(shí)例化一個(gè)MLP網(wǎng)絡(luò),并隨機(jī)生成一個(gè)輸入X,并計(jì)算出模型的輸出Y。
#?實(shí)例化?MLP?網(wǎng)絡(luò)
model?=?MLP(input_size, hidden_size, num_classes)
X?=?torch.randn(size=(2,?28*28))
然后同樣是調(diào)用save方法,我們把模型存儲(chǔ)到model文件夾里,取名叫做mlp.params。
torch.save(model.state_dict(),?'model/mlp.params')
接下來(lái),我們來(lái)讀取保存好的模型參數(shù),重新加載我們的模型。我們先把模型params參數(shù)讀取出來(lái),然后實(shí)例化一個(gè)模型,然后直接調(diào)用load_state_dict方法,傳入模型參數(shù)params。
params?=?torch.load('model/mlp.params')
model_load?=?MLP(input_size, hidden_size, num_classes)
model_load.load_state_dict(params)
?
<All keys matched successfully>
?
此時(shí)兩個(gè)模型model和model_load具有相同的參數(shù),我們給他輸入相同的X,看一下輸出結(jié)果。
output1?=?model(X)
output1
?
tensor([[ 0.0914, ?0.0178, ?0.0692, ?0.1486, ?0.1002, ?0.0057, -0.1099, ?0.1332,
??????????0.0241, ?0.1137],
????????[-0.0228, ?0.0446, ?0.1374, ?0.2009, -0.0978, -0.0831, -0.0193, ?0.1040,
??????????0.1097, ?0.1484]], grad_fn=<AddmmBackward0>)
?
output2?=?model_load(X)
output2
?
tensor([[ 0.0914, ?0.0178, ?0.0692, ?0.1486, ?0.1002, ?0.0057, -0.1099, ?0.1332,
??????????0.0241, ?0.1137],
????????[-0.0228, ?0.0446, ?0.1374, ?0.2009, -0.0978, -0.0831, -0.0193, ?0.1040,
??????????0.1097, ?0.1484]], grad_fn=<AddmmBackward0>)
?
可以看到,輸出的結(jié)果完全一致,說(shuō)明我們將參數(shù)成功地讀取并載入了模型中。
梗直哥提示:使用save保存的是模型參數(shù)而不是整個(gè)模型,因此在模型加載參數(shù)的時(shí)候,需要我們單獨(dú)指定模型架構(gòu),并且要保證模型結(jié)構(gòu)和保存的時(shí)候一致,否則可能會(huì)導(dǎo)致參數(shù)加載失敗。如果你想了解更多內(nèi)容,歡迎入群學(xué)習(xí)(加V: gengzhige99)
同步更新:Gitbub/公眾號(hào):梗直哥

?