詳解半監(jiān)督學習經(jīng)典工作:邊緣生成對抗網(wǎng)絡
來源:投稿 作者:小灰灰
編輯:學姐

論文標題:MarginGAN: Adversarial Training in Semi-Supervised Learning
論文鏈接: https://papers.nips.cc/paper/2019/file/517f24c02e620d5a4dac1db388664a63-Paper.pdf
代碼鏈接:https://github.com/DJjjjhao/MarginGAN
1.摘要
針對半監(jiān)督學習問題,提出了一種邊緣生成對抗網(wǎng)絡(MarginGAN)。與TripleGAN一樣, MarginGAN由三個組件組成:生成器、鑒別器和分類器,其中出現(xiàn)了兩種形式的對抗性訓練, 鑒別器按常規(guī)進行訓練,以區(qū)分真實數(shù)據(jù)和生成器生成的假數(shù)據(jù),分類器可以增加真實樣本的邊緣,并減少假樣本的邊緣。生成器的目的是生成真實的、大幅度的數(shù)據(jù),以便同時欺騙鑒別器和分類器,偽標簽用于在訓練中生成和未標記的數(shù)據(jù)。我們的方法是基于大邊緣分類器的成功以及最近的觀點,即好的半監(jiān)督學習需要“壞”的GAN。在基準數(shù)據(jù)集上的實驗證明,MarginGAN與幾種最先進的方法正交,提供了改進的錯誤率和更短的訓練時間。
2.介紹
在現(xiàn)實世界中,未標記的數(shù)據(jù)可以相對容易地獲得,而手動標記的數(shù)據(jù)成本很高,偽標簽是未標記數(shù)據(jù)的人工標簽,其作用與人工標注數(shù)據(jù)的標簽相同,是半監(jiān)督學習中一種簡單有效的方法。幾種傳統(tǒng)的SSL方法,如自訓練[1]和協(xié)同訓練[4],都基于偽標簽。在過去幾年中,深度神經(jīng)網(wǎng)絡在SSL方面取得了巨大的進步,因此偽標簽的概念被納入深度學習中,以利用未標記的數(shù)據(jù), 在[5]中,選擇具有最大預測概率的類作為偽標簽。[6]中提出的時間集合使用集合預測作為偽標簽,這是在不同正則化和輸入增強條件下,不同時期的標簽預測的指數(shù)移動平均值. 與[6]中的標簽預測進行平均相比,在平均教師方法[7]中,模型權重進行平均。[5]中的偽標簽具有與地面真值標簽相同的效果,以最小化交叉熵損失,而[6,7]中的偽標簽用作預測目標,以實現(xiàn)一致性正則化,這可以使分類器為相似數(shù)據(jù)點提供一致的輸出。
最近,生成性對抗網(wǎng)絡(GANs)被應用于SSL,并取得了驚人的結果。[8]中提出的特征匹配(FM)GANs方法用(K+1)類分類器代替了原始的二元鑒別器。分類器(即鑒別器)的目的是將標記樣本分類為正確類,將未標記樣本分類到前K類中的任何一類,并將生成樣本分類到第(K+1)類,作為特征匹配GANs的改進,在[9]中提出的方法驗證了良好的半監(jiān)督學習需要“壞”生成器。所提出的補碼生成器可以在低密度區(qū)域中產(chǎn)生人工數(shù)據(jù)點,從而鼓勵分類器在這些區(qū)域中放置類邊界,并提高泛化性能。
盡管在深度學習中使用偽標簽的想法簡單有效,但有時可能會發(fā)生不正確的偽標簽會損害泛化性能并減慢深度網(wǎng)絡的訓練。先前的工作,如[6,7]致力于如何提高偽標簽的質量。受[9]的啟發(fā),我們提出了一種方法,鼓勵生成器在SSL中生成“壞”示例,從而提高對錯誤偽標簽的容忍度,并進一步降低錯誤率。
為了解決由錯誤偽標簽引起的問題,我們提出了MarginGAN,一種基于分類器邊緣理論的半監(jiān)督學習中的GAN模型。MarginGAN由三個組件組成——生成器、鑒別器和分類器(MarginGAN的架構見圖1)。鑒別器的作用與標準GAN中的作用相同,區(qū)分樣本是來自真實分布還是由生成器生成。訓練多類分類器以增加真實數(shù)據(jù)(包括標記數(shù)據(jù)和未標記數(shù)據(jù))的分類邊緣,同時減少生成的假樣本的邊緣,生成器的目標是生成看起來真實且具有較大余量的偽標簽,旨在同時欺騙鑒別器和分類器。
3.方法
3.1 本文的動機
在通常的GAN模型中,目標是訓練一個生成器,該生成器可以生成真實的假樣本,使得鑒別器無法辨別真實或假樣本。然而,在SSL問題中,我們的目的是訓練高精度分類器,從而獲得大量訓練示例,我們希望生成器能夠產(chǎn)生接近真實決策邊界的“信息”樣本,就像支持向量機模型中的支持向量一樣。這里出現(xiàn)了另一種對抗性訓練:生成器試圖生成大幅度的假樣本,而分類器旨在對這些假示例進行小幅度預測。
未標記樣本(和假樣本)的錯誤偽標簽大大降低了基于偽標簽的先驗方法的準確性,但我們的MarginGAN對錯誤偽標簽表現(xiàn)出更好的容忍度。由于鑒別器在通常的GAN中起著相同的作用,我們認為MarginGAN獲得的提高的準確性來自生成器和分類器之間的對抗性交互。
首先,我們消融研究中的極端訓練案例表明,MarginGAN生成的假樣本可以積極糾正錯誤偽標簽的影響。由于分類器強制執(zhí)行假樣本的小邊界值,因此生成器必須在“正確”決策邊界附近生成假樣本。這將細化和縮小圍繞真實樣本的決策邊界。
其次,我們說明了四類問題的大幅度直覺。如果分類器選擇相信錯誤的偽標簽,則決策邊界必須跨越兩類示例之間的“真實”差距。但是錯誤的偽標簽會導致邊緣值減少,從而影響泛化精度。因此,為了獲得更高的準確度,大邊緣分類器應該忽略那些錯誤的偽標簽。
3.2 Matgin
在機器學習中,單個數(shù)據(jù)點的邊緣定義為從該數(shù)據(jù)點到?jīng)Q策邊界的距離,該距離可用于限制分類器的泛化誤差。支持向量機(SVM)和boosting都可以用基于邊緣的泛化邊界來解釋,
在AdaBoost算法中,是在迭代t和a_t≥ 0中獲得的基本分類器,并且ht(x)∈?{1,??1} 是在迭代t和
中獲得的基本分類器≥?0是分配給ht的相應權重,組合分類器f是T基分類器的加權多數(shù)表決,其公式為

實例標簽對(x,y)的邊距定義為

3.3 架構總覽
GAN的原始架構由兩個組件組成,一個生成器和一個鑒別器,生成器G變換潛在變量z~p(z)到假樣本x?~(x?) 使得生成的分布
(x?) 近似于真實數(shù)據(jù)分布p(x)。鑒別器D用于區(qū)分生成的假樣本和真實樣本。為了適應半監(jiān)督學習,我們在原始架構中添加了分類器C,我們保留鑒別器,以鼓勵生成器生成視覺上真實的樣本。我們對每個組件的描述如下。MarginGAN的架構如圖1所示。

3.3.1 分類器
我們將多分類器添加到原始的GAN中,因為高精度分類是我們在SSL中的目標。 分類器接收與判別器相同的輸入--標記樣本、未標記樣本和生成的假樣本。
對于標記的樣本,分類器具有與普通多類分類器相同的目標。給定實例標簽對(x,y),分類器C嘗試最小化真實標簽y和預測標簽C(x)之間的交叉熵損失:

標記樣本的損失函數(shù)可以公式化為:

對于未標記的示例,分類器的目標是增加這些數(shù)據(jù)點的邊緣。然而,由于沒有關于相應真實標簽的信息,我們不知道哪個類概率應該達到峰值。我們在一個熱編碼中利用偽標簽來處理未標記的示例。
4.實驗
4.1 初訓練
與我們的工作類似,在先前的工作[5]中使用了偽標記,并報告了MNIST的實驗。為了清楚地顯示MarginGAN帶來的改進,我們首先對MNIST進行了初步實驗。我們使用infoGAN[22]中的生成器和鑒別器,并使用具有六層的簡單卷積網(wǎng)絡作為分類器。盡管我們使用的分類器可能比[5]中使用的更強大,但隨后的消融研究可以揭示生成的假樣本帶來的貢獻。
MNIST由60000幅圖像的訓練集和10000幅圖像的測試集組成,所有圖像均為28×28灰度像素。在設置中,我們采樣100、600、1000或3000個標記樣本,并將訓練集的其余部分用作未標記樣本。在訓練時,我們首先對分類器進行預訓練,以實現(xiàn)錯誤率低于8.0%、9.3%、9.5%和9.7%,僅使用標記樣本,分別對應于100、600、1000和3000個標記樣本。然后,未標記樣本和生成樣本參與訓練過程。表1將我們的結果與[5]中的其他競爭方法進行了比較。我們可以看到,所提出的MarginGAN在每個設置上都優(yōu)于這些基于偽樣本的先前方法,這可以歸因于生成的偽樣本的參與。盡管與現(xiàn)有算法的比較有點不公平,但我們的方法在所有設置下都實現(xiàn)了更高的精確度,隨后的消融研究進一步驗證了我們方法的改進。

4.2 MNIST的消融研究
為了找出標記樣本、未標記樣本和生成的假樣本的影響,我們在一次輸入一種或多種樣本的情況下進行消融實驗。在消融研究中,由于偽標記的不穩(wěn)定性和某些情況下缺乏標記示例,我們將學習率從0.1降低到0.01。我們測量了不同設置下訓練收斂所需的最低錯誤率和時間,結果如表2所示。

表2:MNIST算法的消融研究。本實驗中標記的示例數(shù)量為600個。L、U和G的縮寫分別對應于標記的示例、未標記的示例和生成的示例。最后兩行顯示了極端的訓練情況。
未標記示例在半監(jiān)督學習中起著重要作用。我們可以看到,添加未標記的示例可以將錯誤率從8.21%降低到4.54%,提高了3.67%。為了驗證偽標簽正確性的不確定性,我們進行了一次極端的嘗試:對分類器進行預訓練,以達到9.78%(±0.14%)的錯誤率,然后我們單獨向分類器提供未標記的示例。換句話說,分類器不能再次訪問標記的示例。令我們驚訝的是,錯誤率急劇上升,很快達到89.53%以上。不正確的偽標簽將誤導分類器并阻礙其泛化。
生成假示例我們將生成的示例反饋給分類器,使其對錯誤的偽標簽具有魯棒性,并提高了性能。我們可以看到,與只訓練標記樣本和未標記樣本相比,生成的示例可以進一步將錯誤率從4.54%提高到3.20%。此外,值得注意的是,生成的示例可以顯著減少71.8%的訓練時間。然而,當我們繼續(xù)訓練時,錯誤率開始增加,出現(xiàn)過度擬合。當生成的圖像逐漸變得更真實時,分類器仍然會減少邊緣,這可能會影響性能?;氐缴鲜鰳O端情況,當在預訓練后組合未標記圖像和生成圖像時,錯誤率確實可以提高(從9.78%到7.40%)。
4.3 Generated Fake Images
我們在圖3中顯示了當分類器的精度增加時由MarginGAN生成的圖像。正如我們所看到的,這些假圖像看起來真的很“糟糕”:例如,MNIST和SVHN中生成的大多數(shù)數(shù)字都接近決策邊界,因此無法以高置信度確定它們的標簽。這種情況符合本文的動機。
5.結論
在這項工作中,我們提出了邊緣生成對抗網(wǎng)絡(MarginGAN),它由三部分組成:一個生成器、一個鑒別器和一個分類器。關鍵是分類器可以利用生成器生成的假示例來提高泛化性能。具體而言,分類器的目標是最大化真實示例的邊緣值,最小化假樣本的邊緣。生成器試圖產(chǎn)生真實的、大幅度的示例,以欺騙鑒別器和分類器。在多個基準上的實驗結果表明,MarginGAN可以提高精度并縮短訓練時間。
參考文獻
[1] Probability of error of some adaptive pattern-recognition machines
[4] Combining labeled and unlabeled data with co-training
[5] ?D.-H. Lee. Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks. ICML Workshop, 2013.
[6] ?S. Laine and T. Aila. Temporal ensembling for semi-supervised learning. arXiv:1610.02242, 2016.
[7] ?A.TarvainenandH.Valpola.Meanteachersarebetterrolemodels:Weight-averagedconsistency targets improve semisupervised deep learning results. NeurIPS, 2017.
關注“學姐帶你玩AI”公眾號
回復“500”免費領取學姐整理的論文!