Swin Transformer 自用筆記
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
原文鏈接:https://arxiv.org/abs/2103.14030
摘要:從ViT的理念出發(fā),提出分級(jí)的transformer架構(gòu)和移動(dòng)窗口的方式,優(yōu)化一般基于ViT的全局自注意力計(jì)算方式,證明了Transformer在多種視覺(jué)領(lǐng)域任務(wù)的適用性。


SwinTransformer 模型? 右圖為Block中的詳細(xì)2層結(jié)構(gòu)

這里走一下前向過(guò)程,理解模型:
a. 224*224*3 的rgb圖像經(jīng)過(guò)patch size(代碼實(shí)現(xiàn)中使用卷積操作)為 4*4 的切塊后(ViT中為16*16)變成 56*56*48的特征圖,這里4*4*3=48,224/4=56。
b. 然后經(jīng)過(guò)線性變換(上圖中的Linear Embedding)為56*56*C的特征圖,這里的C的值是個(gè)超參數(shù),而在實(shí)際計(jì)算中前面的56*56其實(shí)是一個(gè)一維的長(zhǎng)向量,也就是一張圖在此處本質(zhì)上是3136*C?的一個(gè)“句子”(后面進(jìn)入Encoder 和 Decoder時(shí)候也就和在NLP領(lǐng)域中的操作一致)。本文中C常用96。
c. 此時(shí)3136的句子長(zhǎng)度對(duì)于任務(wù)來(lái)說(shuō)不可接受,為了實(shí)現(xiàn)降低計(jì)算復(fù)雜度,在Swin Transformer Block中引入滑動(dòng)窗口操作(Shifted?Windows)來(lái)簡(jiǎn)化計(jì)算,這里詳細(xì)講解Swin操作。(注意該Block不改變張量的形狀。所以56*56*C的特征圖的形狀不發(fā)生改變)

????在原文中 3.2節(jié)?Shifted Window based Self-Attention?詳細(xì)解釋了使用該機(jī)制的動(dòng)機(jī)和其運(yùn)作方式?!癟he standard Transformer architecture and its adaptation for image classification both conduct global selfattention, where the relationships between a token and all other tokens are computed. The global computation leads to quadratic complexity with respect to the number of tokens, making it unsuitable for many vision problems requiring an immense set of tokens for dense prediction or to represent a high-resolution image。” 我們可以看到傳統(tǒng)的transformer架構(gòu)被移植到CV領(lǐng)域時(shí)候,做的self-attention操作并沒(méi)有改變,就是圖片有多長(zhǎng),自注意力就算多長(zhǎng),本文有個(gè)較為獨(dú)特的認(rèn)識(shí)見(jiàn)解,即對(duì)于圖片而言,我們是不是不用像文本那樣考慮這么長(zhǎng)的信息,因?yàn)樵趫D像中,可能各個(gè)token區(qū)域的相關(guān)性就沒(méi)有那么的強(qiáng),不值得花費(fèi)如此高昂的計(jì)算代價(jià)。特別是在密集型下游任務(wù)中,這樣的全局自注意力計(jì)算操作導(dǎo)致了巨量的運(yùn)算,甚至難以進(jìn)行。由此作者提出只在局部區(qū)域(以窗口為單位)做自注意力;同時(shí)使用滑動(dòng)窗口的方式,來(lái)建立全局聯(lián)系,總的思路是化整為2部分。復(fù)雜度的差異在原文中有計(jì)算,如下圖,差異就是M^2 與 h*w 能差多少的問(wèn)題,這里選我們上面的h*w=3136為例,在該問(wèn)題中 M =?7, M^2 = 49,可以看到差了2個(gè)數(shù)量級(jí)。

? ? 上面的解釋主要是對(duì)于“窗口”,下面解釋“滑動(dòng)”,在上面我們只用局部窗口做自注意力,那到最后的幾層網(wǎng)絡(luò)中一個(gè)特征圖的一個(gè)“像素”,只能映射到最初特征圖的某一小片區(qū)域,和全局就失去了聯(lián)系。為了彌補(bǔ)窗口自注意力造成的全局關(guān)聯(lián)信息的丟失,引入滑動(dòng)操作,大致就是第一層的transformer中做左上角的窗口自注意力,第二層的transformer中將窗口的劃分往右下方挪動(dòng)兩個(gè)patch后再做窗口自注意力計(jì)算,這樣也解釋了為什么上述的transformer block下面都有*2 以兩層為基礎(chǔ)block的原因;同時(shí)這樣堆疊操作后,最后計(jì)算得到的每個(gè)位置的數(shù)值就是具有全局信息的。

????????然而還有個(gè)問(wèn)題需要解決,為了使特征融合而進(jìn)行的滑動(dòng)窗口操作,會(huì)使得特征圖出現(xiàn)下圖中最左邊的切分方式,這樣不規(guī)則的區(qū)域不易于做自注意力操作,于是本文中進(jìn)行了下面的矩陣拼接操作,使得計(jì)算方式保持一致性,然而這樣拼接后并不能直接進(jìn)行計(jì)算,因?yàn)橄聢D中的不同顏色之間,并沒(méi)有位置的相關(guān)性,就是不可以與同框內(nèi)的其他色塊做交互計(jì)算(如果C是天空,而與C同框的那塊很有可能是地面或者其他在空間上與其并沒(méi)有直接關(guān)系的語(yǔ)義信息的區(qū)域)這里就需要用到掩膜操作,也就是下圖中的Masked MSA,具體掩膜操作可以查閱,就是讓不相關(guān)聯(lián)的小方塊只與自己做自注意力,加其他的權(quán)值大負(fù)數(shù)后,經(jīng)過(guò)sigmoid函數(shù)逼近0舍棄。

? ? 本文中也用到了相對(duì)位置編碼這個(gè)技術(shù),這里展開(kāi)講有很多,不過(guò)不是本文的重點(diǎn),后面會(huì)專門寫個(gè)位置編碼的筆記(絕對(duì)位置編碼,相對(duì)位置編碼)。
d. 由于transformer的架構(gòu)對(duì)于特征圖的形狀不會(huì)造成變化,這里為了和CNN的池化操作對(duì)齊,引入Patch Merging操作(詳解于上圖),具體操作是先對(duì)圖像做四次下采樣,生成4張?zhí)卣鲌D,將其堆疊并歸一化后,使用1*1卷積將其降維,這樣就將原來(lái)一個(gè)W*H*C的特征圖變?yōu)橐粋€(gè)W/2 * H/2 * 2C 的形狀,和CNN中的”大小減半?通道翻倍“對(duì)等。此操作使得張量形狀由56*56*C 變?yōu)?28*28*2C 。
e.?最后得到的一組特征圖 7*7*8C ,根據(jù)下游任務(wù)來(lái)選擇怎么將其操作,若是分類,可以用GlobalAverage pooling,分割檢測(cè)之類的也可以直接拿7*7*8C的張量去用。
小技術(shù)?testing time augmentation:測(cè)試時(shí)增強(qiáng),指的是在推理(預(yù)測(cè))階段,將原始圖片進(jìn)行水平翻轉(zhuǎn)、垂直翻轉(zhuǎn)、對(duì)角線翻轉(zhuǎn)、旋轉(zhuǎn)角度等數(shù)據(jù)增強(qiáng)操作,得到多張圖,分別進(jìn)行推理,再對(duì)多個(gè)結(jié)果進(jìn)行綜合分析,得到最終輸出結(jié)果。
文章的主要貢獻(xiàn)在于結(jié)果和對(duì)transformer架構(gòu)在CV任務(wù)的應(yīng)用上的思考,可細(xì)讀原文實(shí)驗(yàn)。
參考鏈接:
· Swin Transformer之PatchMerging原理及源碼_patch merging_白話先生的博客-CSDN博客
· 簡(jiǎn)單聊聊 Test Time Augmentation - 知乎用戶yP1hFG的文章 - 知乎 https://zhuanlan.zhihu.com/p/383005472