淺析圖像中的條件生成模型pix2pix

最近AIGC的爆火,不管是AI繪圖還是ChatGPT,都讓生成式模型成為了大家關(guān)注的焦點(diǎn)。而在目前主流圖像生成模型DiffusionNet之前,相信沒(méi)有人不承認(rèn)GAN(Generative Adversarial Nets)是生成模型中劃時(shí)代的作品,以至于當(dāng)時(shí)GAN的衍生模型異常之多。這篇文章就來(lái)介紹一個(gè)較為著名的GAN的衍生模型——pix2pix[1]。
本文將以以下幾個(gè)方面來(lái)對(duì)模型進(jìn)行介紹:
GAN系列入門(mén)介紹
模型結(jié)構(gòu)及基本原理
patchGAN
一些消融實(shí)驗(yàn)
一、GAN系列入門(mén)介紹
1. GAN:GAN[2]的中文全稱叫做生成對(duì)抗模型,模型包含兩個(gè)部分,生成器(Generator)與鑒別器(Discriminator)。簡(jiǎn)單來(lái)說(shuō),生成器的作用是生成假圖像,鑒別器的作用是來(lái)辨別圖像真?zhèn)危ㄟ^(guò)兩者的對(duì)抗,鑒別器不斷提高自己的鑒別能力,而生成器不斷提高自己的生成能力,最終當(dāng)鑒別器無(wú)法“有信心”地判斷輸入的真假時(shí),我們也就可以認(rèn)為生成器已經(jīng)學(xué)會(huì)了以假亂真的生成能力。
鑒別器的輸入是從數(shù)據(jù)集中取出的真實(shí)圖像與生成器生成的圖像,輸出則是鑒別器認(rèn)為該輸入是真實(shí)數(shù)據(jù)的概率。也就是說(shuō),當(dāng)此輸入穩(wěn)定在0.5附近時(shí),我們可以認(rèn)為鑒別器難以判斷輸入是真是假。
而生成器的本質(zhì)就是一個(gè)解碼器(decoder),他的輸入是一個(gè)來(lái)自n維標(biāo)準(zhǔn)正態(tài)分布的n維向量,通過(guò)解碼器解碼得到與真實(shí)數(shù)據(jù)維度相等的圖片,即生成圖片。而一旦這個(gè)生成器通過(guò)對(duì)抗的方式訓(xùn)練完成,我們即可以隨便選取一個(gè)n維標(biāo)準(zhǔn)正態(tài)分布向量,解碼得到新的圖片(數(shù)據(jù)集中未曾出現(xiàn)過(guò)的)。
GAN模型不僅結(jié)構(gòu)簡(jiǎn)單,原理不復(fù)雜,并且可以生成數(shù)據(jù)集中未曾出現(xiàn)過(guò)的圖片,因此在很長(zhǎng)一段時(shí)間(實(shí)際上可以說(shuō)時(shí)至今日)都成為了生成模型的主流研究/使用對(duì)象,并出現(xiàn)了一系列的變種來(lái)解決不同的下游問(wèn)題。
2. conditional GAN:上面講到,GAN的輸入是一個(gè)n維向量,而輸出是某類圖片,即使這些圖片都屬于同一類別,有著相似的風(fēng)格,但我們卻無(wú)法控制生成出來(lái)的數(shù)據(jù)長(zhǎng)什么樣子。而如果我們想要讓生成出來(lái)的數(shù)據(jù)可控,我們通常需要給他一個(gè)額外的輸入標(biāo)簽作為指導(dǎo)條件。這類模型一般稱作contional model,而基于這種思想衍生出來(lái)的GAN模型,被稱為conditional GAN(cGAN)[3]。簡(jiǎn)單來(lái)說(shuō),在cGAN中,指導(dǎo)條件(稱作y)也會(huì)編碼成向量形式,通過(guò)concatenate的方式與隨機(jī)向量z融合,并放入生成器中生成圖像G(z,y)。在鑒別階段,y依然會(huì)作為額外信息,通過(guò)多層映射與真實(shí)數(shù)據(jù)x、生成數(shù)據(jù)G(z,y)融合,形成新的向量,送入鑒別器進(jìn)行判斷。
3. pix2pix:有一類任務(wù)叫做image-to-image translation。也就是輸入和輸出是來(lái)自兩個(gè)不同集合(設(shè)為A和B)的圖片,且我們一般認(rèn)為它們是有對(duì)應(yīng)關(guān)系的。比如輸入黑白照片(A)輸出彩色照片(B),輸入輪廓照片(A)輸出色彩填充照片(B)等(如圖1),本文介紹的pix2pix模型所處理的就是這類任務(wù)。并且原文作者通過(guò)一系列實(shí)驗(yàn),證明了conditional GAN在這類問(wèn)題上的有效性,也就是說(shuō),pix2pix本質(zhì)上是一種特殊的conditional GAN。

二、模型結(jié)構(gòu)及基本原理
圖2給出了模型的基本結(jié)構(gòu)圖,其中G為生成器,D為鑒別器。由于我們的輸入是圖像而非低維向量,因此G不再是一個(gè)簡(jiǎn)單的解碼器,而是一個(gè)編碼-解碼的結(jié)構(gòu)(encoder-decoder)。近些年來(lái),編碼-解碼結(jié)構(gòu)用的最多的就是U-Net[4],在傳統(tǒng)的編碼-解碼結(jié)構(gòu)上添加了skip-connection結(jié)構(gòu),將encode過(guò)程中卷積得到的不同尺寸的特征圖,直接concatenate到decode過(guò)程中相應(yīng)尺寸的特征圖上,這樣避免了一些特征在下采樣過(guò)程中的損失,盡可能的保留了原始圖像在不同尺寸上的特征信息。
鑒別器D也區(qū)別于傳統(tǒng)GAN的鑒別器,使用的叫做Patch Discriminator,這個(gè)部分將在第三節(jié)進(jìn)行詳細(xì)講解。注意,在pix2pix模型中,G與D都會(huì)看到輸入x(即圖2中的輪廓圖),在G中,x作為輸入來(lái)通過(guò)編碼-解碼結(jié)構(gòu)獲得G(x),而在D中,x作為指導(dǎo)條件(conditions)來(lái)輔助鑒別器進(jìn)行判斷。所以pix2pix本質(zhì)上就是一個(gè)cGAN。
有人可能會(huì)問(wèn),如果我們不讓鑒別器看到x,只讓x作為輸入進(jìn)行編碼,模型會(huì)變差嗎?由于鑒別器中不加入x,更像傳統(tǒng)GAN(雖然生成器但從解碼器變成了編碼解碼結(jié)構(gòu)),這個(gè)問(wèn)題也可以轉(zhuǎn)換成:在這個(gè)任務(wù)中,cGAN真的要強(qiáng)于GAN嗎,加入condition真的有提升嗎?
關(guān)于這個(gè)問(wèn)題,作者做了消融實(shí)驗(yàn),并驗(yàn)證了cGAN相比于GAN確實(shí)表現(xiàn)更出色。本文將在第四節(jié)給出消融實(shí)驗(yàn)的結(jié)果。

這里面還有一個(gè)問(wèn)題,由于我們使用了encoder-decoder結(jié)構(gòu)的生成器,這樣的話,由于訓(xùn)練完成后,模型參數(shù)不再變化,這會(huì)使得任一確定的輸入圖像都會(huì)被編碼成對(duì)應(yīng)的確定的向量,再通過(guò)參數(shù)固定的解碼器,輸出圖像也將確定(deterministic)。這樣就失去了生成模型的隨機(jī)性。因此參考傳統(tǒng)GAN模型,我們需要引進(jìn)服從標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)向量z來(lái)增添其隨機(jī)性。于是我們看到的損失函數(shù)公式如下

其中生成器生成的圖像是G(x,z)而非G(x),這里面的z就是一個(gè)隨機(jī)向量。另外,在過(guò)去使用的cGAN中,z一般是作為額外的輸入向量輸入到模型中,但pix2pix的作者通過(guò)實(shí)驗(yàn)發(fā)現(xiàn),在pix2pix模型里z作為輸入向量效果并不好,模型會(huì)很容易地學(xué)會(huì)如何忽視掉這個(gè)隨機(jī)向量。因此,作者將z作為每層網(wǎng)絡(luò)的dropout的形式加到了模型中來(lái)增加隨機(jī)性。作者提到,即使這么做,模型的隨機(jī)性依然不好,因此作者認(rèn)為如何通過(guò)cGAN生成隨機(jī)性很強(qiáng)的輸出,將是未來(lái)的研究方向之一。
最后,pix2pix模型的損失函數(shù)共有兩部分組成,上面列出的只是GAN loss這個(gè)部分,由于我們不僅希望輸出的圖片“看起來(lái)真”,還要讓輸出G(x,z)在構(gòu)圖結(jié)構(gòu)及細(xì)節(jié)上更貼近目標(biāo)圖像y。因此,我們還需要引入像素級(jí)別的損失函數(shù),來(lái)讓對(duì)應(yīng)像素的值盡可能接近。這類損失函數(shù)使用最多的就是L1和L2損失。于是最終損失函數(shù)如下:

對(duì)于損失函數(shù),這里面有一點(diǎn)需要簡(jiǎn)要拓展一下:無(wú)論單獨(dú)使用是L1損失還是L2損失,都會(huì)使得結(jié)果偏向模糊。這是因?yàn)檫@兩種損失函數(shù)均是對(duì)對(duì)應(yīng)像素差取均值,這樣的話會(huì)使得輸出的像素分布更加平緩,從而只能很好的保留低頻信息,卻無(wú)法生成準(zhǔn)確的高頻信息,因此從視覺(jué)感受上會(huì)明顯感覺(jué)出差異(一眼模糊)。但是好在我們有GAN損失函數(shù),他是專門(mén)處理“看著不像”的問(wèn)題的,因此L1+cGAN的損失函數(shù)可以最大程度還原我們想要的圖像。更多對(duì)比實(shí)驗(yàn)將在第四節(jié)展示。
三、patchGAN
上一節(jié)提到,與傳統(tǒng)GAN中的鑒別器將整張圖片映射到一個(gè)標(biāo)量概率值不同,pix2pix是先將圖像打成N×N的patches,再將每個(gè)patch送到鑒別器中進(jìn)行判別,最后取得判別的均值作為最終結(jié)果。這種方法并非pix2pix首創(chuàng),而在一種名叫Markovian GAN[5]的模型中已經(jīng)開(kāi)始使用。下面我們通過(guò)兩個(gè)方面來(lái)對(duì)這個(gè)方法的細(xì)節(jié)做一些粗淺的解釋:

1. 為什么要用patch discriminator?
其實(shí)圖3一幅圖就可以很簡(jiǎn)單的闡明這個(gè)方法的原理。假設(shè)圖3中坐標(biāo)系上的每一個(gè)點(diǎn)表示一個(gè)圖像,藍(lán)色點(diǎn)表示輸入數(shù)據(jù)點(diǎn),紅色點(diǎn)表示輸出數(shù)據(jù)點(diǎn)(也就是來(lái)自我們想要得到的圖像空間)。由于一般統(tǒng)計(jì)算法做了圖像分布多為標(biāo)準(zhǔn)正態(tài)分布的假設(shè)前提,那么就如圖3中的第一個(gè)圖所示,當(dāng)整體分布已經(jīng)擬合的很好的時(shí)候(可以看到紅色圈與藍(lán)色圈基本重合),模型就會(huì)停止學(xué)習(xí)。但實(shí)際上,現(xiàn)實(shí)中大多數(shù)數(shù)據(jù)分布并不是服從正態(tài)分布,而是更為復(fù)雜的分布,于是對(duì)抗學(xué)習(xí)就起了大作用,它會(huì)不斷的讓紅點(diǎn)與藍(lán)點(diǎn)重合,讓輸出分布盡量擬合輸入分布,這也就是圖3第二第三張圖所示部分。
而[5]作者將圖像分成patches再去做鑒別,相當(dāng)于是對(duì)圖像空間/分布進(jìn)行進(jìn)一步的細(xì)化,這使得輸入與輸出的圖像分布可以更進(jìn)一步、更細(xì)化地?cái)M合,從而得到更好的效果。
另外,由于鑒別器對(duì)patches進(jìn)行判別,輸入尺寸大幅減小,因此參數(shù)量也大幅降低,模型運(yùn)行速度也隨之大幅提高,這樣我們就可以在保證效率的前提下處理任意大的圖片。
2. 為什么敢用patch discriminator?
有人可能會(huì)問(wèn),把整圖打成patch,會(huì)不會(huì)影響模型對(duì)整體的把握,會(huì)不會(huì)丟失全局信息?
[1]中作者提出,由于L1損失已經(jīng)能夠很好地保留低頻信息,也就是說(shuō),就算不加cGAN損失,我們也已經(jīng)能夠得到很好的色塊分布與結(jié)構(gòu)相似度。因此,我們可以大膽的使用patch discriminator,并且把patch當(dāng)鑒別器輸入后,模型可以學(xué)習(xí)更清晰、精確的高頻特征,與L1損失達(dá)成互補(bǔ),使得輸出更加精確。
如果想更形象地表示,你可以認(rèn)為使用L1損失生成的模糊圖像,是老花眼人士眼中的圖像,雖然非常模糊,但基本可以看清大體輪廓和色彩分布。而patchGAN在這里起到的是放大鏡的作用,不斷移動(dòng)放大鏡,可以看到不同位置的細(xì)節(jié)。由于雖然沒(méi)有放大鏡看不清,但我們已經(jīng)有了大體輪廓,因此當(dāng)使用放大鏡看到每一處局部細(xì)節(jié)后,我們就可以想象出圖像整體清晰的樣子。
另外,形成共識(shí)的是,圖像可以看作Markovian Random Field(MRF),也就是某點(diǎn)像素只與其邊上的像素強(qiáng)關(guān)聯(lián),與遠(yuǎn)處的像素沒(méi)有很強(qiáng)的關(guān)聯(lián)性?;谶@個(gè)先驗(yàn)知識(shí),我們便可以大膽的將圖像打成patch進(jìn)行學(xué)習(xí)/鑒別,從而不會(huì)影響結(jié)果。關(guān)于MRF的相關(guān)知識(shí),請(qǐng)讀者自行學(xué)習(xí)。
四、一些消融實(shí)驗(yàn)
1.?使用L1+cGAN損失比單獨(dú)使用L1/cGAN都要好
圖4中,從左至右分別代表:輸入,目標(biāo)圖像,只用L1,只用cGAN,L1+cGAN??梢钥吹街挥肔1的話,圖像確實(shí)十分模糊,但是大體的色塊分布與結(jié)構(gòu)信息已經(jīng)學(xué)到了。而單獨(dú)使用cGAN,高頻信號(hào)非常多(即顏色突變多),圖片整體銳化程度過(guò)大。

2. cGAN比GAN(不讓鑒別器看到輸入)要好

圖5中可以看到,即使單獨(dú)使用L1,也要比單獨(dú)使用GAN好很多,而cGAN更是明顯優(yōu)于GAN。
3. U-Net比傳統(tǒng)的encoder-decoder結(jié)構(gòu)要好

圖6的結(jié)果表明,無(wú)論使用哪種損失函數(shù),Unet的結(jié)果均要好于傳統(tǒng)的encoder-decoder模型,這說(shuō)明Unet中的skip-connection確實(shí)將下采樣過(guò)程中丟失的信息保存了下來(lái)并提高了上采樣的精度。另外原文也給了數(shù)值指標(biāo)結(jié)果,以L1+cGAN損失為例,U-net的per-pixel,per-class的準(zhǔn)確度分別為0.55與0.20,均優(yōu)于encoder-decoder的0.29與0.09。
4. 不同尺寸的patch下,patch GAN學(xué)習(xí)的能力表現(xiàn)

圖7列出了不同patch尺寸下的學(xué)習(xí)結(jié)果,第一張圖只使用了L1,結(jié)果相當(dāng)模糊,后四張均用的是L1+patchGAN損失。第二張圖使用的是1×1大小的patch,其實(shí)就是一個(gè)像素,因此也叫pixelGAN,可以看到車的色彩發(fā)生了變化,說(shuō)明即使看不到鄰近信息,GAN依然學(xué)會(huì)了加強(qiáng)色彩變幻。16×16與70×70都或多或少加入了一些細(xì)節(jié),無(wú)論從數(shù)值結(jié)果還是視覺(jué)效果,都看得出來(lái)70×70效果最優(yōu)。286×286相當(dāng)于把整圖送進(jìn)去學(xué)習(xí),因此也稱作ImageGAN,可以看到視覺(jué)效果依然不如70×70,主要由于特征學(xué)習(xí)的過(guò)于全局化,導(dǎo)致局部細(xì)節(jié)并不和諧。
五、寫(xiě)在最后
pix2pix是我研究生畢業(yè)課題中使用的模型,用在了醫(yī)學(xué)影像相關(guān)領(lǐng)域的研究上,這也側(cè)面說(shuō)明了這個(gè)模型的泛用性。正如作者所說(shuō),在pix2pix模型之前,cGAN其實(shí)已經(jīng)廣泛用于各種生成任務(wù)中,比如圖像修復(fù)(inpaiting)、風(fēng)格遷移(style transfer)、提高分辨率(superresolution)等任務(wù)上。但作者認(rèn)為他們最大的貢獻(xiàn),是構(gòu)建了一個(gè)通用模型,可以在多種任務(wù)上取得優(yōu)異成績(jī)。其實(shí)這個(gè)模型就是U-Net、MGAN、cGAN的一個(gè)整合,但是卻有將近2w的引用量。這也印證了那句話,成功有時(shí)候真的可能只是因?yàn)檎驹诰奕说募绨蛏稀?/p>
引用:
[1]?https://arxiv.org/abs/1611.07004
[2]?https://arxiv.org/abs/1406.2661
[3]?https://arxiv.org/abs/1411.1784
[4]?https://arxiv.org/abs/1505.04597
[5] https://arxiv.org/abs/1604.04382