RWKV – transformer 與 RNN 的強(qiáng)強(qiáng)聯(lián)合

在 NLP (Natural Language Processing, 自然語言處理) 領(lǐng)域,ChatGPT 和其他的聊天機(jī)器人應(yīng)用引起了極大的關(guān)注。每個(gè)社區(qū)為構(gòu)建自己的應(yīng)用,也都在持續(xù)地尋求強(qiáng)大、可靠的開源模型。自 Vaswani 等人于 2017 年首次提出?Attention Is All You Need?之后,基于 transformer 的強(qiáng)大的模型一直在不斷地涌現(xiàn),它們?cè)?NLP 相關(guān)任務(wù)上的表現(xiàn)遠(yuǎn)遠(yuǎn)超過基于 RNN (Recurrent Neural Networks, 遞歸神經(jīng)網(wǎng)絡(luò)) 的 SoTA 模型,甚至多數(shù)認(rèn)為 RNN 已死。而本文將介紹一個(gè)集 RNN 和 transformer 兩者的優(yōu)勢(shì)于一身的全新網(wǎng)絡(luò)架構(gòu)——RWKV!現(xiàn)已在 HuggingFace?transformers?庫中支持。
RWKV 項(xiàng)目概覽
RWKV 項(xiàng)目已經(jīng)啟動(dòng),由 Bo Peng 主導(dǎo)、貢獻(xiàn)和維護(hù)。同時(shí)項(xiàng)目成員在官方 Discord 也開設(shè)了不同主題的討論頻道: 如性能 (RWKV.cpp、量化等),擴(kuò)展性 (數(shù)據(jù)集收集和處理),相關(guān)研究 (chat 微調(diào)、多模態(tài)微調(diào)等)。該項(xiàng)目中訓(xùn)練 RWKV 模型所需的 GPU 資源由 Stability AI 提供。
讀者可以加入 官方 discord 頻道 了解詳情或者參與討論。如想了解 RWKV 背后的思想,可以參考這兩篇博文:
https://johanwind.github.io/2023/03/23/rwkv_overview.html
https://johanwind.github.io/2023/03/23/rwkv_details.html
Transformer 與 RNN 架構(gòu)對(duì)比
RNN 架構(gòu)是最早廣泛用于處理序列數(shù)據(jù)的神經(jīng)網(wǎng)絡(luò)架構(gòu)之一。與接收固定輸入尺寸的經(jīng)典架構(gòu)不同,RNN 接收當(dāng)前時(shí)刻的 “token”(即數(shù)據(jù)流中的當(dāng)前數(shù)據(jù)點(diǎn)) 和先前時(shí)刻的 “狀態(tài)” 作為輸入,通過網(wǎng)絡(luò)預(yù)測(cè)輸出下一時(shí)刻的 “token” 和 ?“狀態(tài)”,同時(shí)輸出的 “狀態(tài)” 還能繼續(xù)用到后續(xù)的預(yù)測(cè)中去,一直到序列末尾。RNN 還可以用于不同的 “模式”,適用于多種不同的場(chǎng)景。參考 Andrej Karpathy 的博客,RNN 可以用于: 一對(duì)一 (圖像分類),一對(duì)多 (圖像描述),多對(duì)一 (序列分類),多對(duì)多 (序列生成),等等。

由于 RNN 在計(jì)算每一時(shí)刻的預(yù)測(cè)值時(shí)使用的都是同一組網(wǎng)絡(luò)權(quán)重,因此 RNN 很難解決長距離序列信息的記憶問題,這一定程度上也是訓(xùn)練過程中梯度消失導(dǎo)致的。為解決這個(gè)問題,相繼有新的網(wǎng)絡(luò)架構(gòu)被提出,如 LSTM 或者 GRU,其中 transformer 是已被證實(shí)最有效的架構(gòu)。
在 transformer 架構(gòu)中,不同時(shí)刻的輸入 token 可以在 self-attention 模塊中并行處理。首先 token 經(jīng)過 Q、K、V 權(quán)重矩陣做線性變換投影到不同的空間,得到的 Q、K 矩陣用于計(jì)算注意力分?jǐn)?shù) (通過 softmax,如下圖所示),然后乘以 V 的隱狀態(tài)得到最終的隱狀態(tài),這種架構(gòu)設(shè)計(jì)可以有效緩解長距離序列問題,同時(shí)具有比 RNN 更快的訓(xùn)練和推理速度。


在訓(xùn)練過程中,Transformer 架構(gòu)相比于傳統(tǒng)的 RNN 和 CNN 有多個(gè)優(yōu)勢(shì),最突出的優(yōu)勢(shì)是它能夠?qū)W到上下文特征表達(dá)。不同于每次僅處理輸入序列中一個(gè) token 的 RNN 和 CNN,transformer 可以單次處理整個(gè)輸入序列,這種特性也使得 transformer 可以很好地應(yīng)對(duì)長距離序列 token 依賴問題,因此 transformer 在語言翻譯和問答等多種任務(wù)中表現(xiàn)非常亮眼。
在推理過程中,RNN 架構(gòu)在推理速度和內(nèi)存效率方面會(huì)具有一些優(yōu)勢(shì)。例如計(jì)算簡(jiǎn)單 (只需矩陣 - 向量運(yùn)算) 、內(nèi)存友好 (內(nèi)存不會(huì)隨著推理階段的進(jìn)行而增加),速度穩(wěn)定 (與上下文窗口長度一致,因?yàn)?RNN 只關(guān)注當(dāng)前時(shí)刻的 token 和狀態(tài))。
RWKV 架構(gòu)
RWKV 的靈感來自于 Apple 公司的 Attention Free Transformer。RWKV 該架構(gòu)經(jīng)過精心簡(jiǎn)化和優(yōu)化,可以轉(zhuǎn)換為 RNN。除此此外,為使 RWKV 性能媲美 GPT,還額外使用了許多技巧,例如?TokenShift
?和 ?SmallInitEmb
?(使用的完整技巧列表在 官方 GitHub 倉庫的 README 中 說明)。對(duì)于 RWKV 的訓(xùn)練,現(xiàn)有的項(xiàng)目倉庫可以將參數(shù)量擴(kuò)展到 14B,并且迭代修了 RWKV-4 的一些訓(xùn)練問題,例如數(shù)值不穩(wěn)定性等。
RWKV 是 RNN 和 Transformer 的強(qiáng)強(qiáng)聯(lián)合
如何把 transformer 和 RNN 優(yōu)勢(shì)結(jié)合起來?基于 transformer 的模型的主要缺點(diǎn)是,在接收超出上下文長度預(yù)設(shè)值的輸入時(shí),推理結(jié)果可能會(huì)出現(xiàn)潛在的風(fēng)險(xiǎn),因?yàn)樽⒁饬Ψ謹(jǐn)?shù)是針對(duì)訓(xùn)練時(shí)的預(yù)設(shè)值來同時(shí)計(jì)算整個(gè)序列的。
RNN 本身支持非常長的上下文長度。即使在訓(xùn)練時(shí)接收的上下文長度有限,RNN 也可以通過精心的編碼,來得到數(shù)百萬長度的推理結(jié)果。目前,RWKV 模型使用上下文長度上為 8192 (?ctx8192
) 和 ?ctx1024
?時(shí)的訓(xùn)練速度和內(nèi)存需求均相同。
傳統(tǒng) RNN 模型的主要缺陷,以及 RWKV 是如何避免的:
傳統(tǒng)的 RNN 模型無法利用很長距離的上下文信息 (LSTM 用作語言模型時(shí)也只能有效處理約 100 個(gè) token),而 RWKV 可以處理數(shù)千個(gè)甚至更多的 token,如下圖所示:

傳統(tǒng)的 RNN 模型無法并行訓(xùn)練,而 RWKV 更像一個(gè) “線性 GPT”,因此比 GPT 訓(xùn)練得更快。
通過將這兩個(gè)優(yōu)勢(shì)強(qiáng)強(qiáng)聯(lián)合,希望 RWKV 可以實(shí)現(xiàn) “1 + 1 > 2” 的效果。
RWKV 注意力公式
RWKV 模型架構(gòu)與經(jīng)典的 transformer 模型架構(gòu)非常相似 (例如也包含 embedding 層、Layer Normalization、用于預(yù)測(cè)下一 token 的因果語言模型頭、以及多個(gè)完全相同的網(wǎng)絡(luò)層等),唯一的區(qū)別在于注意力層,它與傳統(tǒng)的 transformer 模型架構(gòu)完全不同,因此 RWKV 的注意力計(jì)算公式也不一樣。
本文不會(huì)對(duì)注意力層過多的介紹,這里推薦一篇 Johan Sokrates Wind 的博文,里面有對(duì)注意力層的分?jǐn)?shù)計(jì)算公式等更全面的解釋。
現(xiàn)有檢查點(diǎn)
純語言模型: RWKV-4 模型
大多數(shù)采用 RWKV 架構(gòu)的語言模型參數(shù)量范圍從 170M 到 14B 不等。據(jù) RWKV 概述博文 介紹,這些模型已經(jīng)在 Pile 數(shù)據(jù)集上完成訓(xùn)練,并進(jìn)行了多項(xiàng)不同的基準(zhǔn)測(cè)試,取得了與其他 SoTA 模型表現(xiàn)相當(dāng)?shù)男阅芙Y(jié)果。

指令微調(diào)/Chat 版: RWKV-4 Raven
Bo 還訓(xùn)練了 RWKV 架構(gòu)的 “chat” 版本: RWKV-4 Raven 模型。RWKV-4 Raven 是一個(gè)在 Pile 數(shù)據(jù)集上預(yù)訓(xùn)練的模型,并在 ALPACA、CodeAlpaca、Guanaco、GPT4All、ShareGPT 等上進(jìn)行了微調(diào)。RWKV-4 Raven 模型有多個(gè)版本,如不同語言 (僅英文、英文 + 中文 + 日文、英文 + 日文等) 和不同大小 (1.5B 參數(shù)、7B 參數(shù)、14B 參數(shù)) 等。
所有 HF 版的模型都可以在 Hugging Face Hub 的 RWKV 社區(qū)主頁 找到。
集成 ?? Transformers 庫
感謝這個(gè) Pull Request 的貢獻(xiàn),RWKV 架構(gòu)現(xiàn)已集成到 ?? transformers 庫中。在作者撰寫本文之時(shí),您已經(jīng)可以通過從源代碼安裝?transformers
?庫,或者使用其?main
?分支。RWKV 架構(gòu)也會(huì)與 transformers 庫一起更新,您可以像使用任何其他架構(gòu)一樣使用它。
下面讓我們來看一些使用示例。
文本生成示例
要在給定 prompt 的情況下生成文本,您可以使用?pipeline
:
或者可以運(yùn)行下面的代碼片段:
使用 Raven 模型 (chat 模型) 示例
您可以以 alpaca 風(fēng)格使用提示 chat 版模型,示例如下:
據(jù) Bo 所述,這條 discord 消息 (訪問超鏈接時(shí)請(qǐng)確保已加入 discord 頻道) ?中有更詳細(xì)的書寫指令技巧。

權(quán)重轉(zhuǎn)換
任何用戶都可以使用?transformers
?庫中提供的轉(zhuǎn)換腳本輕松地將原始 RWKV 模型權(quán)重轉(zhuǎn)換為 HF 格式。具體步驟為: 首先,將 “原始” 權(quán)重 push 到 Hugging Face Hub (假定目標(biāo)倉庫為?RAW_HUB_REPO
,目標(biāo)權(quán)重文件為?RAW_FILE
),然后運(yùn)行以下轉(zhuǎn)換腳本:
如果您想將轉(zhuǎn)換后的模型 push 到 Hub 上 (假定推送目錄為?dummy_user/converted-rwkv
),首先請(qǐng)確保在 push 模型之前使用?huggingface-cli login
?登錄 HF 賬號(hào),然后運(yùn)行:
未來工作
多語言 RWKV
Bo 目前正在研究在多語言語料庫上訓(xùn)練 RWKV 模型,最近發(fā)布了一個(gè)新的 多語言分詞器。
社區(qū)后續(xù)研究方向
RWKV 社區(qū)非?;钴S,致力于幾個(gè)后續(xù)研究方向。項(xiàng)目清單可以在 RWKV 的 discord 專用頻道中找到 (訪問超鏈接時(shí)請(qǐng)確保已加入 discord 頻道)。歡迎加入這個(gè) RWKV 研究頻道,以及對(duì) RWKV 的積極貢獻(xiàn)!
模型壓縮與加速
由于只需要矩陣 - 向量運(yùn)算,對(duì)于非標(biāo)準(zhǔn)化和實(shí)驗(yàn)性的計(jì)算硬件,RWKV 是一個(gè)非常理想的架構(gòu)選擇,例如光子處理器/加速器。
因此自然地,RWKV 架構(gòu)也可以使用經(jīng)典的加速和壓縮技術(shù) (如 ONNX、4 位/8 位量化等)。我們希望集成了 transformer 的 RWKV 架構(gòu)能夠使更多開發(fā)者和從業(yè)者受益。
在不久的將來,RWKV 還可以使用 optimum 庫提出的加速技術(shù)。rwkv.cpp 或 rwkv-cpp-cuda 倉庫涉及的其中一些技術(shù)在庫中已標(biāo)明。
致謝
我們 Hugging Face 團(tuán)隊(duì)非常感謝 Bo 和 RWKV 社區(qū)抽出寶貴時(shí)間來回答關(guān)于架構(gòu)的問題,以及非常感謝他們的幫助和支持。我們很期待在 HF 生態(tài)中看到更多 RWKV 模型的應(yīng)用。我們還要感謝 Johan Wind 發(fā)布的關(guān)于 RWKV 的博文,這對(duì)我們理解架構(gòu)本身和其潛力有很大幫助。最后,我們著重感謝 ArEnSc 開啟 RWKV 集成到?transformers
?庫的 PR 所做的工作,以及感謝 Merve Noyan、Maria Khalusova 和 Pedro Cuenca 審閱和校對(duì)本篇文章!
引用
如果您希望在工作中使用 RWKV,請(qǐng)使用此 cff 引用。https://github.com/BlinkDL/RWKV-LM/blob/main/CITATION.cff
英文原文:?https://hf.co/blog/rwkv
作者: BlinkDL, Harrison Vanderbyl, Sylvain Gugger, Younes Belkada
譯者: SuSung-boy
審校/排版: zhongdongy (阿東)