擴(kuò)散模型課程第一單元第二部分:擴(kuò)散模型從零到一

前言
于 11 月底正式開課的擴(kuò)散模型課程正在火熱進(jìn)行中,在中國社區(qū)成員們的幫助下,我們組織了「抱抱臉中文本地化志愿者小組」并完成了擴(kuò)散模型課程的中文翻譯,感謝 @darcula1993、@XhrLeokk、@hoi2022、@SuSung-boy 對課程的翻譯!
如果你還沒有開始課程的學(xué)習(xí),我們建議你從?第一單元:擴(kuò)散模型簡介?開始。
擴(kuò)散模型從零到一
這個 Notebook 我們將展示相同的步驟(向數(shù)據(jù)添加噪聲、創(chuàng)建模型、訓(xùn)練和采樣),并盡可能簡單地在 PyTorch 中從頭開始實(shí)現(xiàn)。然后,我們將這個「玩具示例」與 diffusers 版本進(jìn)行比較,并關(guān)注兩者的區(qū)別以及改進(jìn)之處。這里的目標(biāo)是熟悉不同的組件和其中的設(shè)計(jì)決策,以便在查看新的實(shí)現(xiàn)時(shí)能夠快速確定關(guān)鍵思想。
讓我們開始吧!
有時(shí),只考慮一些事務(wù)最簡單的情況會有助于更好地理解其工作原理。我們將在本筆記本中嘗試這一點(diǎn),從“玩具”擴(kuò)散模型開始,看看不同的部分是如何工作的,然后再檢查它們與更復(fù)雜的實(shí)現(xiàn)有何不同。
你將跟隨本文的 Notebook 學(xué)習(xí)到
損壞過程(向數(shù)據(jù)添加噪聲)
什么是 UNet,以及如何從零開始實(shí)現(xiàn)一個極小的 UNet
擴(kuò)散模型訓(xùn)練
抽樣理論
然后,我們將比較我們的版本與 diffusers 庫中的 DDPM 實(shí)現(xiàn)的區(qū)別
對小型 UNet 的改進(jìn)
DDPM 噪聲計(jì)劃
訓(xùn)練目標(biāo)的差異
timestep 調(diào)節(jié)
抽樣方法
這個筆記本相當(dāng)深入,如果你對從零開始的深入研究不感興趣,可以放心地跳過!
還值得注意的是,這里的大多數(shù)代碼都是出于說明的目的,我不建議直接將其用于您自己的工作(除非您只是為了學(xué)習(xí)目的而嘗試改進(jìn)這里展示的示例)。
準(zhǔn)備環(huán)境與導(dǎo)入:
數(shù)據(jù)
在這里,我們將使用一個非常小的經(jīng)典數(shù)據(jù)集 mnist 來進(jìn)行測試。如果您想在不改變?nèi)魏纹渌麅?nèi)容的情況下給模型一個稍微困難一點(diǎn)的挑戰(zhàn),請使用?torchvision.dataset
,F(xiàn)ashionMNIST 應(yīng)作為替代品。
該數(shù)據(jù)集中的每張圖都是一個數(shù)字的 28x28 像素的灰度圖,像素值的范圍是從 0 到 1。
損壞過程
假設(shè)你沒有讀過任何擴(kuò)散模型的論文,但你知道這個過程會增加噪聲。你會怎么做?
我們可能想要一個簡單的方法來控制損壞的程度。那么,如果我們要引入一個參數(shù)來控制輸入的“噪聲量”,那么我們會這么做:
如果 amount = 0,則返回輸入而不做任何更改。如果 amount = 1,我們將得到一個純粹的噪聲。通過這種方式將輸入與噪聲混合,我們將輸出保持在相同的范圍(0 to 1)。
我們可以很容易地實(shí)現(xiàn)這一點(diǎn)(但是要注意 tensor 的 shape,以防被廣播 (broadcasting) 機(jī)制不正確的影響到):
?
讓我們來可視化一下輸出的結(jié)果,以了解是否符合我們的預(yù)期:
當(dāng)噪聲量接近 1 時(shí),我們的數(shù)據(jù)開始看起來像純隨機(jī)噪聲。但對于大多數(shù)的噪聲情況下,您還是可以很好地識別出數(shù)字。你認(rèn)為這是最佳的嗎?
模型
我們想要一個模型,它可以接收 28px 的噪聲圖像,并輸出相同形狀的預(yù)測。一個比較流行的選擇是一個叫做 UNet 的架構(gòu)。最初被發(fā)明用于醫(yī)學(xué)圖像中的分割任務(wù),UNet 由一個“壓縮路徑”和一個“擴(kuò)展路徑”組成?!皦嚎s路徑”會使通過該路徑的數(shù)據(jù)被壓縮,而通過“擴(kuò)展路徑”會將數(shù)據(jù)擴(kuò)展回原始維度(類似于自動編碼器)。模型中的殘差連接也允許信息和梯度在不同層級之間流動。
一些 UNet 的設(shè)計(jì)在每個階段都有復(fù)雜的 blocks,但對于這個玩具 demo,我們只會構(gòu)建一個最簡單的示例,它接收一個單通道圖像,并通過下行路徑上的三個卷積層(圖和代碼中的 down_layers)和上行路徑上的 3 個卷積層,在下行和上行層之間具有殘差連接。我們將使用 max pooling 進(jìn)行下采樣和?nn.Upsample
?用于上采樣。某些比較復(fù)雜的 UNets 的設(shè)計(jì)會使用帶有可學(xué)習(xí)參數(shù)的上采樣和下采樣 layer。下面的結(jié)構(gòu)圖大致展示了每個 layer 的輸出通道數(shù):

代碼實(shí)現(xiàn)如下:
我們可以驗(yàn)證輸出 shape 是否如我們期望的那樣與輸入相同:
該網(wǎng)絡(luò)有 30 多萬個參數(shù):
309057
您可以嘗試更改每個 layer 中的通道數(shù)或嘗試不同的結(jié)構(gòu)設(shè)計(jì)。
訓(xùn)練模型
那么,模型到底應(yīng)該做什么呢?同樣,對這個問題有各種不同的看法,但對于這個演示,讓我們選擇一個簡單的框架:給定一個損壞的輸入?noisy_x
,模型應(yīng)該輸出它對原本?x
?的最佳猜測。我們將通過均方誤差將預(yù)測與真實(shí)值進(jìn)行比較。
我們現(xiàn)在可以嘗試訓(xùn)練網(wǎng)絡(luò)了。
獲取一批數(shù)據(jù)
添加隨機(jī)噪聲
將數(shù)據(jù)輸入模型
將模型預(yù)測與干凈圖像進(jìn)行比較,以計(jì)算 loss
更新模型的參數(shù)
你可以自由進(jìn)行修改來嘗試獲得更好的結(jié)果!
Finished epoch 0. Average loss for this epoch: 0.026736
Finished epoch 1. Average loss for this epoch: 0.020692
Finished epoch 2. Average loss for this epoch: 0.018887

我們可以嘗試通過抓取一批數(shù)據(jù),以不同的數(shù)量損壞數(shù)據(jù),然后喂進(jìn)模型獲得預(yù)測來觀察結(jié)果:

你可以看到,對于較低的噪聲水平數(shù)量,預(yù)測的結(jié)果相當(dāng)不錯!但是,當(dāng)噪聲水平非常高時(shí),模型能夠獲得的信息就開始逐漸減少。而當(dāng)我們達(dá)到 amount = 1 時(shí),模型會輸出一個模糊的預(yù)測,該預(yù)測會很接近數(shù)據(jù)集的平均值。模型通過這樣的方式來猜測原始輸入。
取樣(采樣)
如果我們在高噪聲水平下的預(yù)測不是很好,我們?nèi)绾尾拍苌蓤D像呢?
如果我們從完全隨機(jī)的噪聲開始,檢查一下模型預(yù)測的結(jié)果,然后只朝著預(yù)測方向移動一小部分,比如說 20%?,F(xiàn)在我們有一個噪聲很多的圖像,其中可能隱藏了一些關(guān)于輸入數(shù)據(jù)的結(jié)構(gòu)的提示,我們可以將其輸入到模型中以獲得新的預(yù)測。希望這個新的預(yù)測比第一個稍微好一點(diǎn)(因?yàn)槲覀冞@一次的輸入稍微減少了一點(diǎn)噪聲),所以我們可以用這個新的更好的預(yù)測再往前邁出一小步。
如果一切順利的話,以上過程重復(fù)幾次以后我們就會得到一個新的圖像!以下圖例是迭代了五次以后的結(jié)果,左側(cè)是每個階段的模型輸入的可視化,右側(cè)則是預(yù)測的去噪圖像。請注意,即使模型在第 1 步就預(yù)測了去噪圖像,我們也只是將輸入向去噪圖像變換了一小部分。重復(fù)幾次以后,圖像的結(jié)構(gòu)開始逐漸出現(xiàn)并得到改善 , 直到獲得我們的最終結(jié)果為止。

我們可以將流程分成更多步驟,并希望通過這種方式獲得更好的圖像:
<matplotlib.image.AxesImage at 0x7f27567d8210>

結(jié)果并不是非常好,但是已經(jīng)出現(xiàn)了一些可以被認(rèn)出來的數(shù)字!您可以嘗試訓(xùn)練更長時(shí)間(例如,10 或 20 個 epoch),并調(diào)整模型配置、學(xué)習(xí)率、優(yōu)化器等。此外,如果您想嘗試稍微困難一點(diǎn)的數(shù)據(jù)集,您可以嘗試一下 fashionMNIST,只需要一行代碼的替換就可以了。
與 DDPM 做比較
在本節(jié)中,我們將看看我們的“玩具”實(shí)現(xiàn)與其他筆記本中使用的基于 DDPM 論文的方法有何不同: 擴(kuò)散器簡介 Notebook。
擴(kuò)散器簡介 Notebook:
https://github.com/huggingface/diffusion-models-class/blob/main/unit1/01_introduction_to_diffusers.ipynb
我們將會看到的
模型的表現(xiàn)受限于隨迭代周期 (timesteps) 變化的控制條件,在前向傳導(dǎo)中時(shí)間步 (t) 是作為一個參數(shù)被傳入的
有很多不同的取樣策略可選擇,可能會比我們上面所使用的最簡單的版本更好
diffusers?
UNet2DModel
?比我們的 BasicUNet 更先進(jìn)損壞過程的處理方式不同
訓(xùn)練目標(biāo)不同,包括預(yù)測噪聲而不是去噪圖像
該模型通過調(diào)節(jié) timestep 來調(diào)節(jié)噪聲水平 , 其中 t 作為一個附加參數(shù)傳入前向過程中。
有許多不同的采樣策略可供選擇,它們應(yīng)該比我們上面簡單的版本更有效。
自 DDPM 論文發(fā)表以來,已經(jīng)有人提出了許多改進(jìn)建議,但這個例子對于不同的可用設(shè)計(jì)決策具有指導(dǎo)意義。讀完這篇文章后,你可能會想要深入了解這篇論文《Elucidating the Design Space of Diffusion-Based Generative Models》,它對所有這些組件進(jìn)行了詳細(xì)的探討,并就如何獲得最佳性能提出了新的建議。
Elucidating the Design Space of Diffusion-Based Generative Models 論文鏈接:
https://arxiv.org/abs/2206.00364
如果你覺得這些內(nèi)容對你來說太過深奧了,請不要擔(dān)心!你可以隨意跳過本筆記本的其余部分或?qū)⑵浔4嬉詡洳粫r(shí)之需。
UNet
diffusers 中的 UNet2DModel 模型比上述基本 UNet 模型有許多改進(jìn):
GroupNorm 層對每個 blocks 的輸入進(jìn)行了組標(biāo)準(zhǔn)化(group normalization)
Dropout 層能使訓(xùn)練更平滑
每個塊有多個 resnet 層(如果 layers_per_block 未設(shè)置為 1)
注意機(jī)制(通常僅用于輸入分辨率較低的 blocks)
timestep 的調(diào)節(jié)。
具有可學(xué)習(xí)參數(shù)的下采樣和上采樣塊
讓我們來創(chuàng)建并仔細(xì)研究一下 UNet2DModel:
正如你所看到的,還有更多!它比我們的 BasicUNet 有多得多的參數(shù)量:
1707009
我們可以用這個模型代替原來的模型來重復(fù)一遍上面展示的訓(xùn)練過程。我們需要將 x 和 timestep 傳遞給模型(這里我會傳遞 t = 0,以表明它在沒有 timestep 條件的情況下工作,并保持采樣代碼簡單,但您也可以嘗試輸入?(amount*1000)
,使 timestep 與噪聲水平相當(dāng))。如果要檢查代碼,更改的行將顯示為“#<<<
。
Finished epoch 0. Average loss for this epoch: 0.018925
Finished epoch 1. Average loss for this epoch: 0.012785
Finished epoch 2. Average loss for this epoch: 0.011694

這看起來比我們的第一組結(jié)果好多了!您可以嘗試調(diào)整 UNet 配置或更長時(shí)間的訓(xùn)練,以獲得更好的性能。
損壞過程
DDPM 論文描述了一個為每個“timestep”添加少量噪聲的損壞過程。為某些 timestep 給定??, 我們可以得到一個噪聲稍稍增加的?:

這就是說,我們?nèi)?, 給他一個??的系數(shù),然后加上帶有??系數(shù)的噪聲。這里??是根據(jù)一些管理器來為每一個 t 設(shè)定的,來決定每一個迭代周期中添加多少噪聲。現(xiàn)在,我們不想把這個推演進(jìn)行 500 次來得到?,所以我們用另一個公式來根據(jù)給出的??計(jì)算得到任意 t 時(shí)刻的?:

數(shù)學(xué)符號看起來總是很嚇人!幸運(yùn)的是,調(diào)度器為我們處理了所有這些(取消下一個單元格的注釋以檢查代碼)。我們可以畫出??(標(biāo)記為?sqrt_alpha_prod
) 和??(標(biāo)記為?sqrt_one_minus_alpha_prod
) 來看一下輸入 (x) 與噪聲是如何在不同迭代周期中量化和疊加的 :

一開始 , 噪聲 x 里絕大部分都是 x 自身的值 ?(sqrt_alpha_prod ~= 1),但是隨著時(shí)間的推移,x 的成分逐漸降低而噪聲的成分逐漸增加。與我們根據(jù)?amount
?對 x 和噪聲進(jìn)行線性混合不同,這個噪聲的增加相對較快。我們可以在一些數(shù)據(jù)上看到這一點(diǎn):

在運(yùn)行中的另一個變化:在 DDPM 版本中,加入的噪聲是取自一個高斯分布(來自均值 0 方差 1 的 torch.randn),而不是在我們原始?corrupt
?函數(shù)中使用的 0-1 之間的均勻分布(torch.rand),當(dāng)然對訓(xùn)練數(shù)據(jù)做正則化也可以理解。在另一篇筆記中,你會看到?Normalize(0.5, 0.5)
?函數(shù)在變化列表中,它把圖片數(shù)據(jù)從 (0, 1) 區(qū)間映射到 (-1, 1),對我們的目標(biāo)來說也‘足夠用了’。我們在此篇筆記中沒使用這個方法,但在上面的可視化中為了更好的展示添加了這種做法。
訓(xùn)練目標(biāo)
在我們的玩具示例中,我們讓模型嘗試預(yù)測去噪圖像。在 DDPM 和許多其他擴(kuò)散模型實(shí)現(xiàn)中,模型則會預(yù)測損壞過程中使用的噪聲(在縮放之前,因此是單位方差噪聲)。在代碼中,它看起來像是這樣:
你可能認(rèn)為預(yù)測噪聲(我們可以從中得出去噪圖像的樣子)等同于直接預(yù)測去噪圖像。那么,為什么要這么做呢?這僅僅是為了數(shù)學(xué)上的方便嗎?
這里其實(shí)還有另一些精妙之處。我們在訓(xùn)練過程中,會計(jì)算不同(隨機(jī)選擇)timestep 的 loss。這些不同的目標(biāo)將導(dǎo)致這些 loss 的不同的“隱含權(quán)重”,其中預(yù)測噪聲會將更多的權(quán)重放在較低的噪聲水平上。你可以選擇更復(fù)雜的目標(biāo)來改變這種“隱性損失權(quán)重”?;蛘?,您選擇的噪聲管理器將在較高的噪聲水平下產(chǎn)生更多的示例。也許你讓模型設(shè)計(jì)成預(yù)測 “velocity” v,我們將其定義為由噪聲水平影響的圖像和噪聲組合(請參閱“擴(kuò)散模型快速采樣的漸進(jìn)蒸餾”- 'PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS')。也許你將模型設(shè)計(jì)成預(yù)測噪聲,然后基于某些因子來對 loss 進(jìn)行縮放:比如有些理論指出可以參考噪聲水平(參見“擴(kuò)散模型的感知優(yōu)先訓(xùn)練”-'Perception Prioritized Training of Diffusion Models'),或者基于一些探索模型最佳噪聲水平的實(shí)驗(yàn)(參見“基于擴(kuò)散的生成模型的設(shè)計(jì)空間說明”-'Elucidating the Design Space of Diffusion-Based Generative Models')。
一句話解釋:選擇目標(biāo)對模型性能有影響,現(xiàn)在有許多研究者正在探索“最佳”選項(xiàng)是什么。目前,預(yù)測噪聲(epsilon 或 eps)是最流行的方法,但隨著時(shí)間的推移,我們很可能會看到庫中支持的其他目標(biāo),并在不同的情況下使用。
迭代周期(Timestep)調(diào)節(jié)
UNet2DModel 以 x 和 timestep 為輸入。后者被轉(zhuǎn)化為一個嵌入(embedding),并在多個地方被輸入到模型中。
這背后的理論支持是這樣的:通過向模型提供有關(guān)噪聲水平的信息,它可以更好地執(zhí)行任務(wù)。雖然在沒有這種 timestep 條件的情況下也可以訓(xùn)練模型,但在某些情況下,它似乎確實(shí)有助于性能,目前來說絕大多數(shù)的模型實(shí)現(xiàn)都包括了這一輸入。
取樣(采樣)
有一個模型可以用來預(yù)測在帶噪樣本中的噪聲(或者說能預(yù)測其去噪版本),我們怎么用它來生成圖像呢?
我們可以給入純噪聲,然后就希望模型能一步就輸出一個不帶噪聲的好圖像。但是,就我們上面所見到的來看,這通常行不通。所以,我們在模型預(yù)測的基礎(chǔ)上使用足夠多的小步,迭代著來每次去除一點(diǎn)點(diǎn)噪聲。
具體我們怎么走這些小步,取決于使用上面取樣方法。我們不會去深入討論太多的理論細(xì)節(jié),但是一些頂層想法是這樣:
每一步你想走多大?也就是說,你遵循什么樣的“噪聲計(jì)劃(噪聲管理)”?
你只使用模型當(dāng)前步的預(yù)測結(jié)果來指導(dǎo)下一步的更新方向嗎(像 DDPM,DDIM 或是其他的什么那樣)?你是否要使用模型來多預(yù)測幾次來估計(jì)一個更高階的梯度來更新一步更大更準(zhǔn)確的結(jié)果(更高階的方法和一些離散 ODE 處理器)?或者保留歷史預(yù)測值來嘗試更好的指導(dǎo)當(dāng)前步的更新(線性多步或遺傳取樣器)?
你是否會在取樣過程中額外再加一些隨機(jī)噪聲,或你完全已知的(deterministic)來添加噪聲?許多取樣器通過參數(shù)(如 DDIM 中的 'eta')來供用戶選擇。
對于擴(kuò)散模型取樣器的研究演進(jìn)的很快,隨之開發(fā)出了越來越多可以使用更少步就找到好結(jié)果的方法。勇敢和有好奇心的人可能會在瀏覽 diffusers library 中不同部署方法時(shí)感到非常有意思,可以查看 Schedulers 代碼 或看看 Schedulers 文檔,這里經(jīng)常有一些相關(guān)的論文。
Schedulers 代碼:
https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulersSchedulers 文檔:
https://huggingface.co/docs/diffusers/main/en/api/schedulers
結(jié)語
希望這可以從一些不同的角度來審視擴(kuò)散模型提供一些幫助。這篇筆記是 Jonathan Whitaker 為 Hugging Face 課程所寫的,如果你對從噪聲和約束分類來生成樣本的例子感興趣。問題與 bug 可以通過 GitHub issues 或 Discord 來交流。
致謝第一單元第二部分社區(qū)貢獻(xiàn)者
感謝社區(qū)成員們對本課程的貢獻(xiàn):
@darcula1993、@XhrLeokk:魔都強(qiáng)人工智能孵化者,二里街調(diào)參記錄保持人,一切興趣使然的 AIGC 色圖創(chuàng)作家的庇護(hù)者,圖靈神在五角場的唯一指定路上行走。
感謝茶葉蛋蛋對本文貢獻(xiàn)設(shè)計(jì)素材!
歡迎通過鏈接加入我們的本地化小組與大家共同交流:
https://bit.ly/3G40j6U