Transformer在Masked Self-attention中做的什么?(實(shí)現(xiàn)細(xì)節(jié))
Transformer是一個(gè)訓(xùn)練與預(yù)測(cè)相互獨(dú)立的模型,訓(xùn)練和預(yù)測(cè)的不同主要反應(yīng)在masked self-attention模塊的代碼上,經(jīng)過幾個(gè)小時(shí)的研究終于搞懂,下面對(duì)該部分的實(shí)現(xiàn)細(xì)節(jié)記錄。需要注意的是接下來提到的全部代碼并非來自原始transformer項(xiàng)目,因此可能并不具有普適性,僅作為一種可行的思路介紹。
1.transformer中的訓(xùn)練與預(yù)測(cè)策略
在訓(xùn)練階段一般采用teacher forcing的策略,所謂teacher forcing,就是直接將ground truth作為decoder的input,通過masked self-attention計(jì)算各時(shí)間步的context feature,然后將context與encoder_out共同送入cross-attention中進(jìn)行跨模態(tài)建模。其中在Mashed SA內(nèi)部使用mask策略避免計(jì)算context的過程看到待預(yù)測(cè)結(jié)果(即避免看到gt中當(dāng)前時(shí)間步之后的單詞)。
在預(yù)測(cè)階段常用的有兩種策略,一種被稱為beam_search,另一種忘記了名字,預(yù)測(cè)策略不是本文的重點(diǎn),因此不對(duì)這方面做更多敘述,僅介紹不使用任何策略的最普通的預(yù)測(cè)方式。做預(yù)測(cè)時(shí),需要從無到有進(jìn)行語言的生成,即需要設(shè)置若干個(gè)時(shí)間步,每個(gè)時(shí)間步中將已預(yù)測(cè)出的結(jié)果作為輸入,以此為指導(dǎo)預(yù)測(cè)下一個(gè)單詞,這個(gè)過程與RNN非常相似,因此transformer在預(yù)測(cè)時(shí)無法體現(xiàn)如訓(xùn)練階段一般的高并行優(yōu)勢(shì)。
2.直覺與實(shí)現(xiàn)的區(qū)別
從直覺上看,預(yù)測(cè)與訓(xùn)練并無太大不同,不過要增加一個(gè)有關(guān)時(shí)間步的訓(xùn)練而已,而在實(shí)現(xiàn)中卻并非如此,以時(shí)間步step=3來舉例:
2.1直覺
當(dāng)step=3時(shí),已經(jīng)完成了對(duì)前面3個(gè)單詞的預(yù)測(cè),計(jì)劃預(yù)測(cè)第4個(gè)單詞,因此需要將前3個(gè)單詞作為decoder的input,假設(shè)batch_size=10,那么從直覺上講送入decoder中input形狀應(yīng)該為(10,3),其中3表示已知單詞的索引。后續(xù)在decoder中將通過word_embedding方法將單詞索引轉(zhuǎn)化為dim=512的向量,經(jīng)過word_embedding處理后即可送入decoder_layers,其形狀為(10,3,512).
在第1個(gè)decoder_layer中,首先使用masked self-attention進(jìn)行處理,將形狀為(10,3,512)的input作為query、key、value,經(jīng)過運(yùn)算后得到(10,3,512),記為context;接下來,使用cross-attention進(jìn)行處理,context作為query,encoder_out作為key和value,經(jīng)過運(yùn)算后得到形狀為(10,3,512)的張量。
在第2個(gè)decoder_layer中,進(jìn)行同樣的操作;第3個(gè)decoder_layer中,進(jìn)行同樣的操作。最終得到形狀為(10,3,512)的張量,取(:,2,:)過Linear(512,voc_len),而后做softmax歸一化,作為預(yù)測(cè)結(jié)果的概率。
2.2實(shí)現(xiàn)
當(dāng)閱讀代碼時(shí),卻發(fā)現(xiàn)在實(shí)現(xiàn)細(xì)節(jié)方面并非如上面敘述的一般。
其中第1行暫時(shí)忽略,稍后解釋。第2行和第3行顯然是根據(jù)計(jì)劃生成的最大序列長(zhǎng)度建立for循環(huán),在每個(gè)循環(huán)中進(jìn)行一次預(yù)測(cè)。進(jìn)行預(yù)測(cè)時(shí)調(diào)用了self.iter方法,在iter方法中真正做預(yù)測(cè)的代碼如下
此處的self.model表示transformer類的對(duì)象,self.model.step是定義在transformer中的一個(gè)方法,實(shí)際上是對(duì)transformer.decoder進(jìn)行調(diào)用,在self.model.step的最后一行代碼為
顯然,在調(diào)用transformer.decoder的傳入?yún)?shù)中,it即為decoder的input,self.enc_output為編碼器的輸出。依照直覺,此處的it的形狀應(yīng)該隨著時(shí)間步的變化而變化,例如當(dāng)step=1時(shí),由于已經(jīng)預(yù)測(cè)出的單詞為一個(gè),故it.shape=(bs,1);當(dāng)step=3時(shí),由于已經(jīng)預(yù)測(cè)出了三個(gè)單詞,故it.shape=(bs,3).
然而,經(jīng)過對(duì)現(xiàn)有代碼的調(diào)試發(fā)現(xiàn)it的形狀始終為(bs,1),不隨時(shí)間步step的增大而變化,這與直覺不符,因?yàn)檫@相當(dāng)于每個(gè)時(shí)間步中僅將上一個(gè)時(shí)間步的預(yù)測(cè)結(jié)果進(jìn)行傳入,而無法關(guān)注到以往預(yù)測(cè)到的全部單詞,這與預(yù)期嚴(yán)重不符。(經(jīng)過研究后發(fā)現(xiàn)雖然it的shape始終為(bs,1),但在實(shí)際預(yù)測(cè)時(shí)還是對(duì)所有已知的單詞進(jìn)行了考慮,下面對(duì)方法進(jìn)行介紹。)
經(jīng)過觀察,發(fā)現(xiàn)在decoder layer中的時(shí)間代碼為,顯然其可以分為三個(gè)部分,即Masked self-attention,Cross-attention,F(xiàn)FN,其中masked中作為query,key,value的傳入?yún)?shù)均為input,其形狀為(bs,1,dim)
將參數(shù)傳入上述的self.self_att后,對(duì)于接收到的query、key、value使用下面的代碼,再次進(jìn)行一次調(diào)用,在這次調(diào)用的self.attention方法內(nèi)才會(huì)真正進(jìn)行softmax(QK)V的注意力運(yùn)算。
值得注意的是,在此處調(diào)用self.attention時(shí)傳入的queries、keys、value形狀分別為(以step=3為例):(bs,1,dim)、(bs,3,dim)、(bs,3,dim),即在key和value處神奇的對(duì)已有的全部單詞做了考慮,而在上一步中分明將同一個(gè)形狀為(bs,1,dim)的input同時(shí)作為self.self_att傳入?yún)?shù)的query、key、value,這中間發(fā)生了什么?
可以看到在接收到keys和value后,先令其與self.running_keys和self.running_values拼接,然后再賦值給keys和values。由于每個(gè)時(shí)間步中均會(huì)進(jìn)行這樣的操作,因此當(dāng)處于step=3時(shí),self.running_keys和self.running_value中將會(huì)存儲(chǔ)有第1個(gè)單詞和第2個(gè)單詞,新接收到的keys和values中儲(chǔ)存有第三個(gè)單詞,將二者拼接即為全部的已知單詞。
到這里,已經(jīng)介紹完了transformer在進(jìn)行預(yù)測(cè)時(shí)的操作,下面補(bǔ)充一些細(xì)節(jié)
3.補(bǔ)充細(xì)節(jié)
3.1需要在對(duì)每個(gè)batch預(yù)測(cè)前將self.running_keys和self.running_values置為(bs,0,dim)的形式,如何做到的?
在初始化函數(shù)init中使用上述代碼進(jìn)行定義,其中self.register_state的定義如下
可以看到,會(huì)將第一個(gè)參數(shù)name保存到列表self._state_names中,將第二個(gè)參數(shù)default保存到字典self._state_defaults中。然后調(diào)用nn.model中的self.register_buffer方法,創(chuàng)建一個(gè)名為self.name的變量,并使用default為其賦值,同時(shí)令其梯度為False,即不會(huì)被優(yōu)化。
通過上述代碼,可以得到形狀為(0,dim)的張量,還缺少batch_size維度,同時(shí)上面代碼僅說明了如何進(jìn)行初始化,而為介紹如何在對(duì)每個(gè)batch預(yù)測(cè)前進(jìn)行置0.
在本文貼出的第一段代碼中,曾說暫時(shí)忽略第一行。此處再次將該段代碼貼出,以針對(duì)第一行進(jìn)行說明。
該行代碼對(duì)應(yīng)的操作如下
通過遞歸的方式令所有的child調(diào)用enable_statefullness,然后調(diào)用self._init_states()方法,代碼如下
其效果是利用self._state_defaults中儲(chǔ)存的數(shù)據(jù)為self._buffers賦值,上面我們介紹register_state時(shí)提到過利用(0,dim)的張量為self._state_defaults列表賦值,后續(xù)沒有對(duì)其修改的操縱,因此其中儲(chǔ)存的數(shù)據(jù)始終為(0,dim)的張量。
self._buffers中儲(chǔ)存的是self.running_keys和self.running_value的值,在這個(gè)init函數(shù)中利用self.state_defaults為self._buffers賦值,實(shí)際上就重置self.running_keys和self.running_value中儲(chǔ)存的數(shù)據(jù)。
因此,在對(duì)時(shí)間步建立for循環(huán)的外面進(jìn)行這個(gè)操作,就可以實(shí)現(xiàn)對(duì)每個(gè)batch預(yù)測(cè)前將self.running_keys和self.running_values置為(bs,0,dim)的形式。
3.2 self._is_stateful與self.can_be_stateful
之前提到的代碼中,有寫分支if需要對(duì)這兩個(gè)值做判別,下面介紹這兩個(gè)值的賦值情況。
self.can_be_stateful是初始化對(duì)象時(shí)的傳入?yún)?shù),僅在self-attention用作masked SA時(shí)將其設(shè)置為True,其余時(shí)候均為False。
self._is_stateful在上述的enable_statefullness方法中被置為True,位置通常為對(duì)每個(gè)batch的時(shí)間步建立循環(huán)之前。
(完)