90天學(xué)會(huì)GAN--Day4--從MNIST數(shù)據(jù)集開(kāi)始

GAN的變體之一:CGAN
CGAN是GAN的一種變體,主要就是加入了 label 來(lái)影響生成器生成的圖片,以達(dá)成一定程度的分類效果。
以 MNIST 數(shù)據(jù)集為例,MNIST 數(shù)據(jù)集中有 0-9 共10個(gè)數(shù)字, 所以可以給每一個(gè)數(shù)據(jù)加上一個(gè)標(biāo)簽再放入 generator 生成 。
這樣在最后輸出的時(shí)候就可以通過(guò)插入標(biāo)簽來(lái)生成指定的圖片種類。 比如我們可以通過(guò) nn.Embedding() 函數(shù)來(lái)實(shí)現(xiàn)這個(gè)功能。

通過(guò)這個(gè)方法,我們就可以將 labels 的信息插入 generator 和 discriminator,實(shí)現(xiàn)CGAN的功能。
于是我們生成的時(shí)候只需要在原本生成的噪聲Z 后面在插入這些 labels 即可。
此處,筆者產(chǎn)生了個(gè)疑問(wèn):為啥要在discriminator里也插入 labels 啊,不應(yīng)該只需要generator插入就行了嗎?
此處我們需要回歸 generator 和 discriminator 的定義。
generator 其實(shí)就是將一個(gè)分布映射到另一個(gè)分布的函數(shù),所以我們做的是將一個(gè)隨機(jī)數(shù)輸進(jìn) generator 產(chǎn)生一張假的圖片來(lái)交給 discriminator 判斷這是來(lái)自原圖像還是生成的圖像。
因此,discriminator 實(shí)際上要做的工作就是判斷生成的圖像是否與原圖接近。 那么如果我們?cè)诮唤o discriminator 判斷前,在原圖像旁邊插上一排 labels, 如下圖:

并將這個(gè)作為 “原圖” 輸入進(jìn) discriminator, 與來(lái)自 generator 的圖片進(jìn)行比較,并減小差距 (loss) ;
這時(shí)候 generator 就會(huì)知道要去生成和這張 “原圖” 相近的圖片。
由于 generator 中下半部分已經(jīng)確定了 , 因此最后 generator 只能讓上半部分的圖片更加接近原圖。
所以 generator 能夠通過(guò)這個(gè)標(biāo)簽生成類似 1 的圖片。