Swin Transformer源碼解析(二)
二、Transformer Block
數(shù)據(jù)經(jīng)過patch_embed后接著進(jìn)入TransformerBlock模塊,TransformerBlock主要包含四個(gè)部分:NormLayer==>W-MSA/SW-MSA==>NormLayer==>MLP,內(nèi)部各部分還使用殘差連接。
1. Norm Layer
NormLayer默認(rèn)使用LayerNorm,對(duì)最后一維歸一化,即模型的維度C
2. W-MSA/SW-MSA
窗口自注意力和移位窗口自注意力,將patch的特征圖劃分成一個(gè)個(gè)window,然后再在每個(gè)window內(nèi)部做自注意力,但是這樣window和window之間無交互,所以又使用了移位窗口自注意力。
2.1 window_partition
類似把圖片分成pacth的操作,這里將patch_embedding操作后的特征圖按window劃分,但不同的是patch_embedding中有個(gè)embedding的過程,是通過卷積實(shí)現(xiàn)的但是這里不需要,只是簡單的分成window。window_partition操作是將圖片的形式由(2,56,56,96)==>(2*8*8,7,7,96) 8*8就是window的數(shù)量,可以看出維度沒有變化,且內(nèi)部也沒有任何神經(jīng)元的連接。
至于為什么要乘以8*8,是因?yàn)楹竺嬉趙indow內(nèi)部做注意力,window與window之間無關(guān),所以直接乘到batch_size里面。
2.2 window_reverse
和window_partition的操作相反,將劃分后的windows轉(zhuǎn)回去,形狀一樣,對(duì)應(yīng)位置也一樣。因?yàn)樽鐾曜宰⒁饬χ笠兂芍暗男螤?,因?yàn)楹竺嬉鰌atch_merge,要轉(zhuǎn)成patch的格式
2.3 window_attention
3、Norm Layer
4、MLP
很簡單,就是全連接==>激活==>dropout==>全連接==>dropout
三、Patch Merge
就是將patch特征圖變小,但是維度增加