1-當(dāng)蒸餾遇上GAN
什么是知識(shí)蒸餾
我們知道在深度學(xué)習(xí)的大部分網(wǎng)絡(luò)中,有很多神經(jīng)元是冗余的,所以很多網(wǎng)絡(luò)的參數(shù)量是巨大的,但是在很多移動(dòng)端,比如手機(jī)上,是跑不動(dòng)這么大的網(wǎng)絡(luò)的。所以知識(shí)蒸餾的一開始的目標(biāo)是做模型壓縮,它的目標(biāo)就是讓一個(gè)更小的網(wǎng)絡(luò)去擬合甚至是超越教師網(wǎng)絡(luò)的性能。在通常情況下,學(xué)生網(wǎng)絡(luò)的在蒸餾階段的目標(biāo)可以用這樣的一個(gè)函數(shù)來表示,這里的損失函數(shù)L根據(jù)算法對(duì)知識(shí)的定義不同也會(huì)有不同的函數(shù)表示。

比如Hinton在初期的知識(shí)蒸餾文章中將教師網(wǎng)絡(luò)的知識(shí)定義為網(wǎng)絡(luò)輸出的每個(gè)類別的概率,也就是軟標(biāo)簽。所以在這個(gè)算法中,損失函數(shù)就是教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)輸出軟標(biāo)簽的KL散度。這里fT就是教師的輸出,fS是學(xué)生的輸出,下面的T是溫度參數(shù),我們通過讓學(xué)生網(wǎng)絡(luò)去最小化這個(gè)函數(shù),來學(xué)習(xí)教師網(wǎng)絡(luò)的知識(shí),從而達(dá)到蒸餾的效果。

還有之前提到過的Hint,通過擬合教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)中間層的輸出來達(dá)到蒸餾的效果。這里的損失函數(shù)L就是教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)中間層輸出的L2距離,這里的β是為了讓輸出大小相同做的一個(gè)變換。我們會(huì)發(fā)現(xiàn),通過改變損失函數(shù)L,可以得到多種多樣的蒸餾方法,但是這些蒸餾方法都有個(gè)特點(diǎn),就是需要教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)在訓(xùn)練的時(shí)候有著同樣的輸入。

Why GAN?
但是在實(shí)際應(yīng)用環(huán)境下,很有可能教師端訓(xùn)練時(shí)用的是一些隱私性的數(shù)據(jù),比如醫(yī)療時(shí)所用的一些病人的數(shù)據(jù),學(xué)生網(wǎng)絡(luò)沒有辦法去獲取這些數(shù)據(jù),這樣傳統(tǒng)的一些蒸餾方法就不能使用了。為了解決這個(gè)問題,我們很自然的會(huì)想到,是不是可以生成一個(gè)假的數(shù)據(jù)集,然后用這個(gè)假數(shù)據(jù)集作為教師和學(xué)生公用的數(shù)據(jù)集,并在上面做蒸餾操作。在生成假數(shù)據(jù)方面,近幾年流行的就是生成式對(duì)抗網(wǎng)絡(luò)GAN了。這就是為什么我們想把GAN和蒸餾去做一個(gè)結(jié)合。

生成式對(duì)抗網(wǎng)絡(luò)
這里簡(jiǎn)單的介紹一下生成式對(duì)抗網(wǎng)絡(luò),GAN是由兩個(gè)網(wǎng)絡(luò)組成的,一個(gè)是左邊的生成器網(wǎng)絡(luò),另一個(gè)是右邊的判別器網(wǎng)絡(luò)。生成器接收一個(gè)隨機(jī)向量,并生成一個(gè)假樣本。以MNIST數(shù)據(jù)集為例,這里的真實(shí)數(shù)據(jù)就是MNIST數(shù)據(jù)集中的手寫數(shù)字圖像,而生成器就需要去生成類似于手寫數(shù)字的圖片。判別器就是一個(gè)二分類器,他需要盡可能的去分辨輸入的圖像是真實(shí)的圖片還是生成器生成的。當(dāng)輸入的圖像是真實(shí)圖片的時(shí)候,判別器需要給他一個(gè)很高的分?jǐn)?shù),當(dāng)輸入的圖像是生成器生成的圖片的時(shí)候,判別器需要給他一個(gè)很低的分?jǐn)?shù)。而生成器的目標(biāo),就是讓自己生成的圖像盡可能的被打出一個(gè)高分。GAN就是通過這樣一種對(duì)抗的形式,來獲得生成器和判別器彼此性能的提升。

訓(xùn)練判別器
我們可以通過GAN 的訓(xùn)練過程來對(duì)它做一個(gè)更加深入的理解。還是以MNIST數(shù)據(jù)集為例,首先我們需要更新判別器的參數(shù)。需要注意的是,在更新判別器參數(shù)的時(shí)候,生成器的參數(shù)是需要完全固定的。還是以MNIST數(shù)據(jù)集為例,我們看上面的綠色虛線框起來的部分,這里輸入判別器的是MNIST數(shù)據(jù)集中的一張真實(shí)圖像,經(jīng)過判別器之后,會(huì)輸出其屬于真假兩個(gè)類別的概率,我們假設(shè)輸出真的概率是0.6假的概率是0.4。對(duì)于輸入的真實(shí)圖像,我們希望判別器很確定的將其判斷為真,也就是真的概率是1,假的概率是0。這樣,我們就可以計(jì)算兩者的一個(gè)交叉熵作為判別器的優(yōu)化目標(biāo)。這里因?yàn)榕袆e器所作的是一個(gè)二分類任務(wù),我們可以對(duì)交叉熵做進(jìn)一步的優(yōu)化,只需要其輸出為真的概率,也就是上面的0.6盡可能的接近1就好了。后面的這個(gè)式子中的D(x)就是這里的0.6,我們希望它越大越好。同樣的,我們看到下面這里黃色虛線框起來的部分,這里輸入生成器的隨機(jī)向量,在這個(gè)例子中我們可以簡(jiǎn)單的理解為100維的服從高斯分布的向量,這個(gè)向量提供了一些跟更高維度的信息,生成器通過網(wǎng)絡(luò)添加更多的細(xì)節(jié),從而生成了一張假的圖像。所以在這個(gè)階段輸入判別器的是生成器生成的加圖像,同樣的我們也會(huì)得到它屬于真假兩個(gè)類別的概率,判別器希望其被判定為真的概率越低越好,也就是這里的D(G(z))越低越好。通過這里兩個(gè)優(yōu)化目標(biāo),我們可以對(duì)判別器計(jì)算損失并反向傳播做第一次的參數(shù)更新。

訓(xùn)練生成器
在上個(gè)階段,判別器更新完成后,我們把判別器的參數(shù)固定,開始訓(xùn)練生成器。同樣的,生成器生成一張假圖片輸入判別器,得到真假兩個(gè)類別的概率。與前面不同的是,生成器希望自己生成的圖片被盡可能的判定為真的,所以它需要讓這里的D(G(Z))越大越好。這樣,整個(gè)生成式對(duì)抗網(wǎng)路的優(yōu)化目標(biāo)就很好理解了。這個(gè)階段我們更新了生成器的參數(shù),與前面判別器的參數(shù)更新放在一起就是GAN訓(xùn)練完整的一輪過程。不斷的重復(fù)這個(gè)過程直到網(wǎng)絡(luò)達(dá)到一個(gè)收斂的狀態(tài)。

蒸餾 meets GAN
在了解了生成式對(duì)抗網(wǎng)絡(luò)之后,我們?nèi)ハ胨趺春驼麴s去結(jié)合起來。有一種很簡(jiǎn)單很直接的方法就是,直接把教師網(wǎng)絡(luò)的數(shù)據(jù)集丟給一個(gè)隨機(jī)初始化的GAN,讓他從頭開始訓(xùn)練,然后生成假圖片用來蒸餾。這種方法確實(shí)沒有問題,但是從頭開始訓(xùn)練一個(gè)GAN,是非常非常慢的。我們可以想一下現(xiàn)在所擁有的工具,除了教師端私有的數(shù)據(jù)集之外,還有一個(gè)已經(jīng)訓(xùn)練好的性能很棒的教師網(wǎng)絡(luò)。如果可以把這個(gè)教師網(wǎng)絡(luò)利用起來,就可幫助生成模型更快的達(dá)到收斂狀態(tài),從而降低計(jì)算消耗。怎么去用這個(gè)教師網(wǎng)絡(luò)呢,我們知道教師網(wǎng)絡(luò)是在真實(shí)數(shù)據(jù)上訓(xùn)練出來的,比如說我們用的是mnist數(shù)據(jù)集,那這個(gè)預(yù)訓(xùn)練好的教師網(wǎng)絡(luò)就可以很好的提取真實(shí)圖像的特征,然后用這些特征去分類輸入圖像。這一點(diǎn)和GAN中判別器所作的任務(wù)是非常的相似的。所以我們可以把教師網(wǎng)絡(luò)看作是一個(gè)判別器,但是這個(gè)判別器和傳統(tǒng)的GAN的判別器是不一樣,他的參數(shù)都是預(yù)訓(xùn)練好的,而不是隨機(jī)初始化的,如果還是用原來的訓(xùn)練方法的話,那么隨機(jī)初始化的生成器無論生成什么樣的圖片,都會(huì)被判別器以很高的置信度判定為假,這樣生成器就不知道自己優(yōu)化的方向了。還有一點(diǎn)就是,教師網(wǎng)絡(luò)作為判別器做的也不是真假的二分類任務(wù),如果是mnist數(shù)據(jù)集,那他做的就是一個(gè)10分類,而不是之前我們說過的輸出圖片屬于真假類別的概率。所以對(duì)于生成器的優(yōu)化目標(biāo),我們需要做一定的變化。

生成器損失1
既然生成器是用來生成圖片的,我們可以想一想怎么衡量生成器生成圖片的好壞呢?這其實(shí)是一個(gè)非常困難的任務(wù),我們以生成一只貓的圖片為例子,可能生成器生成了一只躺著的貓,而真實(shí)圖像是一只站著的貓,所以不是說生成器生成的圖像和數(shù)據(jù)集中某一張真實(shí)的貓的圖像一模一樣他就是好的,因此也不能的直接對(duì)兩者去計(jì)算L1L2損失。我們可以從真實(shí)圖像的某些特征上出發(fā),因?yàn)榻處熅W(wǎng)絡(luò)是在真實(shí)圖像上訓(xùn)練出來,所以一個(gè)比較好的教師網(wǎng)絡(luò),在輸入是真實(shí)圖片的時(shí)候,必定會(huì)在某一個(gè)類別的概率特別的大,在其他類別的概率特別的小。所以我們希望生成器生成的圖像也具有這樣的性質(zhì)。這里的yt就是對(duì)于生成圖像,教師網(wǎng)絡(luò)得到的其屬于每個(gè)類別的概率,小t就是取了概率最大的類別并做成one-hot的向量,用交叉熵來衡量他們的相似度。通過最小化這個(gè)函數(shù)來對(duì)生成圖像做一個(gè)限制。

生成器損失2
同樣的,我們知道卷積神經(jīng)網(wǎng)絡(luò)的卷積核就是一個(gè)特征提取器,相比于一個(gè)隨機(jī)的向量,真實(shí)的輸入圖片會(huì)有更多與之相符合的特征,也就是它被激活的神經(jīng)元會(huì)更多。所以,如果生成的圖像和真實(shí)圖像相似,那么他在中間層被激活的神經(jīng)元也應(yīng)該更多,這里用L1范數(shù)來衡量激活神經(jīng)元的個(gè)數(shù)。

生成器損失3
最后第三項(xiàng),我們從生成圖片本身出發(fā)。我們的生成器是隨機(jī)初始化的,一開始他生成的完全是沒有意義的圖像,假設(shè)生成器經(jīng)過優(yōu)化先生成了一張類似于手寫數(shù)字0的圖像,只是后判別器給他的分?jǐn)?shù)相對(duì)是比較高,這種情況下生成器會(huì)覺的0這個(gè)數(shù)字是比較好的,之后無論輸入是什么,他都朝著生成一個(gè)更好的數(shù)字0去做。這明顯和我們想要的生成器是不一樣的,所以對(duì)于他在每一個(gè)batch中生成的圖像,我們需要對(duì)其做一個(gè)限制,最好的情況就是生成的圖片在每一個(gè)類別上的分布都是均衡的,此時(shí)信息量也是最大的,所以我們用這樣的一個(gè)函數(shù)來對(duì)它生成圖片的多樣性做一個(gè)限制。加上之前的兩項(xiàng)就是生成器總的目標(biāo)函數(shù),通過這個(gè)目標(biāo)函數(shù),我們可以對(duì)生成器進(jìn)行優(yōu)化從而得到很多的假圖像。有了假圖像之后就十分簡(jiǎn)單了,只需要把之前的蒸餾方法套進(jìn)來就行了,這里作者選用的是Hinton的軟標(biāo)簽。

實(shí)驗(yàn)
我們看論文的實(shí)驗(yàn)結(jié)果,首先看紅色框框里面的三行,這里三行是說,在教師網(wǎng)絡(luò)的數(shù)據(jù)集可以被獲取到的情況下,學(xué)生網(wǎng)絡(luò)所能達(dá)到的性能。這里教師網(wǎng)絡(luò)是ResNet34,學(xué)生網(wǎng)絡(luò)是ResNet18,數(shù)據(jù)集用的是CIFAR10,可以看到第三行Hinton的知識(shí)蒸餾方法在網(wǎng)絡(luò)參數(shù)量減少一倍的情況下,仍然保持了一個(gè)較高的分類準(zhǔn)確率。但是到了第四行,在不能得到教師網(wǎng)絡(luò)數(shù)據(jù)的情況下,他只能達(dá)到14.89的準(zhǔn)確率。而用作者的方法,則可以達(dá)到92.22的準(zhǔn)確率,這說明作者提出的算法是確實(shí)有效的。

這是作者另外作的一個(gè)剝離實(shí)驗(yàn),就是去驗(yàn)證之前所說的生成器三項(xiàng)損失的有效性。我們首先看紅色的框框,這個(gè)是說生成器不進(jìn)行優(yōu)化,隨機(jī)的去生成數(shù)據(jù),這種情況下學(xué)生網(wǎng)絡(luò)也能達(dá)到88的準(zhǔn)確率。這個(gè)和之前說的14的準(zhǔn)確率不一樣是因?yàn)樗玫氖歉雍?jiǎn)單的MNIST數(shù)據(jù)集。但是我們看到綠色框框中,單獨(dú)的使用one-hot或者是激活損失,得到結(jié)果反而會(huì)比隨機(jī)的要差很多。這是因?yàn)槿狈α诵畔㈧負(fù)p失函數(shù),生成器生成的圖片是非常不均衡的,這樣學(xué)生就無法充分的學(xué)習(xí)到教師網(wǎng)絡(luò)的知識(shí)。最后我們看到在綜合使用三項(xiàng)的情況下的結(jié)果是最好的。

此外作者還做了兩個(gè)額外的實(shí)驗(yàn)。左邊這張表是說教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)結(jié)構(gòu)一樣的情況下,在不同的數(shù)據(jù)集上用作者的方法所能達(dá)到的性能。而右邊這個(gè)圖,是把教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的卷積核進(jìn)行可視化后的結(jié)果,我們可以看到,在1、3、4、5、6列學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)卷積核的可視化結(jié)果非常的相似,這就說明學(xué)生網(wǎng)絡(luò)確實(shí)對(duì)教師網(wǎng)絡(luò)進(jìn)行了有效的學(xué)習(xí)。

Mo?人工智能俱樂部是由人工智能在線建模平臺(tái)(網(wǎng)址:https://momodel.cn)的研發(fā)與產(chǎn)品設(shè)計(jì)團(tuán)隊(duì)發(fā)起、致力于降低人工智能開發(fā)與使用門檻的俱樂部。團(tuán)隊(duì)具備大數(shù)據(jù)處理分析、可視化與數(shù)據(jù)建模經(jīng)驗(yàn),已承擔(dān)多領(lǐng)域智能項(xiàng)目,具備從底層到前端的全線設(shè)計(jì)開發(fā)能力。主要研究方向?yàn)榇髷?shù)據(jù)管理分析與人工智能技術(shù),并以此來促進(jìn)數(shù)據(jù)驅(qū)動(dòng)的科學(xué)研究。
每周我們?cè)卺斸斨辈?,有浙大研究生博士分享前沿論文,群?hào)是31502598,歡迎參加!
