[Quant 1.7] 從數(shù)理的角度理解一下RNN (2)
我先放資源
第一個是MIT的一本深度學(xué)習(xí)教材的RNN章節(jié),作者是Ian Goodfellow & Yoshua Bengio &?Aaron Courville
https://www.deeplearningbook.org/contents/rnn.html#pf7
第二個是吳恩達(dá) (Andrew Ng)博士的網(wǎng)課

吳博士的網(wǎng)課主要從NLP的角度出發(fā)介紹RNN,里面有幫助的理解的例子。但是它的notation我看不太習(xí)慣。然后那本教材我感覺是給更厲害的quant或者數(shù)據(jù)科學(xué)/AI專業(yè)的學(xué)生讀的,細(xì)節(jié)非常之多,我個人啃起來感覺吃力。這篇想把零零散散的知識點總結(jié)一下。(以下內(nèi)容可能有很多理解錯誤的地方,后續(xù)我發(fā)現(xiàn)問題的話會編輯。然后我會有很多中英文穿插著,畢竟面試可能會被問到,我也借這個專欄強化一下記憶。)

4. RNN和它的兩個弱點:梯度消失和梯度爆炸
這一步基本上我見過的所有視頻都會一筆帶過。他們只會講RNN的梯度消失和梯度爆炸會影響SGD optimizer對參數(shù)的更新,從而引出不存在這兩種問題的GRU和LSTM。但我認(rèn)為這部分的推導(dǎo)是整個RNN最高難最有趣的地方。仔細(xì)研究一下不僅能讓你復(fù)習(xí)很多忘掉的數(shù)學(xué),還能對RNN的運行過程有更深刻的理解。
這里我會用上面介紹的第一種RNN來作為例子求梯度。

我們要求的梯度,就是當(dāng)前的總損失對參數(shù)的梯度。求導(dǎo)的過程要用到我之前兩個專欄的內(nèi)容


所謂backward propogation,就是逆著unfolded computational graph箭頭方向求導(dǎo)的過程。從圖中我們能看出來,想對參數(shù)求梯度,
是必經(jīng)之路,所以我們先嘗試求一下總損失
對
的梯度,進(jìn)而對
的梯度。下一步再求解
或
對這5個參數(shù)的梯度。

假設(shè)RNN已經(jīng)訓(xùn)練到了時間,那么當(dāng)前的總損失就是
對
的梯度可以寫成
這個用到了之前說過的鏈?zhǔn)椒▌t。只要我們有關(guān)于梯度輸入和輸出維度的規(guī)定,梯度的鏈?zhǔn)椒▌t和求導(dǎo)的鏈?zhǔn)椒▌t是一致。在等號右邊的三個梯度中,第一個根據(jù)上面的總損失公式就等于1;第二個需要求一個Vector-to-scalar函數(shù)的梯度,因為和
的函數(shù)關(guān)系不涉及矩陣運算,我們只能用梯度的定義來求。之前我們假設(shè)我們的RNN用的是交叉熵?fù)p失函數(shù)cross entropy loss function,再假設(shè)RNN的輸出是n維列向量,因此;
第三個是求一個Vector-to-vector函數(shù)的梯度,我們需要求的是Jacobian矩陣。和
的關(guān)系是softmax函數(shù),即
Jacobian矩陣的第i行j列可以表示為
所以的計算實質(zhì)上是一個
的梯度向量乘以一個
的Jacobian矩陣
我們現(xiàn)在成功獲得了在
上面的梯度,下一步就可以求
在別的參數(shù)上的梯度了。
我們先來看,?
只從
上面遺傳下來。因為RNN中有明確的
關(guān)于
的矩陣表達(dá)式,所以我們可以用之前學(xué)過的Vector-to-vector求梯度的方法來計算:
所以
但是問題來了,是只從
上面遺傳下來,但是
不是。從圖中可以看出來,
的傳播方向不僅有
,還有
。所以為了鏈?zhǔn)椒▌t能夠連接到
,我們需要找到
在
上的梯度
這個梯度的求解過程可以參考Quant 1.5里面的最后一個例子
因此我們能夠得到關(guān)于梯度的迭代公式
我們逐步迭代,直到,我們就能獲得
關(guān)于
梯度的確定的值。

所有的準(zhǔn)備工作都做完了,我們現(xiàn)在獲得了關(guān)于和
的全部梯度,需要繼續(xù)利用鏈?zhǔn)椒▌t求
的梯度:
多說一句以防大家忘記,這里的鏈?zhǔn)椒▌t的情形比前面的要復(fù)雜一些。之前例子中的鏈?zhǔn)椒▌t都是單鏈的,也就是說從復(fù)合函數(shù)的自變量變換到因變量的過程只有一個路徑,但是在RNN中時間的存在使得
到損失
的變換在每一個時間
都有一條平行的路徑。多鏈的鏈?zhǔn)椒▌t需要我們把每個平行路徑的梯度加和來求總梯度。
然后上面我們已經(jīng)求的了五個梯度中最簡單的兩個。但是在求另外三個的時候就會遇到一個問題...按照正常的鏈?zhǔn)椒▌t,求關(guān)于的梯度的時候應(yīng)該有
但是根據(jù)我目前學(xué)到的知識,這個我是算不出來的,因為是一個Matrix-to-vector函數(shù)的梯度,首先它的梯度求出來應(yīng)該是一個三維tensor的形式,其次即使它的梯度有矩陣的表達(dá)形式,但是我認(rèn)為這種表達(dá)形式放入chain rule進(jìn)行計算是錯誤的。因此,上面的梯度無法這么求。
那我們應(yīng)該怎么辦呢?我們雖然不確定Matrix-to-vector函數(shù)的梯度能不能插入鏈?zhǔn)椒▌t里面,但是我們能確定Matrix-to-scalar函數(shù)肯定是可以的!那我們可以不可以修改一下路徑,使得由到
的梯度變成Matrix-to-vector函數(shù)的梯度呢?可以!我們只要把
的每一項拿出來作為中間項過度就可以了。大致變化差不多下面這張圖,需要把
展開:

最后按照右邊的圖片,關(guān)于U的梯度就變成了
接下來是全篇最蹩腳的地方:如何化簡上面的等式?首先我們要求一個Matrix-to-scalar函數(shù)的梯度
這個就不是很好理解。我們先看一下矩陣變換到
的過程:第一步對于矩陣U進(jìn)行線性變換
,對得到的向量只保留第i項
,第二部是
。所以根據(jù)鏈?zhǔn)椒▌t:
關(guān)于后面這個Matrix-to-scalar函數(shù)的梯度,因為這個函數(shù)里面的沒有trace,所以我們只能通過Matrix-to-scalar函數(shù)梯度的定義來理解了。因為我們最后只取向量的第i項,所以矩陣U只有第i行和有關(guān)系。將
對矩陣
的第
行各項求導(dǎo),得到的就是向量
每一項。因此,得到的梯度矩陣除了第i行是
以外,其余項都是0。
下一步依然是難理解的一步,我們想要取消內(nèi)部的求和符號就需要把所有關(guān)于
都加起來:
所以
如果線代學(xué)的不是特別好的話,從求和符號到矩陣乘法之間的轉(zhuǎn)換可能需要比較長的時間消化,這里大家可以多總結(jié)一下。
然后關(guān)于矩陣的梯度和U十分類似,我就不再推一遍了

梯度求完了,我們現(xiàn)在要討論的是梯度為什么會消失 (gradient vanishing),梯度又為什么會爆炸 (gradient exploding)。我在網(wǎng)上能找到的所有視頻關(guān)于這兩點都是泛泛而談,包括上面的教材和視頻,沒有一個解釋的很詳細(xì)的,我感覺很無語。所以這里試著解釋一下:
我們從上面一系列的梯度的推導(dǎo)中可以看出,每一個time step對參數(shù)的梯度都是有貢獻(xiàn)的。但是如果我們的神經(jīng)網(wǎng)絡(luò)深度足夠大的話,早期的time step對梯度的貢獻(xiàn)會過小或者過大。這就是梯度消失/爆炸。
難道早期的time step對梯度的貢獻(xiàn)小不是應(yīng)該的嗎?不一定!吳教授的視頻中舉了一個例子,假如我們現(xiàn)在想要預(yù)測一句不完整的話的下一個詞是什么:
The cats, which already ate,?were?full.
The cat, which already ate,?was full.
我們在預(yù)測這個was/were的時候,最重要的參考信息不是中間的非限制性定語從句,而是前面的cat/cats。因此,模型參數(shù)應(yīng)該在cat/cats這個詞對應(yīng)time step的hidden state上面的梯度很大,從而可以顯著的修改神經(jīng)網(wǎng)絡(luò)中的矩陣參數(shù)。然而,如果這個cat/cats離was/were太遠(yuǎn)了,這個梯度就會產(chǎn)生爆炸或者消失,導(dǎo)致我們最后訓(xùn)練出來的神經(jīng)網(wǎng)絡(luò)參數(shù)沒有預(yù)測價值。畢竟語法上非限制性定語從句可以無限長。
那么如何從梯度的公式上理解梯度消失和梯度爆炸呢?根據(jù)我們之前推導(dǎo)過的梯度公式,可以知道損失函數(shù)關(guān)于各個矩陣/向量參數(shù)的梯度大多與有關(guān)系,而
還有遞推公式:
我們多寫出來一步
這是兩個time step的遞推公式,我們可以知道當(dāng)time step的間隔足夠大的時候,多個矩陣會累乘到一起作為晚期time step梯度的系數(shù)。這就導(dǎo)致早期的time step梯度受矩陣
的影響很大。就像是
和
差距很大一樣,梯度下降的進(jìn)程會受參數(shù)初始化的影響,如果
初始化太大會造成梯度爆炸,太小則會造成梯度消失。