不僅搞定“梯度消失”,還讓CNN更具泛化性:港科大開源深度神經(jīng)網(wǎng)絡(luò)訓練新方法
港科大李鐸、陳啟峰提出一種優(yōu)化模型訓練、提升模型泛化性能與模型精度方法。
◎作者系極市原創(chuàng)作者計劃特約作者Happy歡迎大家聯(lián)系極市小編(微信ID:fengcall19)加入極市原創(chuàng)作者行列

paper:?https://arxiv.org/abs/2003.10739
code:?https://github.com/d-li14/DHM
該文是港科大李鐸、陳啟峰提出的一種優(yōu)化模型訓練、提升模型泛化性能與模型精度的方法,相比之前Deeply-Supervised Networks方式,所提方法可以進一步提升模型的性能。值得一讀。
Abstract
時間見證了深度神經(jīng)網(wǎng)絡(luò)的深度的迅速提升(自LeNet的5層到ResNet的上千層),但尾端監(jiān)督的訓練方式仍是當前主流方法。之前有學者提出采用深度監(jiān)督(Deeply-supervised,DSN)方式緩解深度網(wǎng)絡(luò)的訓練難度問題,但是它不可避免的會影響深度網(wǎng)絡(luò)的分層特征表達能力,同時會導(dǎo)致前后矛盾的優(yōu)化目標。
作者提出一種動態(tài)分層模仿機制(Dynamic Hierarchical Mimicking,一種廣義特征學習機制)加速CNN訓練同時使其具有更強的泛化性能。所提方法部分受DSN啟發(fā),對給定神經(jīng)網(wǎng)絡(luò)的中間特征進行巧妙的設(shè)置邊界分支(side branches)。每個分支可以動態(tài)的出現(xiàn)在主分支的特定位置,它不僅可以保留骨干網(wǎng)絡(luò)的特征表達能力,同時還可以研其通路產(chǎn)生更多樣性的特征表達。與此同時,作者提出采用概率預(yù)測匹配損失進一步提升多分支的多級交互影響,它可以確保優(yōu)化過程的魯棒性,同時具有更好的泛化性能。
最后作者在分類與實例識別任務(wù)上驗證了所提方法的性能,均可取得一致性的性能提升。
Method
該部分內(nèi)容首先簡單介紹一下深度監(jiān)督及存在的問題,最后給出所提方法。由于該部分內(nèi)容公式較多,文字較多,故這里僅進行粗略的介紹,在后面對進行一些個人理解分析。
Analysis of Deep Supervision

通過上述上述訓練方式,中間層不僅可以從頂層損失獲取梯度信息,還可以從分支損失獲取提取信息,這使得其具有緩解“梯度消失”,加速網(wǎng)絡(luò)收斂的功能。
然而,直接在中間層添加額外的監(jiān)督信息的方式在訓練極深網(wǎng)絡(luò)時可能會導(dǎo)致模型性能下降。眾所周知,深度網(wǎng)絡(luò)具有極強的分層特征表達能力,其特征會隨網(wǎng)絡(luò)深度而變化(底層特征聚焦邊緣特征而缺乏語義信息,而高層特征則聚焦于語義信息)。在底層添加強監(jiān)督信息會導(dǎo)致深度網(wǎng)絡(luò)的上述特征表達方式被破壞,進而導(dǎo)致模型的性能下降。這從某種程度上解釋了為何上述監(jiān)督方式對模型的性能提升比較小(大概在0.5%左右,甚至無提升)。
Dynamic Hierarchical Mimicking
作者重新對上述優(yōu)化目標進行了分析并給出猜測:“最本質(zhì)的原因在于損失函數(shù)中相加的兩塊損失優(yōu)化目標不一致”。以分類為例,盡管兩者均意在優(yōu)化交叉熵損失,但兩者在中間層的優(yōu)化方向是不一致的,存在矛盾點,進而導(dǎo)致對最終模型性能產(chǎn)生負面影響。
針對上述問題,作者提出一種新穎的知識匹配損失用于正則化訓練過程,并使得不同損失對中間層的優(yōu)化目標相一致,從而確保了模型的魯棒性與泛化性能。

所提方法的優(yōu)化目標函數(shù)可以描述如下公式,其示意圖見上圖。

其中比較關(guān)鍵在于第三項的引入,也就是所提到的知識匹配損失。注:由于全文公式太多,本人只是相對粗略的看來一遍,沒有過于深度去研究。應(yīng)該不會影響對其的認知,見后續(xù)的對比分析。
Experiments
為驗證所提方法的有效性,作者在多個數(shù)據(jù)集(Cifar,ImageNet,Market1501等)上的機型了實驗對比分析。
首先,給出了CIFAR-100數(shù)據(jù)集上所提方法與DSL的性能對比,見下圖。盡管DSL可以提升模型的性能,但提提升比較少,而作者所提DHM可以得到更高的性能提升。該實驗證實了所提方法的有效性。

然后,作者給出了ImageNet數(shù)據(jù)集上的性能對比,見下圖??梢缘玫脚c前面類似的結(jié)論,但同時可以看到:對于極深網(wǎng)絡(luò)(如ResNe152),DSL的性能提升非常有限,而所提方法仍能極大的提升模型的性能超1%。

其次,作者給出了Market1501數(shù)據(jù)集上的性能對比,見下圖。結(jié)論同前,不再贅述。

最后,作者還提供了其實驗過程中的網(wǎng)絡(luò)架構(gòu),這里僅提供一個參考模型(MobileNet)作為示例以及分析說明。除了MobileNet外,作者還提供了DenseNet、ResNet、WRN等實驗?zāi)P汀?/p>
Discusion
實事求是的說,本人在看到最后的網(wǎng)絡(luò)結(jié)構(gòu)和代碼之前是沒看明白這篇論文該怎么應(yīng)用的。只是大概了解DSL破壞了深度網(wǎng)絡(luò)的分層特征表達能力,針對該問題而提出的解決方案。
看了論文和代碼后,基本上明白了作者是怎么做的。就一點:既然DSL破壞了深度網(wǎng)絡(luò)的分層特征表達能力,那么就想辦法去補償以不同損失反向傳播到中間層與底層時優(yōu)化方向是一致的。那么該怎么去補償呢?下圖給出了圖示,中間主干分支表示預(yù)定義好的網(wǎng)絡(luò)結(jié)構(gòu),左右兩個分支表示作者補償?shù)慕Y(jié)構(gòu),通過這樣的方式可以確保主損失與右分支損失傳播到layer3的優(yōu)化方向一致,主損失與做分支損失傳播到layer2的優(yōu)化方向一致。當然圖中兩個顏色layer3表示這是不同的處理過程,分支的處理過程肯定要比主分支的計算量小,否則豈不是加大了訓練難度?

我想,看到這里大家基本上都明白了DHM這篇論文所要表達的思想了。接下來,將嘗試將其與其他類似的方法進行一下對比分析。首先給出傳統(tǒng)訓練方式、DSL訓練方式與DHM的對比圖(注:圖中暗紅色區(qū)域表示損失計算,具體怎么計算不詳述)。

上圖給出了常規(guī)訓練過程、DSL訓練過程以及DHM的訓練成果對比。常規(guī)訓練過程僅在head部分有一個損失;而DSN(即DSL)則有多個損失,不同的損失回傳的速度時不一樣的,比如左分支損失直接傳給了layer2,這明顯快于中間的主損失,這是緩解“梯度消失”的原因所在;DHM類似于DSL具有多個損失,但同時為防止不同損失對中間層優(yōu)化方向的不一致,而添加了額外的輔助層,用于模擬深度網(wǎng)絡(luò)的分層特征表達。
那么DHM是如何緩解“梯度消失”現(xiàn)象的呢?個人認為,它有兩種方式:(1) ResNet與DenseNet中的緩解“梯度消失”的方式,這與網(wǎng)路結(jié)構(gòu)有關(guān);(2)分支層數(shù)少于主干層數(shù),一定程度上緩解了“梯度消失”。
最后,再補上一個與DHM極為相似的方法DML,兩者的流程圖如下所示。論文原文確實提到了DML方法,但并未與之進行對比。從圖示可以看到兩者還是比較相似的,盡管DML初衷是兩個網(wǎng)絡(luò)采用知識蒸餾的方式進行訓練,而DHM則是針對DSL存在的缺陷進行的改進。

私認為DHM是DML的特例(注:僅僅從上述圖示出發(fā)),有這么三點原因:
損失函數(shù)方面,以圖像分類為例,DML與DHM均采用交叉熵損失+KL散度計算不同分支損失;
分支數(shù)方面:盡管DML原文是借鑒識蒸餾方式,但其分支可以不止兩個,比如擴展到三個呢,四個呢?這兩種方式是不是就一樣了呢?
網(wǎng)路結(jié)構(gòu)方面:盡管DML提到的是兩個網(wǎng)絡(luò),但是兩個網(wǎng)絡(luò)如果共享stem+layer1+layer2部分呢?從這個角度來看,DHM與DML殊途同歸了。
做完上述記錄后,本人厚著臉皮去騷擾了一下李鐸大神,請教了一下。經(jīng)允許,現(xiàn)將作者的理解摘錄如下:
DSL存在的問題:(1) 特征逐級提取問題,如果像上述圖中g(shù)ooglenet/dsn那樣把head直接接在中間層立刻再接classifier,那么強制要求layer2、layer3、layer4都提取high-level語意特征,這和一般網(wǎng)絡(luò)里layer2、layer3可能還在提取更low-level的特征相違背;(2) 不同分支的gradient都會回傳到shared的主支上,如果這些gradient相互沖突甚至抵消,對于整個網(wǎng)絡(luò)的優(yōu)化是產(chǎn)生負面影響的。
DHM的解決方案:(1)第一個問題通過圖中的分支網(wǎng)絡(luò)結(jié)構(gòu)的改進來解決;(2)第二個問題則是通過KL散度損失隱式約束梯度來解決。
OK,關(guān)于DHM的介紹,全文到底結(jié)束!碼字不易,思考更不易,還請給個在看。
Reference
Going Deeper with Convolutions. https://arxiv.org/abs/1409.4842
Deeply Supervised Networks. https://arxiv.org/abs/1409.5185
Deep Mutual Learning. https://arxiv.org/abs/1706.003384