使用 DDPO 在 TRL 中微調(diào) Stable Diffusion 模型

引言
擴(kuò)散模型 (如 DALL-E 2、Stable Diffusion) 是一類文生圖模型,在生成圖像 (尤其是有照片級真實感的圖像) 方面取得了廣泛成功。然而,這些模型生成的圖像可能并不總是符合人類偏好或人類意圖。因此出現(xiàn)了對齊問題,即如何確保模型的輸出與人類偏好 (如“質(zhì)感”) 一致,或者與那種難以通過提示來表達(dá)的意圖一致?這里就有強(qiáng)化學(xué)習(xí)的用武之地了。
在大語言模型 (LLM) 領(lǐng)域,強(qiáng)化學(xué)習(xí) (RL) 已被證明是能讓目標(biāo)模型符合人類偏好的非常有效的工具。這是 ChatGPT 等系統(tǒng)卓越性能背后的主要秘訣之一。更準(zhǔn)確地說,強(qiáng)化學(xué)習(xí)是人類反饋強(qiáng)化學(xué)習(xí) (RLHF) 的關(guān)鍵要素,它使 ChatGPT 能像人類一樣聊天。
在 Training Diffusion Models with Reinforcement Learning 一文中,Black 等人展示了如何利用 RL 來對擴(kuò)散模型進(jìn)行強(qiáng)化,他們通過名為去噪擴(kuò)散策略優(yōu)化 (Denoising Diffusion Policy Optimization,DDPO) 的方法針對模型的目標(biāo)函數(shù)實施微調(diào)。
在本文中,我們討論了 DDPO 的誕生、簡要描述了其工作原理,并介紹了如何將 DDPO 加入 RLHF 工作流中以實現(xiàn)更符合人類審美的模型輸出。然后,我們切換到實戰(zhàn),討論如何使用?trl
?庫中新集成的?DDPOTrainer
?將 DDPO 應(yīng)用到模型中,并討論我們在 Stable Diffusion 上運(yùn)行 DDPO 的發(fā)現(xiàn)。
DDPO 的優(yōu)勢
DDPO 并非解決?如何使用 RL 微調(diào)擴(kuò)散模型
?這一問題的唯一有效答案。
在進(jìn)一步深入討論之前,我們強(qiáng)調(diào)一下在對 RL 解決方案進(jìn)行橫評時需要掌握的兩個關(guān)鍵點(diǎn):
計算效率是關(guān)鍵。數(shù)據(jù)分布越復(fù)雜,計算成本就越高。
近似法很好,但由于近似值不是真實值,因此相關(guān)的錯誤會累積。
在 DDPO 之前,獎勵加權(quán)回歸 (Reward-Weighted Regression,RWR) 是使用強(qiáng)化學(xué)習(xí)微調(diào)擴(kuò)散模型的主要方法。RWR 重用了擴(kuò)散模型的去噪損失函數(shù)、從模型本身采樣得的訓(xùn)練數(shù)據(jù)以及取決于最終生成樣本的獎勵的逐樣本損失權(quán)重。該算法忽略中間的去噪步驟/樣本。雖然有效,但應(yīng)該注意兩件事:
通過對逐樣本損失進(jìn)行加權(quán)來進(jìn)行優(yōu)化,這是一個最大似然目標(biāo),因此這是一種近似優(yōu)化。
加權(quán)后的損失甚至不是精確的最大似然目標(biāo),而是從重新加權(quán)的變分界中得出的近似值。
所以,根本上來講,這是一個兩階近似法,其對性能和處理復(fù)雜目標(biāo)的能力都有比較大的影響。
DDPO 始于此方法,但 DDPO 沒有將去噪過程視為僅關(guān)注最終樣本的單個步驟,而是將整個去噪過程構(gòu)建為多步馬爾可夫決策過程 (MDP),只是在最后收到獎勵而已。這樣做的好處除了可以使用固定的采樣器之外,還為讓代理策略成為各向同性高斯分布 (而不是任意復(fù)雜的分布) 鋪平了道路。因此,該方法不使用最終樣本的近似似然 (即 RWR 的做法),而是使用易于計算的每個去噪步驟的確切似然 (??)。
如果你有興趣了解有關(guān) DDPO 的更多詳細(xì)信息,我們鼓勵你閱讀 原論文 及其 附帶的博文。
DDPO 算法簡述
考慮到我們用 MDP 對去噪過程進(jìn)行建模以及其他因素,求解該優(yōu)化問題的首選工具是策略梯度方法。特別是近端策略優(yōu)化 (PPO)。整個 DDPO 算法與近端策略優(yōu)化 (PPO) 幾乎相同,僅對 PPO 的軌跡收集部分進(jìn)行了比較大的修改。
下圖總結(jié)了整個算法流程:

DDPO 和 RLHF: 合力增強(qiáng)美觀性
RLHF 的一般訓(xùn)練步驟如下:
有監(jiān)督微調(diào)“基礎(chǔ)”模型,以學(xué)習(xí)新數(shù)據(jù)的分布。
收集偏好數(shù)據(jù)并用它訓(xùn)練獎勵模型。
使用獎勵模型作為信號,通過強(qiáng)化學(xué)習(xí)對模型進(jìn)行微調(diào)。
需要指出的是,在 RLHF 中偏好數(shù)據(jù)是獲取人類反饋的主要來源。
DDPO 加進(jìn)來后,整個工作流就變成了:
從預(yù)訓(xùn)練的擴(kuò)散模型開始。
收集偏好數(shù)據(jù)并用它訓(xùn)練獎勵模型。
使用獎勵模型作為信號,通過 DDPO 微調(diào)模型
請注意,DDPO 工作流把原始 RLHF 工作流中的第 3 步省略了,這是因為經(jīng)驗表明 (后面你也會親眼見證) 這是不需要的。
下面我們實戰(zhàn)一下,訓(xùn)練一個擴(kuò)散模型來輸出更符合人類審美的圖像,我們分以下幾步來走:
從預(yù)訓(xùn)練的 Stable Diffusion (SD) 模型開始。
在 美學(xué)視覺分析 (Aesthetic Visual Analysis,AVA) ?數(shù)據(jù)集上訓(xùn)練一個帶有可訓(xùn)回歸頭的凍結(jié) CLIP 模型,用于預(yù)測人們對輸入圖像的平均喜愛程度。
使用美學(xué)預(yù)測模型作為獎勵信號,通過 DDPO 微調(diào) SD 模型。
記住這些步驟,下面開始干活:
使用 DDPO 訓(xùn)練 Stable Diffusion
環(huán)境設(shè)置
首先,要成功使用 DDPO 訓(xùn)練模型,你至少需要一個英偉達(dá) A100 GPU,低于此規(guī)格的 GPU 很容易遇到內(nèi)存不足問題。
使用 pip 安裝?trl
?庫
主庫安裝好后,再安裝所需的訓(xùn)練過程跟蹤和圖像處理相關(guān)的依賴庫。注意,安裝完?wandb
?后,請務(wù)必登錄以將結(jié)果保存到個人帳戶。
注意: 如果不想用?wandb
?,你也可以用?pip
?安裝?tensorboard
?。
演練一遍
trl
?庫中負(fù)責(zé) DDPO 訓(xùn)練的主要是?DDPOTrainer
?和?DDPOConfig
?這兩個類。有關(guān)?DDPOTrainer
?和?DDPOConfig
?的更多信息,請參閱 相應(yīng)文檔。trl
?代碼庫中有一個 示例訓(xùn)練腳本。它默認(rèn)使用這兩個類,并有一套默認(rèn)的輸入和參數(shù)用于微調(diào)?RunwayML
?中的預(yù)訓(xùn)練 Stable Diffusion 模型。
此示例腳本使用?wandb
?記錄訓(xùn)練日志,并使用美學(xué)獎勵模型,其權(quán)重是從公開的 Hugging Face 存儲庫讀取的 (因此數(shù)據(jù)收集和美學(xué)獎勵模型訓(xùn)練均已經(jīng)幫你做完了)。默認(rèn)提示數(shù)據(jù)是一系列動物名。
用戶只需要一個命令行參數(shù)即可啟動腳本。此外,用戶需要有一個 Hugging Face 用戶訪問令牌,用于將微調(diào)后的模型上傳到 Hugging Face Hub。
運(yùn)行以下 bash 命令啟動程序:
下表列出了影響微調(diào)結(jié)果的關(guān)鍵超參數(shù):

這個腳本僅僅是一個起點(diǎn)。你可以隨意調(diào)整超參數(shù),甚至徹底修改腳本以適應(yīng)不同的目標(biāo)函數(shù)。例如,可以集成一個測量 JPEG 壓縮度的函數(shù)或 使用多模態(tài)模型評估視覺文本對齊度的函數(shù) 等。
經(jīng)驗與教訓(xùn)
盡管訓(xùn)練提示很少,但其結(jié)果似乎已經(jīng)足夠泛化。對于美學(xué)獎勵函數(shù)而言,該方法已經(jīng)得到了徹底的驗證。
嘗試通過增加訓(xùn)練提示數(shù)以及改變提示來進(jìn)一步泛化美學(xué)獎勵函數(shù),似乎反而會減慢收斂速度,但對模型的泛化能力收效甚微。
雖然推薦使用久經(jīng)考驗 LoRA,但非 LoRA 也值得考慮,一個經(jīng)驗證據(jù)就是,非 LoRA 似乎確實比 LoRA 能產(chǎn)生相對更復(fù)雜的圖像。但同時,非 LoRA 訓(xùn)練的收斂穩(wěn)定性不太好,對超參選擇的要求也高很多。
對于非 LoRA 的超參建議是: 將學(xué)習(xí)率設(shè)低點(diǎn),經(jīng)驗值是大約?
1e-5
?,同時將?mixed_ precision
?設(shè)置為?None
?。
結(jié)果
以下是提示?bear
?、?heaven
?和?dune
?微調(diào)前 (左) 、后 (右) 的輸出 (每行都是一個提示的輸出):

限制
目前?
trl
?的?DDPOTrainer
?僅限于微調(diào)原始 SD 模型;在我們的實驗中,主要關(guān)注的是效果較好的 LoRA。我們也做了一些全模型訓(xùn)練的實驗,其生成的質(zhì)量會更好,但超參尋優(yōu)更具挑戰(zhàn)性。
總結(jié)
像 Stable Diffusion 這樣的擴(kuò)散模型,當(dāng)使用 DDPO 進(jìn)行微調(diào)時,可以顯著提高圖像的主觀質(zhì)感或其對應(yīng)的指標(biāo),只要其可以表示成一個目標(biāo)函數(shù)的形式。
DDPO 的計算效率及其不依賴近似優(yōu)化的能力,在擴(kuò)散模型微調(diào)方面遠(yuǎn)超之前的方法,因而成為微調(diào)擴(kuò)散模型 (如 Stable Diffusion) 的有力候選。
trl
?庫的?DDPOTrainer
?實現(xiàn)了 DDPO 以微調(diào) SD 模型。
我們的實驗表明 DDPO 對很多提示具有相當(dāng)好的泛化能力,盡管進(jìn)一步增加提示數(shù)以增強(qiáng)泛化似乎效果不大。為非 LoRA 微調(diào)找到正確超參的難度比較大,這也是我們得到的重要經(jīng)驗之一。
DDPO 是一種很有前途的技術(shù),可以將擴(kuò)散模型與任何獎勵函數(shù)結(jié)合起來,我們希望通過其在 TRL 中的發(fā)布,社區(qū)可以更容易地使用它!
致謝
感謝 Chunte Lee 提供本博文的縮略圖。
英文原文:?https://hf.co/blog/trl-ddpo
原文作者: Luke Meyers,Sayak Paul,Kashif Rasul,Leandro von Werra
譯者: Matrix Yao (姚偉峰),英特爾深度學(xué)習(xí)工程師,工作方向為 transformer-family 模型在各模態(tài)數(shù)據(jù)上的應(yīng)用及大規(guī)模模型的訓(xùn)練推理。
審校/排版: zhongdongy (阿東)