5.7 梯度消失和梯度爆炸
5.7.1?什么是梯度消失和梯度爆炸
前面我們講了反向傳播鏈式法則,通過鏈式法則,可以將誤差從末層逐層向前傳遞,最終達到調整各層模型參數(shù)的目的。鏈式法則的形式如下:

大家看到這個公式有沒有發(fā)現(xiàn)什么問題?鏈式法則是一個連乘的形式,當模型層數(shù)淺的時候可能看不出來,隨著層數(shù)的加深,梯度將以指數(shù)形式變化。
當每一層的梯度都小于1的時候,隨著網絡層數(shù)加深,梯度將趨近于0,這就是梯度消失。相反,當每一層的梯度都大于1的時候,隨著網絡層數(shù)加深,梯度將趨近于正無窮,這就是梯度爆炸。梯度消失問題和梯度爆炸問題一般隨著網絡層數(shù)的增加會變得越來越明顯,他們在本質原理上其實是一樣的。
不穩(wěn)定梯度會威脅到我們優(yōu)化算法的穩(wěn)定性。?梯度爆炸發(fā)生,參數(shù)更新過大,破壞了模型的穩(wěn)定收斂;梯度消失發(fā)生時,參數(shù)更新過小,則是更新時幾乎不會移動,導致模型無法學習。
無論是梯度趨近0還是正無窮,都會導致我們的模型訓練失敗。因此,解決梯度消失和梯度爆炸問題是深度學習的必修課。接下來我們就來實際看一下梯度消失和梯度爆炸出現(xiàn)的原因。
5.7.2?梯度消失
梯度消失是指當梯度傳遞到深層時,由于參數(shù)的初始值或激活函數(shù)的形式,梯度變得非常小,從而導致訓練難以收斂。這種情況通常發(fā)生在使用?sigmoid 或 tanh 作為激活函數(shù)的情況下,因為這兩個函數(shù)在輸入較大時,梯度會變得非常小。
下面代碼繪制了sigmoid函數(shù)和它所對應的梯度函數(shù)。
import?torch
import?matplotlib.pyplot?as?plt
x?=?torch.arange(-8.0,?8.0,?0.1, requires_grad=True)
y?=?torch.sigmoid(x)
y.backward(torch.ones_like(x))
plt.plot(x.detach().numpy(), y.detach().numpy(), label?=?'sigmoid')
plt.plot(x.detach().numpy(), x.grad.numpy() ,linestyle=':', label?=?'gradient')
plt.legend()
plt.show()

可以看出,當sigmoid函數(shù)的輸入很大或是很小時,它的梯度都是一個遠遠小于1的數(shù),非常趨近于0。當反向傳播通過許多層時,除非每一層的sigmoid函數(shù)的輸入都恰好接近于零,否則整個乘積的梯度可能會消失。 當我們的網絡有很多層時,除非我們很小心,否則在某一層可能就會切斷梯度。 因此,現(xiàn)在大家更愿意選擇更穩(wěn)定的ReLU系列函數(shù)作為激活函數(shù)。
5.7.3?梯度爆炸
與之相反的則是梯度爆炸問題。梯度爆炸是指當梯度傳遞到深層時,由于參數(shù)的初始值或激活函數(shù)的形式,梯度變得非常大,從而導致訓練難以收斂。為了更直觀的看到這個問題,我們用代碼生成了100個高斯隨機矩陣,并將這些矩陣與一個矩陣相乘,這個矩陣相當于模型的初始參數(shù)矩陣。我們設置方差為1,看一下運行結果。
Mat?=?torch.normal(0,?1, size=(5,5))
print('初始參數(shù)矩陣',Mat)
for?i?in?range(100):
????Mat?=?torch.mm(Mat, torch.normal(0,?1, size=(5,?5)))
print('計算后矩陣', Mat)
初始參數(shù)矩陣?tensor([[ 1.5515, -0.5073, -2.4602, -1.9177, ?0.4678],
????????[ 0.2328, -0.6902, ?1.4514, -2.1545, -1.1679],
????????[ 0.5522, -0.4998, ?0.9615, ?0.0644, -0.0136],
????????[ 0.9784, ?0.1601, -0.5344, ?0.7700, ?0.7958],
????????[ 0.5089, -1.4568, ?0.8654, -1.0948, ?1.3430]])
計算后矩陣?tensor([[-1.0448e+29, ?7.7966e+28, ?9.1455e+28, ?8.5349e+27, -2.6260e+29],
????????[ 4.9579e+28, -3.6996e+28, -4.3397e+28, -4.0499e+27, ?1.2461e+29],
????????[ 3.3363e+28, -2.4895e+28, -2.9203e+28, -2.7253e+27, ?8.3852e+28],
????????[-3.9100e+28, ?2.9176e+28, ?3.4225e+28, ?3.1940e+27, -9.8272e+28],
????????[ 6.7301e+28, -5.0220e+28, -5.8909e+28, -5.4976e+27, ?1.6915e+29]])
可以看到在經過100次乘法運算后,矩陣內的值發(fā)生了爆炸性增長,這就是梯度爆炸。這種情況其實是由于我們的參數(shù)初始化方法所導致的。
5.7.4?解決方法
解決梯度消失和梯度爆炸問題的方法很多,這里講兩種常見的方法,梯度裁剪和使用Relu函數(shù)。
梯度裁剪(正則化)
梯度裁剪主要是針對梯度爆炸提出。其思想也比較簡單,訓練時候設置一個閾值,梯度更新的時候,如果梯度超過閾值,那么就將梯度強制限制在該范圍內,這時可以防止梯度爆炸。
權重正則化(weithts regularization)也可以解決梯度爆炸的問題,其思想就是我們常見的正則方式。


import?torch
import?matplotlib.pyplot?as?plt
x?=?torch.arange(-8.0,?8.0,?0.1, requires_grad=True)
y?=?torch.relu(x)
y.backward(torch.ones_like(x))
plt.plot(x.detach().numpy(), y.detach().numpy(), label?=?'relu')
plt.plot(x.detach().numpy(), x.grad.numpy() ,linestyle=':', label?=?'gradient')
plt.legend()
plt.show()
?

其他方法
決梯度消失和梯度爆炸的方法還有很多,比如下面這些:
1.使用 Batch Normalization 層,這樣可以縮小梯度的范圍,避免梯度爆炸的問題。
2.初始化權重參數(shù),使用更加合理的初始化方法。
3.使用更加穩(wěn)定的優(yōu)化算法,如 Adam 優(yōu)化器或 RMSprop 優(yōu)化器,這些優(yōu)化器可以自動調整學習率,使得訓練更加穩(wěn)定。
4.增加模型的寬度或使用殘差連接,這樣可以緩解深層網絡中的梯度消失問題。
這些方法這里暫不做深入講解,在今后的學習中大家會一點一點的接觸到的。大家加油!
梗直哥提示:梯度消失和梯度爆炸是深度學習中非?;A的問題,建議初學者在理解的基礎上,親自動手實踐一下,相信你會記憶得更加深刻。如果你想了解更多內容,歡迎入群學習(加V: gengzhige99)
深度學習視頻課程已經上線,前100名報名的同學立減590元,歡迎大家訂閱。

?