pytorch中的鉤子(Hook)
首先明確一點,有哪些hook?
1.?torch.autograd.Variable.register_hook?(Python method, in Automatic differentiation package
2.?torch.nn.Module.register_backward_hook?(Python method, in torch.nn)
3.?torch.nn.Module.register_forward_hook
第一個是register_hook,是針對Variable對象的,后面的兩個:register_backward_hook和register_forward_hook是針對nn.Module這個對象的。
也就是說,這個函數(shù)是擁有改變梯度值的威力的!
至于register_forward_hook和register_backward_hook的用法和這個大同小異。只不過對象從Variable改成了你自己定義的nn.Module。
當你訓(xùn)練一個網(wǎng)絡(luò),想要提取中間層的參數(shù)、或者特征圖的時候,使用hook就能派上用場了
相當于插件??梢詫崿F(xiàn)一些額外的功能,而又不用修改主體代碼。把這些額外功能實現(xiàn)了掛在主代碼上,所以叫鉤子,很形象。
一、Hook函數(shù)概念
Hook 是?PyTorch?中一個十分有用的特性。利用它,我們可以不必改變網(wǎng)絡(luò)輸入輸出的結(jié)構(gòu),方便地獲取、改變網(wǎng)絡(luò)中間層變量的值和梯度。這個功能被廣泛用于可視化神經(jīng)網(wǎng)絡(luò)中間層的 feature、gradient,從而診斷神經(jīng)網(wǎng)絡(luò)中可能出現(xiàn)的問題,分析網(wǎng)絡(luò)有效性。
Hook函數(shù)機制:不改變主體,實現(xiàn)額外的功能,像一個掛件一樣;