RWKV: 大語言模型結(jié)構(gòu)的另一種選擇

文章首發(fā)于 網(wǎng)站機(jī)器翻譯學(xué)堂
轉(zhuǎn)載事宜請后臺詢問哦
作者 | 阮俊豪
單位 | 東北大學(xué)自然語言處理實(shí)驗(yàn)室

前言
Transformer[1]在諸多的NLP任務(wù)上產(chǎn)生了非常驚艷的效果,甚至逐漸輻射到CV領(lǐng)域(如Vision Transfomrer[2]),獲得了學(xué)術(shù)界和工業(yè)界一致的認(rèn)可。因此也被作為當(dāng)下大語言模型結(jié)構(gòu)的不二之選。無論是以BERT[3]為代表的,常用于分類任務(wù)的Encoder-only模型;亦或是解決生成類任務(wù)為主的Decoder-only模型GPT[4];或兼而有之的Encoder-Decoder架構(gòu)的T5[5]模型,他們都采用了transformer的部分或完整架構(gòu)。
盡管如此,Transformer作為大語言模型的標(biāo)準(zhǔn)架構(gòu)選擇,也存在一些不能忽視的缺陷,例如內(nèi)存和時(shí)間復(fù)雜度都與輸入序列的長度成平方,這很大程度的影響了大語言模型部署在端側(cè)或資源受限設(shè)備的可能性。

隨著Scailing Law[6]的提出,大公司開始將研究重心轉(zhuǎn)移到如何建設(shè)能容納更大參數(shù)規(guī)模的基礎(chǔ)支撐設(shè)備,更深層的訓(xùn)練技巧保證梯度和訓(xùn)練穩(wěn)定性,借助更大數(shù)據(jù)規(guī)模的有效微調(diào)方式。在力大磚飛,物理限制還沒有看到盡頭的當(dāng)下,較少有人繼續(xù)研究Transformer之外,還有沒有更適合大語言模型的結(jié)構(gòu)。一方面因?yàn)門ransformer結(jié)構(gòu)設(shè)計(jì)上具有很高效的GPU并行程度,另一方面是因?yàn)門ransformer在各種NLP任務(wù)取得良好成果的情況下嘗試其他模型的大參數(shù)訓(xùn)練有不小的試錯(cuò)成本。
一位獨(dú)立研究員彭博[7],在2021年8月份,就提出了他的原始RWKV[8]構(gòu)想,并在完善到RKWV-V2版本之后,在reddit和discord上引發(fā)業(yè)內(nèi)人員廣泛關(guān)注?,F(xiàn)今已經(jīng)演化到V4版本,并充分展現(xiàn)了RNN模型的縮放潛力。本篇博客將介紹RWKV的原理、演變流程和現(xiàn)在取得的成效。
RWKV模型原型
An Attention Free Transformer[9]
標(biāo)準(zhǔn)AFT
RWKV模型中的time-mix的設(shè)計(jì)受An Attention Free Transformer工作影響較深,所以開始之前,我們先介紹一下Attention Free Transformer(AFT)。
這篇文章之所以名字中帶有Free,是因?yàn)樗耆コ藰?biāo)準(zhǔn)transfomrer中的點(diǎn)乘注意力,同時(shí)也沒有采取其他linear transformer工作中常見的點(diǎn)乘注意力近似方法。
我們把標(biāo)準(zhǔn)的多頭自注意力中每個(gè)頭的輸出結(jié)果表示為:
其中
是第 i 個(gè)頭的線性變換。??是非線性函數(shù),默認(rèn)情況下是softmax函數(shù)。
AFT中,Q、K、V來源仍然是輸入的線性變換結(jié)構(gòu)。
我們把記作模型的輸出,那么模型 t 時(shí)刻的輸出在AFT中表示為:
其中指的是element-wise乘法,就是按位相乘?
。
?是sigmoid函數(shù),和softmax的區(qū)別是,sigmoid是一個(gè)個(gè)矩陣元素計(jì)算的,softmax是一行行計(jì)算的。
就是一個(gè)可學(xué)習(xí)的位置偏置。
我們可以粗略地將這個(gè)過程提煉為?,和原始的點(diǎn)乘注意力???相比,元素乘的時(shí)間開銷顯著降低,且不需要顯式地計(jì)算和保存softmax出來的權(quán)重矩陣,還保留了全局的K、V之間的交互。需要特別注意的是,和原始多頭自注意力不同。AFT的softmax是以列(時(shí)間維度)做歸一化的,類似于池化。所以下文RWKV-V1也把這個(gè)設(shè)計(jì)稱為Time-mix。原始多頭自注意力是在隱藏層維度上做的。
對于每一個(gè)輸入?的位置
,AFT形成了以
為權(quán)重的
的加權(quán)平均數(shù)。
為了進(jìn)一步闡述AFT和多頭注意力的聯(lián)系,我們按照標(biāo)準(zhǔn)多頭自注意力的形式,來描寫時(shí)刻第
個(gè)注意力頭的輸出
在原始AFT的工作中,仍有兩點(diǎn)關(guān)于的內(nèi)容需要厘清。首先,AFT中
,
應(yīng)該指代的是從
中取出第
行和第
列的元素。其次,公式中
與
的加和基本上都采用廣播方式。以
為例子,
,
,因此他們的加和可能實(shí)際上需要由
做轉(zhuǎn)置,然后在新的列維度上擴(kuò)充至
方可運(yùn)算。除此之外,為了簡單描述,上述過程并沒有討論mask的置入方案,在實(shí)際操作中,可以通過控制求和范圍來等價(jià)實(shí)現(xiàn)
AFT-local
在很多場景中,局部性都是一個(gè)很重要的歸納偏差。也有一些工作基于這個(gè)性質(zhì)開展,比如OpenAI現(xiàn)在使用在GPT系列上的sparse transformer。AFT-local發(fā)現(xiàn),訓(xùn)練后的transfomrer的注意力模式更傾向于集中局部。為了更詳細(xì)的表達(dá)這個(gè)理念,我們用一副圖可視化Vision Transformer(ViT)的注意力矩陣。這是一個(gè)12層6注意力頭數(shù)的ViT在256張圖片上的平均注意力模式,其中縱向維度是層數(shù)(2層一統(tǒng)計(jì)),橫向維度是注意力頭數(shù)。星光亮度越大的地方代表了更高的注意力權(quán)重。

可以看到,這個(gè)圖里展示了相當(dāng)強(qiáng)的局部模式,這個(gè)觀測引出了AFT的變體——AFT-local的設(shè)計(jì)。
和上文的區(qū)別是,將位置偏置的值做了一個(gè)區(qū)域限制:
這里s就是一個(gè)局部的窗口大小。
AFT-simple
AFT-local的一個(gè)極端模式就是令s=0,也就是完全從AFT中抽走了位置偏置,從而得到
AFT還有第三種變體,叫做AFT-conv,其更適用于圖像任務(wù)。因篇幅所限,感興趣的讀者可以查閱原文了解。
GLU Variants Improve Transformer
channel-mix的部分則受該節(jié)工作啟發(fā)。尤其是其中的GeGLU。
標(biāo)準(zhǔn)Transformer中的feed-forward network采用的是如下的結(jié)構(gòu):
T5取走了最外層的偏置,修改為
也有其他工作嘗試使用GELU或者Swish替代ReLU.
其中
在上面的例子里,兩層可學(xué)習(xí)的線性變換是按順序堆疊在一起的,一層的輸出作為第二層的輸入。后續(xù)Gated Linear Units(GLU)提出了另一種形式,這種形式如果省略掉激活函數(shù)也被稱作bilinear layer。
同理,GLU上也存在一個(gè)非線性激活函數(shù),我們可以使用GELU等函數(shù)去替換。RWKV-1所涉及的GeGLU就來自于GELU+GLU:
最后仍然采用省略偏置的結(jié)構(gòu)替換FFN:
本文最后還測試了許多不同GLU變體作為FFN的效果,感興趣的讀者可以參看原文。
RWLV-V1
這個(gè)版本的工作還比較類似linear transformer的工作,而不是純粹的RNN網(wǎng)絡(luò)。在彭博的設(shè)計(jì)中,RWKV模型由交替的Time-mix和Channel-mix層組成。
兩者均擁有類似的R\W\KV結(jié)構(gòu)設(shè)計(jì),故此得名。其中R\K\V由輸入線性變換生成,W是一個(gè)可學(xué)習(xí)的參數(shù)矩陣。
筆者認(rèn)為,和AFT工作中的標(biāo)記方式不同,矩陣的下標(biāo)不僅代表取元素,也同時(shí)代表維度表示。例如可以認(rèn)為是
取出了第
行的元素。
可以看出,Time-mix層與AFT-simple基本相同,其區(qū)別包括,修改歸一化
相較于原始的在完整時(shí)間序列上的歸一化,Time-mix現(xiàn)在采用的是一種只回看歷史序列的局部歸一化。
除此之外,還將W分解,支持多頭
channel-mix和上文提到的也基本相同,因?yàn)镵和V是由層輸入x線性變換得到的,因此相當(dāng)于只是增加了一個(gè)額外的self-gating R,彭博測試后發(fā)現(xiàn)確實(shí)提高了擬合性能。
最后再加上了彭博2020年8月的一個(gè)想法?time-shift就組成了RWKV-V1版本。time-shift主要涉及到relative position embedding,以及把輸入從改成
表面上看,這是強(qiáng)制要求網(wǎng)絡(luò)結(jié)合 x[t] 和 x[t-1],是個(gè) inductive bias,就像強(qiáng)制要求網(wǎng)絡(luò)使用 2-gram。
但后來我做了更多實(shí)驗(yàn),發(fā)現(xiàn)它對于深層模型也有效。而且,對于 SA 和 FFN 都有效。
這乍一看有點(diǎn)奇怪,把 x[t] 的一半通道,用 x[t-1] 的通道代替,是不是有點(diǎn)過分了?經(jīng)實(shí)驗(yàn),確實(shí)可以用這么強(qiáng)的混合。
后來再想想,我們訓(xùn)練 GPT 時(shí),網(wǎng)絡(luò)的 hidden representation 實(shí)際在做兩件不同的事情:
1. 預(yù)測下一個(gè)字。有時(shí)這很簡單(有時(shí)下一個(gè)字很明顯)。
2. 收集前文的 context 信息,傳遞給后文。這永遠(yuǎn)是困難的任務(wù)。
這兩件事有明顯的區(qū)別。第2件事更難。
我的理論是:在加入 time-shift 后,沒有 time-shift 的通道主要承擔(dān) (1),被 time-shift 的通道主要承擔(dān) (2)。所以,這就實(shí)現(xiàn)了更明確的信息分離。
而且 time-shift 可讓信息快速傳遞,就像一個(gè)小卷積。在多層后可以看很長的距離。
上述這段引述自彭博對于time-shift的描述,我解釋一下最后一句話,“在多層后可以看很長的距離”,假設(shè)層數(shù)有6層,在t時(shí)刻的第6層的輸入,依賴于第5層t-1時(shí)刻和t時(shí)刻的輸入,而第五層t-1時(shí)刻的輸入,又依賴于第四層t-2時(shí)刻的輸入。。。依次類推到第一層t-5時(shí)刻的輸入。所以賦予了回看歷史信息的能力。這個(gè)觀點(diǎn)一般是在local attention的模式里會提到的,被用于解釋這一段做法也挺契合。如果認(rèn)為收集前文context的信息更難,time-mixing就在注意力回看的基礎(chǔ)上,直接讓輸入特征也一并回看了。
對于普通的 QKV 自注意力,我觀察權(quán)重,看各層對于兩種通道的使用程度,發(fā)現(xiàn) Q 偏向于使用【沒有 time-shift 的通道】,而 V 偏向于使用【被 time-shift 的通道】,符合這個(gè)理論。
這個(gè)觀測我覺得很有意思,如何觀察Q和V對于不同層的偏向使用情況,我傾向于作者可能對線性變換矩陣做了一個(gè)類似熱力圖觀察,發(fā)現(xiàn)不同通道部分存在不同的熱力,比如前半部分通道更熱,
后半部分更熱。
該版本在 simplebooks-92 的 character-level 性能對比了灰色基線(普通 MHA 多頭注意力 + Rotary encoding + GeGLU FFN),和加入各種魔改的MHA的黑色線版本,均有競爭力。

RWKV-V2
這個(gè)版本的改動有些大,我們先從自注意力層開始。為了方便,我把作者的偽代碼圖轉(zhuǎn)成了模型結(jié)構(gòu)圖。和RWKV-V1相比,區(qū)別在于修改了time-mix的softmax處理,顯式地加入了token-shift機(jī)制,從永遠(yuǎn)的當(dāng)前詞和歷史詞各取一半channel合并改成了可訓(xùn)練參數(shù)T調(diào)節(jié),我們也可以把這個(gè)叫做shift門。



其中W是被預(yù)先計(jì)算好的值初始化的(具體怎么預(yù)先計(jì)算的彭博未提及,但是受啟發(fā)于alibi編碼),不同的channel使用不同的W,而且更小的W被用在更低的層上。這是因?yàn)榈撞繉拥钠骄p更快,對應(yīng)短程信息;頂部層的平均衰減更慢,對應(yīng)長程信息。
作者還給了一個(gè)提醒,要clamp k【是閾值裁剪,對上界限制了60】和對d加上來預(yù)防overflows。
接下來是FFN層,相較于RWKV-V1,這算是完全新增的層了。

每個(gè)時(shí)間步只依賴于?,而
或
只依賴于
和
,擺脫了像傳統(tǒng)RNN對歷史狀態(tài)的依賴,所以可以很方便的展開訓(xùn)練時(shí)并行。a、b分別代表kv和k的滑動平均數(shù)。c和d則是a、b加上self-attntion,同時(shí)也是記憶機(jī)制。T,K,V,R,W,X,P全是可訓(xùn)練的參數(shù)矩陣。
模型在實(shí)現(xiàn)時(shí)還有個(gè)headQK機(jī)制,會快速地看一遍前文,可以讓模型從前文復(fù)制或避免某些字。
在v2這個(gè)階段,LSTM的發(fā)明和奠基者Sepp Hochreiter也看到了這個(gè)工作,并給予了一定認(rèn)可。

為了更好地闡述V2的內(nèi)容,我們?nèi)匀恢卣故綬WKV-V2如何從降至
?。同時(shí)講述RWKV更向RNN靠近的理由。筆者在RWKV-V1的時(shí)候,曾說過V1還是很類似于linearize attention的工作,每一個(gè)時(shí)間步都依賴于歷史時(shí)刻所有的輸入。而rnn類型的模型則是完全依賴于上一個(gè)時(shí)刻的輸入和當(dāng)前時(shí)刻的輸入,以及固定大小的某些狀態(tài)(在RWKV-V2里是a和b)。以下是彭博給出的RWKV-V2的自注意力層簡化表達(dá)
RWKV-V3
這個(gè)階段是一個(gè)非??斓倪^度階段,相較于前兩版近一年的跨度,V3只持續(xù)了兩個(gè)月左右,且文字資料較少。在RWKV-LM的項(xiàng)目中提到,R K V 的來源變化了。在V2版本中,是先生成一個(gè)mix的。
我們用一幅圖簡單展示一下這個(gè)變換。

另一個(gè)變換是,使用preLN替換postLN(更穩(wěn)定且更快收斂)。不過好像preLN在V2已經(jīng)采用了。
自注意力層的實(shí)現(xiàn)是:對x做time-shift,然后分別根據(jù)不同的矩陣映射成,再映射成R\K\V
RWKV-V4
雖然這版星在22年底就出現(xiàn)了,但是我們可以認(rèn)為最后定檔應(yīng)該是出自這篇論文RWKV: Reinventing RNNs for the Transformer Era.框架的整體結(jié)構(gòu)如下圖所示:

time-mixing block的實(shí)現(xiàn)是:
開始仍然是和v3一樣的映射操作,然后有一點(diǎn)小小的變化。就起到了transformer自注意力的
,但是時(shí)間復(fù)雜度是
,因?yàn)橹恍枰闅v序列,每次操作都是簡單的加法,時(shí)間復(fù)雜度是
。這里面的
也是一個(gè)位置偏置,就好像我們第一次在AFT中討論的那樣,不過現(xiàn)在
。要求w中的元素非負(fù)是為了讓每個(gè)通道的權(quán)重相當(dāng)于往歷史時(shí)間上衰減。且越遠(yuǎn)的時(shí)間步,比如i=1,受到的遺忘力度就越大。
而channel-mixing block的實(shí)現(xiàn)是:
這里采用了squared ReLU。
但是還有一點(diǎn)非常重要,因?yàn)榻鉀Q了時(shí)間復(fù)雜度的問題,transformer之所以如此強(qiáng)大,是因?yàn)樗泻芎玫牟⑿行再|(zhì),能充分利用GPU。而就time-mixing block來看,是存在時(shí)間依賴的。也就是說訓(xùn)練時(shí)給定一句話,不能利用類似teacher-forcing的方法訓(xùn)練,因?yàn)樗€依賴于前面時(shí)刻的狀態(tài)。
對于這個(gè)問題,彭博采用了Simple recurrent units for highly parallelizable recurrence?(SRU)里的思路,簡單地說,只依賴于原始輸入x的部分,都可以預(yù)先計(jì)算好,例如。而
這種element-wise product則可以按照batch和dimmension兩個(gè)維度并行。達(dá)到一個(gè)相對完備的并行狀態(tài)。
上面的寫法看起來并不太像RNN,因?yàn)镽NN并不會全局的回看所有序列,而是通常依賴于當(dāng)前輸入和來自歷史的狀態(tài),不過,我們在V2的時(shí)候就知道,這些mixing都可以重寫成RNN-block的形式。transformer解碼的時(shí)候通常會利用一個(gè)kv緩存來獲得一定的速度提升,但隨著序列長度增加也會帶來很多空間占用,RNN形式卻不會碰到這種問題。比如,上面的我們可以重寫成下面這種形式,它只依賴于輸入
和狀態(tài)
在訓(xùn)練的時(shí)候,RWKV采取并行模式,而在推斷的時(shí)候,RWKV可以采取RNN模式。
下圖是語言建模任務(wù)下,RWKV-LM的運(yùn)行過程。

傳統(tǒng)的RNN通過使用非飽和激活函數(shù)、門控機(jī)制、梯度裁剪、添加約束等多種方法來解決梯度穩(wěn)定性問題,但RWKV通過類似于transformer和RNN的融合,本質(zhì)上地具有了更穩(wěn)定的梯度。RWKV包含全時(shí)間依賴的softmax操作有助于數(shù)值穩(wěn)定和防止梯度消失。層歸一化也在這方面起到了很重要的作用。論文附錄中,作者給出了RWKV在梯度穩(wěn)定性上的數(shù)學(xué)證明(詳見附錄F)。同時(shí),這樣的設(shè)計(jì)也能夠以超過任何現(xiàn)有RNN的能力的方式實(shí)現(xiàn)深層堆疊,模型能夠捕獲跨不同抽象級別的更復(fù)雜的模式。

上面的圖,展現(xiàn)的是位置衰減偏置在channel維度的衰減大小??梢钥闯?,在后續(xù)層中,模型的上下文信息被幾乎完整的保存和傳播。而在低層中,衰減曲線很快下滑,提示底層比較關(guān)注局部信息。
下面的圖則展示了信息的檢索和傳播路徑,采用的是Locating and editing factual associations in GPT?中提到的方法。
運(yùn)行模型一次,記錄計(jì)算過程中課程的所有狀態(tài)和激活情況。采用噪聲破壞被試的輸入嵌入(例子里用的是“埃菲爾鐵塔”)。還原計(jì)算過程中某一層在某一個(gè)令牌處的狀態(tài)和激活情況,記錄模型輸出正確答案( '巴黎')的對數(shù)概率。
與transformer不同,RWKV依賴于信息在時(shí)間維度上的遞歸傳播。在這種情況下,"埃菲爾鐵塔位于巴黎"這一事實(shí)在第4層被檢索到。然后將其傳遞給后續(xù)的層。在第20層中,信息主要通過時(shí)間進(jìn)行傳播,直到到達(dá)需要的地方。最后,將其傳遞到最后一層進(jìn)行答案的輸出。
一些局限性
RWKV解決長程依賴問題,也就是傳遞歷史上下文信息的機(jī)制,在RWKV-V4中有三種——遞歸、時(shí)間衰減和token shift。之所以在此處仍然重復(fù)已經(jīng)提到的內(nèi)容,是因?yàn)樵谶@三種機(jī)制下,RWKV的長度外推效果仍然欠佳。
RWKV是靠記憶來完成任務(wù)的,也就是只會開卷考試不會閉卷考試。所以RWKV對prompt比較敏感,要把任務(wù)描述的token放到最前面,帶著問題閱讀材料效果才比較好。
當(dāng)模型寬度繼續(xù)加大時(shí),線性RNN的時(shí)間復(fù)雜度可能更依賴于隱藏層維度d,使得標(biāo)準(zhǔn)的attention機(jī)制也在序列上接近線性了。
結(jié)語
彭博是一個(gè)工程實(shí)力非常強(qiáng)悍的獨(dú)立研究員,早期就在知乎分享了非常多對于模型改進(jìn)的思路和實(shí)現(xiàn)方案,將talk is cheap,show me the code展現(xiàn)的淋漓盡致。他能從大量的論文閱讀中真正提取出別人的精華,并進(jìn)行嘗試,在積累一定量后進(jìn)行融合,也由此誕生了RWKV。RWKV是一個(gè)非常有意思的模型,通過SRU的思路,解決了RNN訓(xùn)練并行效率的問題,通過AFT、time-shift、相對位置編碼等多種思路融合,加強(qiáng)了RWKV的長程依賴且緩解了訓(xùn)練不穩(wěn)定的問題。通過geglu、squred ReLU、FFN with R gate、自定義初始化等多種工作進(jìn)一步強(qiáng)化了transformer中的FFN層。
在OpenAI實(shí)際上并不Open的當(dāng)下,RWKV從21年創(chuàng)立之初完全開源,既受到開源社區(qū)很多幫助,也反哺了開源社區(qū)很多的成果?,F(xiàn)在ChatRWKV已經(jīng)在同尺寸上展現(xiàn)出了相當(dāng)驚人的表現(xiàn)。對于LM基座感興趣的讀者,可以參看這個(gè)鏈接,而想在線體驗(yàn)的讀者,也可以從這個(gè)鏈接直接體驗(yàn)RWKV-4-World-7B模型。
RWKV-V5的構(gòu)想和改進(jìn)計(jì)劃也已在近日公布,相信在可預(yù)見的未來,大語言模型的結(jié)構(gòu)選擇除了transformer,也將會有完全由國人設(shè)計(jì)的RWKV的一席之地。

hi,這里是小牛翻譯~
想要看到更多我們的文章,可以關(guān)注下
機(jī)器翻譯學(xué)堂(公號或網(wǎng)站)
筆芯~

往期精彩文章


