最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

AIGC: Progressive Distillation 筆記

2023-08-18 09:34 作者:剎那-Ksana-  | 我要投稿

Google 出品,必屬精品?

DDIM 知識蒸餾(Knowledge Distillation)

我們先從 DDIM 的知識蒸餾開始(2101.02388),在這個知識蒸餾的設定里面,我們有一個老師 (teacher) 和一個學生 (student),學生的目標是讓自己的輸出?p_%7Bstudent%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T) 盡量地接近老師的輸出 p_%7Bteacher%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T),用數(shù)學公式表達,就是最小化:

L_%7Bstudent%7D%3D%20%5Cmathbb%7BE%7D_%7Bx_T%7D%5B%20D_%7BKL%7D(p_%7Bteacher%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T)%20%7C%7C%20p_%7Bstudent%7D(%5Cmathbf%7Bx%7D_0%20%7C%20%5Cmathbf%7Bx%7D_T))%20%5D

另外,知識蒸餾有一個要求是,輸出需要是確定的 (deterministic),所以這里采用的是 DDIM 的設定。

逐步蒸餾(Progressive Distillation)

示意圖;x 代表了我們通常意義的?x0, z 代表了中間步驟 x_t?

逐步蒸餾(2202.00512)的設定很簡單,相比于上面知識蒸餾的一步到位,逐步蒸餾采用的是分步進行蒸餾——首先有一個通過 N 步 DDIM 訓練好的老師,然后有一個長得和老師一模一樣的學生,想要以自己?1 步的輸出去貼近老師 2 步的輸出(意味著學生 DDIM 只需要 N%2F2?步),當這個學生學習結(jié)束以后,這個學生就成了新的老師,然后重復如上的過程。

論文對于擴散過程,用了一個更廣泛的設定?q(z_t%7C%5Cmathbf%7Bx_0%7D)%3D%5Cmathcal%7BN%7D(%5Calpha_t%20%5Cmathbf%7Bx%7D_0%2C%20%5Csigma_t%5E2%20%5Cmathbf%7BI%7D). 我們通常所見到的?Variance Preserving 擴散過程,是其在?%5Csigma_t%3D%5Csqrt%7B1-%5Calpha_t%5E2%7D 時的特例。z_t?是所謂的?latent, 其實就是?%5Cmathbf%7Bx%7D_0?加噪后的數(shù)據(jù).?t%5Cin%20%5B0%2C1%5D.

這里,我們在離散時間上進行訓練和蒸餾,并且采用余弦方案 %5Calpha_t%20%3D%20%5Ccos(0.5%5Cpi%20t),%5Cmathbf%7Bz_1%7D?代表了純高斯噪聲?%5Cmathcal%7BN%7D(0%2CI)(注意下標?t 的范圍是從0到1)。

我們這里再定義一個信噪比 (Signal-to-Noise Ratio)?SNR(t)%3D%5Calpha_t%5E2%20%2F%20%5Csigma_t%5E2.??在 z_1?的時候,很明顯?%5Calpha_%7B1%7D%3D0,?%5Csigma_1%5E2%3D1,?故信噪比為?0.

Loss

針對 loss 函數(shù)我們有

L_%7B%5Ctheta%7D%20%3D%20%5ClVert%20%5Cepsilon%20-%20%5Chat%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%5CrVert_%7B2%7D%5E%7B2%7D%20%3D%20%5Cleft%5C%7C%20%5Cfrac%7B1%7D%7B%5Csigma_t%7D(%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Calpha_t%7B%5Cmathbf%7Bx%7D_0%7D)%20-%20%5Cfrac%7B1%7D%7B%5Csigma_t%7D(%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Calpha_t%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t))%5Cright%5C%7C_%7B2%7D%5E%7B2%7D%20%3D%20%5Cfrac%7B%5Calpha%5E%7B2%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D%20%5ClVert%20%7B%5Cmathbf%7Bx%7D_0%7D%20-%20%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%20%5CrVert_%7B2%7D%5E%7B2%7D

%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(z_t%2C%20t)%3D(%5Cmathbf%7Bz%7D_t-%5Csigma_t%20%5Cepsilon_%5Ctheta%20(z_t%2Ct))%2F%20%5Calpha_t?代表了在 t 時間點所生成的圖片。在公式做了如上的變形之后,我們可以把 loss 看成是在?%5Cmathbf%7Bx%7D?空間里面的函數(shù)(預測圖像和原圖像的距離),而信噪比則控制了 loss 的權(quán)重 (weight). 這里我們把這個權(quán)重稱作權(quán)重函數(shù) (weighting function). 當然,我們還可以設計各種不同的權(quán)重函數(shù)。

這里,論文討論了一個很有趣的現(xiàn)象——當?%5Calpha_t%20%5Cto%200?時(即擴散初期),因為?%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(z_t%2C%20t)%3D(%5Cmathbf%7Bz%7D_t-%5Csigma_t%20%5Cepsilon_%5Ctheta%20(z_t%2Ct))%2F%20%5Calpha_t, 所以 %5Cepsilon_%5Ctheta%20 任何一點小的波動都會被超級放大。在蒸餾的初期,因為我們的步數(shù)很多,早期的一些的錯誤會在后期被修復;但是越往下蒸餾,步數(shù)越少的時候,這種情況就要出問題了。在極端的情況下,如果我們這個逐步蒸餾,進行到只剩下一步了(意味著直接從純高斯噪聲一步生成圖片),那么這個時候,整個 loss 也變成 0 了,學生就學不到任何東西了。

對此,論文里面有三種解決方案:

  1. 直接預測 x (繞過了 %5Calpha_t?在分母上的問題)

  2. 預測?%5Ctilde%7B%5Cepsilon%7D_%5Ctheta?的同時,也預測 %5Ctilde%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D,然后用公式?%5Chat%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Csigma%5E%7B2%7D_t%5Ctilde%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%20%2B%20%5Calpha_%7Bt%7D(%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Csigma_t%5Ctilde%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t))%3D(1-%5Calpha_t%5E2)%5C%20%5Ctilde%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%5Cmathbf%7Bz%7D_t)%20%2B%20%5Calpha_t%5E2%20%20%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%5Cmathbf%7Bz%7D_t)?生成圖片。(兩種渠道預測的 %5Cmathbf%7Bx%7D?加權(quán)求和)

  3. 預測?%5Cmathbf%7Bv%7D%3A%3D%5Calpha_t%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20-%20%5Csigma_%7Bt%7D%7B%5Cmathbf%7Bx%7D%7D, 然后?%5Chat%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Calpha_t%7B%5Cmathbf%7Bz%7D%7D_t%20-%20%5Csigma_t%5Chat%7B%5Cmathbf%7Bv%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)

三種解決方案+原方案的效果對比,第一個數(shù)是FID(越低越好),第二個數(shù)是IS(越高越好);論文中認為,這三種解決方案都能取得不錯的效果;并且,三種方案都可以直接用來訓練去噪擴散模型;N/A 意味著這個過程不穩(wěn)定

另外,論文里面還提出了兩種可行的 loss 的方案:

  1. ?Truncated SNR:?L_%7B%5Ctheta%7D%20%3D%20%5Ctext%7Bmax%7D(%5ClVert%20%7B%5Cmathbf%7Bx%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7Bx%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D%2C%20%5ClVert%20%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7B%5Cepsilon%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D)%20%3D%20%5Ctext%7Bmax%7D(%5Cfrac%7B%5Calpha%5E%7B2%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D%2C%201)%5ClVert%20%7B%5Cmathbf%7Bx%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7Bx%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D

  2. SNR+1:?L_%7B%5Ctheta%7D%20%3D%20%5ClVert%20%7B%5Cmathbf%7Bv%7D%7D_t%20-%20%5Chat%7B%7B%5Cmathbf%7Bv%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D%20%3D%20(1%2B%5Cfrac%7B%5Calpha%5E%7B2%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D)%5ClVert%20%7B%5Cmathbf%7Bx%7D%7D%20-%20%5Chat%7B%7B%5Cmathbf%7Bx%7D%7D%7D_t%20%5CrVert_%7B2%7D%5E%7B2%7D

DDIM Angular Parameterization

這里,我們對 DDIM 從另一個角度進行切入?%5Cphi_%7Bt%7D%20%3D%20%5Ctext%7Barctan%7D(%5Csigma_%7Bt%7D%2F%5Calpha_%7Bt%7D),所以?%5Calpha_%7B%5Cphi%7D%20%3D%20%5Ccos(%5Cphi)%2C%20%5Csigma_%7B%5Cphi%7D%3D%5Csin(%5Cphi). 顯然,由 %5Cmathbf%7Bz%7D_t%20%3D%20%5Calpha_t%20%5Cmathbf%7Bx%7D_0%20%2B%20%5Csigma_t%20%5Cepsilon, 我們可得?%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%20%3D%20%5Ccos(%5Cphi)%7B%5Cmathbf%7Bx%7D_0%7D%20%2B%20%5Csin(%5Cphi)%7B%5Cmathbf%7B%5Cepsilon%7D%7D.

接下來,我們定義?%5Cmathbf%7Bz%7D_%5Cphi?的速度 (velocity) 為:

%5Cmathbf%7Bv%7D_%5Cphi%20%3A%3D%20%5Cfrac%7Bd%20%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%7D%7Bd%5Cphi%7D%20%3D%20%5Cfrac%7Bd%5Ccos(%5Cphi)%7D%7Bd%5Cphi%7D%7B%5Cmathbf%7Bx%7D%7D%20%2B%20%5Cfrac%7Bd%5Csin(%5Cphi)%7D%7Bd%5Cphi%7D%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20%3D%5Ccos(%5Cphi)%7B%5Cmathbf%7B%5Cepsilon%7D%7D%20-%20%5Csin(%5Cphi)%7B%5Cmathbf%7Bx%7D%7D

利用三角函數(shù)的那些定理(初高中知識哦),對上面的公式變形后,我們可以得到:

%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Ccos(%5Cphi)%7B%5Cmathbf%7Bz%7D%7D%20-%20%5Csin(%5Cphi)%7B%5Cmathbf%7Bv%7D%7D_%7B%5Cphi%7D

%5Cepsilon%20%3D%20%5Csin(%5Cphi)%20%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%20%2B%20%5Ccos(%5Cphi)%20%7B%5Cmathbf%7Bv%7D%7D_%7B%5Cphi%7D

在這里,我們再定義一個預測速度 (predicted velocity):?

%5Chat%7B%5Cmathbf%7Bv%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%20%5Cequiv%20%5Ccos(%5Cphi)%5Chat%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%20-%20%5Csin(%5Cphi)%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%0A

根據(jù)公式?%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(z_t%2C%20t)%3D(%5Cmathbf%7Bz%7D_t-%5Csigma_t%20%5Cepsilon_%5Ctheta%20(z_t%2Ct))%2F%20%5Calpha_t,我們有:?

%5Chat%7B%5Cmathbf%7B%5Cepsilon%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D)%20%3D%20(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D%20-%20%5Ccos(%5Cphi)%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi%7D))%2F%5Csin(%5Cphi)

所以這里解釋了上一節(jié)的解決方案3的公式由來。

接下來我們要做的只是一些公式變形了,最終我們會得到:

%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bs%7D%7D%20%3D%20%5Ccos(%5Cphi_s%20-%20%5Cphi_t)%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bt%7D%7D%20%2B%20%5Csin(%5Cphi_s%20-%20%5Cphi_t)%5Chat%7B%5Cmathbf%7Bv%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_t%7D)%20

%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bt%7D-%5Cdelta%7D%20%3D%20%5Ccos(%5Cdelta)%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_%7Bt%7D%7D%20-%20%5Csin(%5Cdelta)%5Chat%7B%5Cmathbf%7Bv%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_%7B%5Cphi_t%7D)

從純高斯噪聲e到原圖像x,我們是朝著 -v 的方向沿著一個圓弧在前進

學習目標

對于每一步的更新,其方法是可以有很多種的。

這里,論文里面使用的更新公式為:

%7B%5Cmathbf%7Bz%7D%7D_s%20%3D%20%5Calpha_s%20%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%7B%5Cmathbf%7Bz%7D%7D_t)%20%2B%20%5Csigma_s%5Cfrac%7B%7B%5Cmathbf%7Bz%7D%7D_t-%5Calpha_t%5Chat%7B%5Cmathbf%7Bx%7D%7D_%5Ctheta(%5Cmathbf%7Bz%7D_t)%7D%7B%5Csigma_t%7D, 對其求導的話就可以得到?d%7B%5Cmathbf%7Bz%7D%7D%20%3D%20%5Bf(%7B%5Cmathbf%7Bz%7D%7D%2C%20t)%20-%20%5Cfrac%7B1%7D%7B2%7Dg%5E%7B2%7D(t)%5Cnabla_%7Bz%7D%5Clog%20p_%7Bt%7D(%7B%5Cmathbf%7Bz%7D%7D)%20%5Ddt.(這里論文假定了 score function %5Cnabla_%7Bz%7D%5Clog%20p_%7Bt%7D(%5Cmathbf%7Bz%7D)?可以用 %5Cnabla_%7Bz%7D%5Clog%20p_%7Bt%7D(%7B%5Cmathbf%7Bz%7D%7D)%20%5Capprox%20%5Cfrac%7B%5Calpha_%7Bt%7D%5Chat%7B%5Cmathbf%7Bx%7D%7D_%7B%5Ctheta%7D(%7B%5Cmathbf%7Bz%7D%7D_t)%20-%20%7B%5Cmathbf%7Bz%7D%7D_t%7D%7B%5Csigma%5E%7B2%7D_t%7D 來近似;詳細過程見論文附錄)

有了更新公式以后,接下來的事情就簡單了,根據(jù)公式先計算前一步的?%5Cmathbf%7Bz%7D_%7Bt'%7D, 根據(jù)?%5Cmathbf%7Bz%7D_%7Bt'%7D再計算前一步的?%5Cmathbf%7Bz%7D_%7Bt''%7D. 然后我們計算目標?%5Ctilde%7B%5Cmathbf%7Bx%7D%7D%20%3D%20%5Cfrac%7B%7B%5Cmathbf%7Bz%7D%7D_%7Bt''%7D%20-%20%5Cfrac%7B%5Csigma_%7Bt''%7D%7D%7B%5Csigma_%7Bt%7D%7D%7B%5Cmathbf%7Bz%7D%7D_t%7D%7B%5Calpha_%7Bt''%7D%20-%20%5Cfrac%7B%5Csigma_%7Bt''%7D%7D%7B%5Csigma_%7Bt%7D%7D%5Calpha_t%7D,最小化上述的 loss. 大功告成。

完。

注:B站的公式編輯器頻繁抽風,如果遇到一些 tex parse error 之類的錯誤時,嘗試刷新一下頁面。

AIGC: Progressive Distillation 筆記的評論 (共 條)

分享到微博請遵守國家法律
农安县| 吉水县| 蒲江县| 揭东县| 沛县| 景东| 凌源市| 武穴市| 合作市| 长春市| 泗水县| 冷水江市| 黎平县| 碌曲县| 井陉县| 丹东市| 宁海县| 金湖县| 怀远县| 松溪县| 长沙县| 海阳市| 龙泉市| 肇州县| 静乐县| 调兵山市| 玉环县| 漳浦县| 营口市| 桃园县| 鹤庆县| 马山县| 七台河市| 张北县| 汉寿县| 九龙坡区| 宜宾县| 鄂托克旗| 罗江县| 石门县| 东安县|