最美情侣中文字幕电影,在线麻豆精品传媒,在线网站高清黄,久久黄色视频

歡迎光臨散文網(wǎng) 會員登陸 & 注冊

RWKV: 大語言模型結(jié)構(gòu)的另一種選擇

2023-07-17 14:03 作者:小牛翻譯NiuTrans  | 我要投稿

文章首發(fā)于 網(wǎng)站機(jī)器翻譯學(xué)堂

轉(zhuǎn)載事宜請后臺詢問哦

作者 | 阮俊豪

單位 | 東北大學(xué)自然語言處理實(shí)驗(yàn)室

前言

Transformer[1]在諸多的NLP任務(wù)上產(chǎn)生了非常驚艷的效果,甚至逐漸輻射到CV領(lǐng)域(如Vision Transfomrer[2]),獲得了學(xué)術(shù)界和工業(yè)界一致的認(rèn)可。因此也被作為當(dāng)下大語言模型結(jié)構(gòu)的不二之選。無論是以BERT[3]為代表的,常用于分類任務(wù)的Encoder-only模型;亦或是解決生成類任務(wù)為主的Decoder-only模型GPT[4];或兼而有之的Encoder-Decoder架構(gòu)的T5[5]模型,他們都采用了transformer的部分或完整架構(gòu)。

盡管如此,Transformer作為大語言模型的標(biāo)準(zhǔn)架構(gòu)選擇,也存在一些不能忽視的缺陷,例如內(nèi)存和時(shí)間復(fù)雜度都與輸入序列的長度成平方,這很大程度的影響了大語言模型部署在端側(cè)或資源受限設(shè)備的可能性。

隨著Scailing Law[6]的提出,大公司開始將研究重心轉(zhuǎn)移到如何建設(shè)能容納更大參數(shù)規(guī)模的基礎(chǔ)支撐設(shè)備,更深層的訓(xùn)練技巧保證梯度和訓(xùn)練穩(wěn)定性,借助更大數(shù)據(jù)規(guī)模的有效微調(diào)方式。在力大磚飛,物理限制還沒有看到盡頭的當(dāng)下,較少有人繼續(xù)研究Transformer之外,還有沒有更適合大語言模型的結(jié)構(gòu)。一方面因?yàn)門ransformer結(jié)構(gòu)設(shè)計(jì)上具有很高效的GPU并行程度,另一方面是因?yàn)門ransformer在各種NLP任務(wù)取得良好成果的情況下嘗試其他模型的大參數(shù)訓(xùn)練有不小的試錯(cuò)成本。

一位獨(dú)立研究員彭博[7],在2021年8月份,就提出了他的原始RWKV[8]構(gòu)想,并在完善到RKWV-V2版本之后,在reddit和discord上引發(fā)業(yè)內(nèi)人員廣泛關(guān)注?,F(xiàn)今已經(jīng)演化到V4版本,并充分展現(xiàn)了RNN模型的縮放潛力。本篇博客將介紹RWKV的原理、演變流程和現(xiàn)在取得的成效。

RWKV模型原型

An Attention Free Transformer[9]

標(biāo)準(zhǔn)AFT

RWKV模型中的time-mix的設(shè)計(jì)受An Attention Free Transformer工作影響較深,所以開始之前,我們先介紹一下Attention Free Transformer(AFT)。

這篇文章之所以名字中帶有Free,是因?yàn)樗耆コ藰?biāo)準(zhǔn)transfomrer中的點(diǎn)乘注意力,同時(shí)也沒有采取其他linear transformer工作中常見的點(diǎn)乘注意力近似方法。

我們把標(biāo)準(zhǔn)的多頭自注意力中每個(gè)頭的輸出結(jié)果表示為:

f_i(x)%3D%5Csigma(%5Cfrac%7BQ_i(K_i)%5ET%7D%7B%5Csqrt%7Bd_k%7D%7D)V_i

其中Q_i%3DXW_i%5EQ%2CQ_i%3DXW_i%5EK%2CQ_i%3DXW_i%5EV%2C%5Cquad%20W_i%5EQ%5Cin%20R%5E%7Bd%5Ctimes%20d_k%7D%2CW_i%5Ek%5Cin%20R%5E%7Bd%5Ctimes%20d_k%7D%2CW_i%5EQ%5Cin%20R%5E%7Bd%5Ctimes%20d_v%7D

是第 i 個(gè)頭的線性變換。??是非線性函數(shù),默認(rèn)情況下是softmax函數(shù)。

AFT中,Q、K、V來源仍然是輸入X的線性變換結(jié)構(gòu)。

我們把Y%3Df(X)記作模型的輸出,那么模型 t 時(shí)刻的輸出在AFT中表示為:

Y_t%3D%5Csigma_q(Q_t)%5Codot%5Cfrac%7B%5Csum_%7Bt%5E%7B%5Cprime%7D%3D1%7D%5ET%5Cexp(K_%7Bt%5E%7B%5Cprime%7D%7D%2Bw_%7Bt%2Ct%5E%7B%5Cprime%7D%7D)%5Codot%20V_t%5E%7B%5Cprime%7D%7D%7B%5Csum_%7Bt%5E%7B%5Cprime%7D-1%7D%5ET%5Cexp(K_%7Bt%5E%7B%5Cprime%7D%7D%2Bw_%7Bt%2Ct%5E%7B%5Cprime%7D%7D)%7D

其中%5Codot%20指的是element-wise乘法,就是按位相乘?a%5Codot%20b%3D%5Ba_1b_1%2Ca_2b_2...a_nb_n%5D。%5Csigma%20_q?是sigmoid函數(shù),和softmax的區(qū)別是,sigmoid是一個(gè)個(gè)矩陣元素計(jì)算的,softmax是一行行計(jì)算的。w%5Cin%20R%5E%7BT%5Ctimes%20T%7D就是一個(gè)可學(xué)習(xí)的位置偏置。

我們可以粗略地將這個(gè)過程提煉為?Q%5Codot%5Csum%5Ctext%7Bsoftmax%7D(K)%5Codot%20V,和原始的點(diǎn)乘注意力???相比,元素乘的時(shí)間開銷顯著降低,且不需要顯式地計(jì)算和保存softmax出來的權(quán)重矩陣,還保留了全局的K、V之間的交互。需要特別注意的是,和原始多頭自注意力不同。AFT的softmax是以列(時(shí)間維度)做歸一化的,類似于池化。所以下文RWKV-V1也把這個(gè)設(shè)計(jì)稱為Time-mix。原始多頭自注意力是在隱藏層維度上做的。

對于每一個(gè)輸入?X的位置t,AFT形成了以K為權(quán)重的V的加權(quán)平均數(shù)。

為了進(jìn)一步闡述AFT和多頭注意力的聯(lián)系,我們按照標(biāo)準(zhǔn)多頭自注意力的形式,來描寫t時(shí)刻第i個(gè)注意力頭的輸出

Y_t%5Ei%3D%3Ca_t%5Ei%2CV%5Ei%3E%2Cs.t.a_t%5Ei%3D%5Cfrac%7B%5Csigma(Q_t%5Ei)%5Cexp(K%5Ei%2Bw_t)%7D%7B%5Csum_%7Bt'-1%7D%5ET%5Cexp(K_%7Bt'%7D%5Ei%2Bw_%7Bt%2Ct'%7D)%7D

在原始AFT的工作中,仍有兩點(diǎn)關(guān)于w的內(nèi)容需要厘清。首先,AFT中w_t,t%E2%80%99應(yīng)該指代的是從w%5Cin%20R%5E%7BT%5Ctimes%20T%7D中取出第t行和第t'列的元素。其次,公式中KW的加和基本上都采用廣播方式。以K%5Ei%2Bw_t為例子,K%5Ei%5Cin%20R%5E%7BT%5Ctimes%20d_x%7D,w_t%5Cin%20R%5E%7B1%5Ctimes%20T%7D,因此他們的加和可能實(shí)際上需要由w做轉(zhuǎn)置,然后在新的列維度上擴(kuò)充至d_k方可運(yùn)算。除此之外,為了簡單描述,上述過程并沒有討論mask的置入方案,在實(shí)際操作中,可以通過控制求和范圍來等價(jià)實(shí)現(xiàn)

AFT-local

在很多場景中,局部性都是一個(gè)很重要的歸納偏差。也有一些工作基于這個(gè)性質(zhì)開展,比如OpenAI現(xiàn)在使用在GPT系列上的sparse transformer。AFT-local發(fā)現(xiàn),訓(xùn)練后的transfomrer的注意力模式更傾向于集中局部。為了更詳細(xì)的表達(dá)這個(gè)理念,我們用一副圖可視化Vision Transformer(ViT)的注意力矩陣。這是一個(gè)12層6注意力頭數(shù)的ViT在256張圖片上的平均注意力模式,其中縱向維度是層數(shù)(2層一統(tǒng)計(jì)),橫向維度是注意力頭數(shù)。星光亮度越大的地方代表了更高的注意力權(quán)重。

可以看到,這個(gè)圖里展示了相當(dāng)強(qiáng)的局部模式,這個(gè)觀測引出了AFT的變體——AFT-local的設(shè)計(jì)。

和上文的區(qū)別是,將位置偏置w%5Cin%20R%5E%7BT%5Ctimes%20T%7D的值做了一個(gè)區(qū)域限制:

w_%7Bt%2Ct'%7D%3D%5Cbegin%7Bcases%7Dw_%7Bt%2Ct'%7D%2C%26%5Ctext%7Bif%7D%7Ct-t'%7C%3Cs%5C%5C0%2C%26otherwise%5Cend%7Bcases%7D

這里s就是一個(gè)局部的窗口大小。

AFT-simple

AFT-local的一個(gè)極端模式就是令s=0,也就是完全從AFT中抽走了位置偏置,從而得到

Y_t%3D%5Csigma_q(Q_t)%5Codot%5Cfrac%7B%5Csum_%7Bt'%3D1%7D%5ET%5Cexp(K_%7Bt'%7D)%5Codot%20V_t'%7D%7B%5Csum_%7Bt'%3D1%7D%5ET%5Cexp(K_%7Bt'%7D%2Bw_%7Bt%2Ct'%7D)%7D%3D%5Csigma(Q_t)%5Codot%5Csum_%7Bt'%3D1%7D%5ET%5Ctext%7B(softmax(K)%7D%5Codot%20V)_%7Bt'%7D

AFT還有第三種變體,叫做AFT-conv,其更適用于圖像任務(wù)。因篇幅所限,感興趣的讀者可以查閱原文了解。

GLU Variants Improve Transformer

channel-mix的部分則受該節(jié)工作啟發(fā)。尤其是其中的GeGLU。

標(biāo)準(zhǔn)Transformer中的feed-forward network采用的是如下的結(jié)構(gòu):

%5Coperatorname%7BFFN%7D(x%2CW_1%2CW_2%2Cb_1%2Cb_2)%3Dmax(0%2CxW_1%2Bb_1)W_2%2Bb_2

T5取走了最外層的偏置,修改為

%5Ctext%7BFFN%7D_%5Ctext%7BReLU%7D%7B%20(%20x%20%2C%20W%20_%201%20%2C%20W%20_%202%20%2C%20b%20_%201%20%2C%20b%20_%202%20)%20%7D%3Dmax(0%2CxW_1%2Bb_1)W_2

也有其他工作嘗試使用GELU或者Swish替代ReLU.

%5Cbegin%7Barray%7D%7Bl%7D%5Ctext%7BFFN%7D_%5Ctext%7BGELU%7D(x%2CW_1%2CW_2)%3D%5Ctext%7BGELU%7D(xW_1)W_2%5C%5C%5Ctext%7BFFN%7D_%5Ctext%7BSwish%7D(x%2CW_1%2CW_2)%3D%5Ctext%7BSwish%7D_1(xW_1)W_2%5Cend%7Barray%7D

其中%5Coperatorname%7BGELU%7D(x)%3Dx%5CPhi(x)%2CSwish_%5Cbeta(x)%3Dx%5Csigma(%5Cbeta%20x)

在上面的例子里,兩層可學(xué)習(xí)的線性變換是按順序堆疊在一起的,一層的輸出作為第二層的輸入。后續(xù)Gated Linear Units(GLU)提出了另一種形式,這種形式如果省略掉激活函數(shù)也被稱作bilinear layer。

%5Ctext%7BGLU%7D(x%2CW_1%2CW_2%2Cb%2Cc)%3D%5Csigma(xW%2Bb)%5Cotimes(xV%2Bc)

同理,GLU上也存在一個(gè)非線性激活函數(shù),我們可以使用GELU等函數(shù)去替換。RWKV-1所涉及的GeGLU就來自于GELU+GLU:

%5Ctext%7BGEGLU%7D(x%2CW%2CV%2Cb%2Cc)%3D%5Ctext%7BGELU%7D(xW%2Bb)%5Cotimes(xV%2Bc)

最后仍然采用省略偏置的結(jié)構(gòu)替換FFN:

%5Coperatorname%7BFFN_%7BGEGLU%7D%7D(x%2CW%2CV%2CW_2)%3D(%5Coperatorname%7BGELU(xW)%7D%5Cotimes%20xV)W_2

本文最后還測試了許多不同GLU變體作為FFN的效果,感興趣的讀者可以參看原文。

RWLV-V1

這個(gè)版本的工作還比較類似linear transformer的工作,而不是純粹的RNN網(wǎng)絡(luò)。在彭博的設(shè)計(jì)中,RWKV模型由交替的Time-mix和Channel-mix層組成。

%5Cmathrm%7BTime%7D-%5Cmathrm%7Bmix%7D%3A%5Cmathbf%7BTM%7D_%7Bt%2Cc%7D%3D%5Cmathrm%7Bsigmoid%7D(R_%7Bt%2Cc%7D)%5Ccdot%5Csum_%7Bu%7DW_%7Bt%2Cu%2Cc%7D%5Ccdot%5Cmathrm%7Bsoftmax%7D_%7Bt%7D(K_%7Bu%2Cc%7D)%5Ccdot%20V_%7Bu%2Cc%7D

%5Ctext%7BChannel%7D-%5Ctext%7Bmix%7D%3A%5Cmathbf%7BCM%7D_%7Bt%2Cc%7D%3D%5Ctext%7Bsigmoid%7D(R_%7Bt%2Cc%7D)%5Ccdot%5Csum_dW_%7Bc%2Cd%7D%5Ccdot%5Ctext%7Bgelu%7D(K_%7Bu%2Cc%7D)%5Ccdot%20V_%7Bu%2Cc%7D

兩者均擁有類似的R\W\KV結(jié)構(gòu)設(shè)計(jì),故此得名。其中R\K\V由輸入線性變換生成,W是一個(gè)可學(xué)習(xí)的參數(shù)矩陣。

筆者認(rèn)為,和AFT工作中的標(biāo)記方式不同,矩陣的下標(biāo)不僅代表取元素,也同時(shí)代表維度表示。例如K_%7Bu%2Cc%7D可以認(rèn)為是K%5Cin%20R%5E%7Bt%5Ctimes%20c%7D取出了第u行的元素。

可以看出,Time-mix層與AFT-simple基本相同,其區(qū)別包括,修改歸一化

%5Ctext%7Bsoftmax%7D_t(K_%7Bu%2Cc%7D)%3D%5Cfrac%7B%5Cexp(K_%7Bu%2Cc%7D)%7D%7B%5Csum%5Cnolimits_%7Bv%5Cleq%20t%7D%5Cexp(K_%7Bv%2Cc%7D)%7D

相較于原始的在完整時(shí)間序列上的歸一化,Time-mix現(xiàn)在采用的是一種只回看歷史序列的局部歸一化。

除此之外,還將W分解,支持多頭W%3AW_%7Bt%2Cu%2Cc%7D%3Df_%7Bh%7D(t-u)%5Ccdot%5Calpha_%7Bh%7D(u)%5Ccdot%5Cbeta_%7Bh%7D(t)

channel-mix和上文提到的FFN_%7BGEGLU%7D也基本相同,因?yàn)镵和V是由層輸入x線性變換得到的,因此相當(dāng)于只是增加了一個(gè)額外的self-gating R,彭博測試后發(fā)現(xiàn)確實(shí)提高了擬合性能。

最后再加上了彭博2020年8月的一個(gè)想法?time-shift就組成了RWKV-V1版本。time-shift主要涉及到relative position embedding,以及把輸入從x%5Bt%2C%3Ac%5D改成%5Ctext%7Btorch.cat%7D(x%5Bt-1%2C%3Ac%2F%2F2%5D%2Cx%5Bt%2Cc%2F%2F2%3A%5D).

表面上看,這是強(qiáng)制要求網(wǎng)絡(luò)結(jié)合 x[t] 和 x[t-1],是個(gè) inductive bias,就像強(qiáng)制要求網(wǎng)絡(luò)使用 2-gram。

但后來我做了更多實(shí)驗(yàn),發(fā)現(xiàn)它對于深層模型也有效。而且,對于 SA 和 FFN 都有效。

這乍一看有點(diǎn)奇怪,把 x[t] 的一半通道,用 x[t-1] 的通道代替,是不是有點(diǎn)過分了?經(jīng)實(shí)驗(yàn),確實(shí)可以用這么強(qiáng)的混合。

后來再想想,我們訓(xùn)練 GPT 時(shí),網(wǎng)絡(luò)的 hidden representation 實(shí)際在做兩件不同的事情:

1. 預(yù)測下一個(gè)字。有時(shí)這很簡單(有時(shí)下一個(gè)字很明顯)。

2. 收集前文的 context 信息,傳遞給后文。這永遠(yuǎn)是困難的任務(wù)。

這兩件事有明顯的區(qū)別。第2件事更難。

我的理論是:在加入 time-shift 后,沒有 time-shift 的通道主要承擔(dān) (1),被 time-shift 的通道主要承擔(dān) (2)。所以,這就實(shí)現(xiàn)了更明確的信息分離。

而且 time-shift 可讓信息快速傳遞,就像一個(gè)小卷積。在多層后可以看很長的距離。

上述這段引述自彭博對于time-shift的描述,我解釋一下最后一句話,“在多層后可以看很長的距離”,假設(shè)層數(shù)有6層,在t時(shí)刻的第6層的輸入,依賴于第5層t-1時(shí)刻和t時(shí)刻的輸入,而第五層t-1時(shí)刻的輸入,又依賴于第四層t-2時(shí)刻的輸入。。。依次類推到第一層t-5時(shí)刻的輸入。所以賦予了回看歷史信息的能力。這個(gè)觀點(diǎn)一般是在local attention的模式里會提到的,被用于解釋這一段做法也挺契合。如果認(rèn)為收集前文context的信息更難,time-mixing就在注意力回看的基礎(chǔ)上,直接讓輸入特征也一并回看了。

對于普通的 QKV 自注意力,我觀察權(quán)重,看各層對于兩種通道的使用程度,發(fā)現(xiàn) Q 偏向于使用【沒有 time-shift 的通道】,而 V 偏向于使用【被 time-shift 的通道】,符合這個(gè)理論。

這個(gè)觀測我覺得很有意思,如何觀察Q和V對于不同層的偏向使用情況,我傾向于作者可能對線性變換矩陣做了一個(gè)類似熱力圖觀察,發(fā)現(xiàn)不同通道部分存在不同的熱力,比如W_Q前半部分通道更熱,W_V后半部分更熱。

該版本在 simplebooks-92 的 character-level 性能對比了灰色基線(普通 MHA 多頭注意力 + Rotary encoding + GeGLU FFN),和加入各種魔改的MHA的黑色線版本,均有競爭力。

RWKV-V2

這個(gè)版本的改動有些大,我們先從自注意力層開始。為了方便,我把作者的偽代碼圖轉(zhuǎn)成了模型結(jié)構(gòu)圖。和RWKV-V1相比,區(qū)別在于修改了time-mix的softmax處理,顯式地加入了token-shift機(jī)制,從永遠(yuǎn)的當(dāng)前詞和歷史詞各取一半channel合并改成了可訓(xùn)練參數(shù)T調(diào)節(jié),我們也可以把這個(gè)叫做shift門。

其中W是被預(yù)先計(jì)算好的值初始化的(具體怎么預(yù)先計(jì)算的彭博未提及,但是受啟發(fā)于alibi編碼),不同的channel使用不同的W,而且更小的W被用在更低的層上。這是因?yàn)榈撞繉拥钠骄p更快,對應(yīng)短程信息;頂部層的平均衰減更慢,對應(yīng)長程信息。

作者還給了一個(gè)提醒,要clamp k【是閾值裁剪,對上界限制了60】和對d加上%5Cepsilon%20來預(yù)防overflows。

接下來是FFN層,相較于RWKV-V1,這算是完全新增的層了。

每個(gè)時(shí)間步只依賴于%5C%7Bx_t%2Ca_t%2Cb_t%5C%7D?,而a_tb_t只依賴于x_%7Bt%E2%88%921%7Dx_t,擺脫了像傳統(tǒng)RNN對歷史狀態(tài)的依賴,所以可以很方便的展開訓(xùn)練時(shí)并行。a、b分別代表kv和k的滑動平均數(shù)。c和d則是a、b加上self-attntion,同時(shí)也是記憶機(jī)制。T,K,V,R,W,X,P全是可訓(xùn)練的參數(shù)矩陣。

模型在實(shí)現(xiàn)時(shí)還有個(gè)headQK機(jī)制,會快速地看一遍前文,可以讓模型從前文復(fù)制或避免某些字。

在v2這個(gè)階段,LSTM的發(fā)明和奠基者Sepp Hochreiter也看到了這個(gè)工作,并給予了一定認(rèn)可。

為了更好地闡述V2的內(nèi)容,我們?nèi)匀恢卣故綬WKV-V2如何從O(T%5E2)降至O(T)?。同時(shí)講述RWKV更向RNN靠近的理由。筆者在RWKV-V1的時(shí)候,曾說過V1還是很類似于linearize attention的工作,每一個(gè)時(shí)間步都依賴于歷史時(shí)刻所有的輸入。而rnn類型的模型則是完全依賴于上一個(gè)時(shí)刻的輸入和當(dāng)前時(shí)刻的輸入,以及固定大小的某些狀態(tài)(在RWKV-V2里是a和b)。以下是彭博給出的RWKV-V2的自注意力層簡化表達(dá)

x_%7Bt%2B1%7D%3D%5Csigma%5Cleft(Rx_%7Bt%7D%5Cright)%5Ccdot%5Cfrac%7B%5Cexp(Kx_%7Bt%7D)%5Ccdot(V_%7Bx%7Dt)%2B%5Cexp(W)%5Ccdot%20a_%7Bt%7D%7D%7B%5Cexp(Kx_%7Bt%7D)%2B%5Cexp(W)%5Ccdot%20b_%7Bt%7D%7D

RWKV-V3

這個(gè)階段是一個(gè)非??斓倪^度階段,相較于前兩版近一年的跨度,V3只持續(xù)了兩個(gè)月左右,且文字資料較少。在RWKV-LM的項(xiàng)目中提到,R K V 的來源變化了。在V2版本中,是先生成一個(gè)mix的。

我們用一幅圖簡單展示一下這個(gè)變換。

另一個(gè)變換是,使用preLN替換postLN(更穩(wěn)定且更快收斂)。不過好像preLN在V2已經(jīng)采用了。

自注意力層的實(shí)現(xiàn)是:對x做time-shift,然后分別根據(jù)不同的矩陣映射成x_R%EF%BC%8Cx_K%2Cx_V,再映射成R\K\V

RWKV-V4

雖然這版星在22年底就出現(xiàn)了,但是我們可以認(rèn)為最后定檔應(yīng)該是出自這篇論文RWKV: Reinventing RNNs for the Transformer Era.框架的整體結(jié)構(gòu)如下圖所示:

time-mixing block的實(shí)現(xiàn)是:

%5Cbegin%7Bgathered%7D%0Ar_t%3DW_r%5Ccdot(u_rx_t%2B(1-u_r)x_%7Bt-1%7D)%20%5C%5C%0Ak_t%3DW_k%5Ccdot(u_kx_t%2B(1-u_k)x_%7Bt-1%7D)%20%5C%5C%0A%5Cbegin%7Baligned%7Dv_t%26%3DW_v%5Ccdot(u_vx_t%2B(1-u_v)x_%7Bt-1%7D)%5Cend%7Baligned%7D%20%5C%5C%0Awkv_%7Bt%7D%3D%5Cfrac%7B%5Csum_%7Bi%3D1%7D%5E%7Bt-1%7De%5E%7B-%5Cleft(t-1-i%5Cright)w%2Bki%7Dv_%7Bi%7D%2Be%5E%7Bu%2Bkt%7Dv_%7Bt%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bt-1%7De%5E%7B-%5Cleft(t-1-i%5Cright)w%2Bki%7D%2Be%5E%7Bu%2Bk_%7Bt%7D%7Dv_%7Bt%7D%7D%20%5C%5C%0Ao_t%3DW_o%5Ccdot(%5Csigma(r_t)%5Codot%20wkv_t)%20%0A%5Cend%7Bgathered%7D

開始仍然是和v3一樣的映射操作,然后有一點(diǎn)小小的變化。wkv_t就起到了transformer自注意力的Attn(Q%2CK%2CV),但是時(shí)間復(fù)雜度是O(T),因?yàn)橹恍枰闅v序列,每次操作都是簡單的加法,時(shí)間復(fù)雜度是O(1)。這里面的t也是一個(gè)位置偏置,就好像我們第一次在AFT中討論的那樣,不過現(xiàn)在w%5Cin%20(R_%7B%5Cgeq%200%7D)%5Ed。要求w中的元素非負(fù)是為了讓每個(gè)通道的權(quán)重相當(dāng)于往歷史時(shí)間上衰減。且越遠(yuǎn)的時(shí)間步,比如i=1,受到的遺忘力度就越大。

而channel-mixing block的實(shí)現(xiàn)是:

%5Cbegin%7Bgathered%7D%0A%5Cboldsymbol%7Br%7D_t%20%3DW_r%5Ccdot(u_rx_t%2B(1-u_r)x_%7Bt-1%7D)%20%5C%5C%0Ak_%7Bt%7D%20%3DW_r%5Ccdot(u_kx_t%2B(1-u_k)x_%7Bt-1%7D)%20%5C%5C%0AO_%7Bt%7D%20%3D%5Csigma(r_t)%5Codot(W_v%5Ccdot%5Cmax(k_t%2C0)%5E2)%20%0A%5Cend%7Bgathered%7D

這里采用了squared ReLU。

但是還有一點(diǎn)非常重要,因?yàn)榻鉀Q了時(shí)間復(fù)雜度的問題,transformer之所以如此強(qiáng)大,是因?yàn)樗泻芎玫牟⑿行再|(zhì),能充分利用GPU。而就time-mixing block來看,wkv_t是存在時(shí)間依賴的。也就是說訓(xùn)練時(shí)給定一句話,不能利用類似teacher-forcing的方法訓(xùn)練,因?yàn)樗€依賴于前面時(shí)刻的狀態(tài)。

對于這個(gè)問題,彭博采用了Simple recurrent units for highly parallelizable recurrence?(SRU)里的思路,簡單地說,只依賴于原始輸入x的部分,都可以預(yù)先計(jì)算好,例如W_ru_rX。而wkv_t這種element-wise product則可以按照batch和dimmension兩個(gè)維度并行。達(dá)到一個(gè)相對完備的并行狀態(tài)。

上面的寫法看起來并不太像RNN,因?yàn)镽NN并不會全局的回看所有序列,而是通常依賴于當(dāng)前輸入和來自歷史的狀態(tài),不過,我們在V2的時(shí)候就知道,這些mixing都可以重寫成RNN-block的形式。transformer解碼的時(shí)候通常會利用一個(gè)kv緩存來獲得一定的速度提升,但隨著序列長度增加也會帶來很多空間占用,RNN形式卻不會碰到這種問題。比如,上面的wkv_t我們可以重寫成下面這種形式,它只依賴于輸入x_t和狀態(tài)h%3D(a%2Cb)

%5Cbegin%7Bgathered%7D%0Aa_0%2Cb_0%3D0%20%5C%5C%0Awkv_t%3D%5Cfrac%7Ba_%7Bt-1%7D%2Be%5E%7Bu%2Bk_t%7Dv_t%7D%7Bb_%7Bt-1%7D%2Be%5E%7Bu%2Bkt%7D%7D%20%5C%5C%0Aa_t%3De%5E%7B-w%7Da_%7Bt-1%7D%2Be%5E%7Bk_t%7Dv_t%20%5C%5C%0Ab_t%3De%5E%7B-w%7Db_%7Bt-1%7D%2Be%5E%7Bk_t%7D%20%0A%5Cend%7Bgathered%7D

在訓(xùn)練的時(shí)候,RWKV采取并行模式,而在推斷的時(shí)候,RWKV可以采取RNN模式。

下圖是語言建模任務(wù)下,RWKV-LM的運(yùn)行過程。

傳統(tǒng)的RNN通過使用非飽和激活函數(shù)、門控機(jī)制、梯度裁剪、添加約束等多種方法來解決梯度穩(wěn)定性問題,但RWKV通過類似于transformer和RNN的融合,本質(zhì)上地具有了更穩(wěn)定的梯度。RWKV包含全時(shí)間依賴的softmax操作有助于數(shù)值穩(wěn)定和防止梯度消失。層歸一化也在這方面起到了很重要的作用。論文附錄中,作者給出了RWKV在梯度穩(wěn)定性上的數(shù)學(xué)證明(詳見附錄F)。同時(shí),這樣的設(shè)計(jì)也能夠以超過任何現(xiàn)有RNN的能力的方式實(shí)現(xiàn)深層堆疊,模型能夠捕獲跨不同抽象級別的更復(fù)雜的模式。

上面的圖,展現(xiàn)的是位置衰減偏置e%5E%7B-w%7D在channel維度的衰減大小??梢钥闯?,在后續(xù)層中,模型的上下文信息被幾乎完整的保存和傳播。而在低層中,衰減曲線很快下滑,提示底層比較關(guān)注局部信息。

下面的圖則展示了信息的檢索和傳播路徑,采用的是Locating and editing factual associations in GPT?中提到的方法。

運(yùn)行模型一次,記錄計(jì)算過程中課程的所有狀態(tài)和激活情況。采用噪聲破壞被試的輸入嵌入(例子里用的是“埃菲爾鐵塔”)。還原計(jì)算過程中某一層在某一個(gè)令牌處的狀態(tài)和激活情況,記錄模型輸出正確答案( '巴黎')的對數(shù)概率。

與transformer不同,RWKV依賴于信息在時(shí)間維度上的遞歸傳播。在這種情況下,"埃菲爾鐵塔位于巴黎"這一事實(shí)在第4層被檢索到。然后將其傳遞給后續(xù)的層。在第20層中,信息主要通過時(shí)間進(jìn)行傳播,直到到達(dá)需要的地方。最后,將其傳遞到最后一層進(jìn)行答案的輸出。

一些局限性

RWKV解決長程依賴問題,也就是傳遞歷史上下文信息的機(jī)制,在RWKV-V4中有三種——遞歸、時(shí)間衰減和token shift。之所以在此處仍然重復(fù)已經(jīng)提到的內(nèi)容,是因?yàn)樵谶@三種機(jī)制下,RWKV的長度外推效果仍然欠佳。

RWKV是靠記憶來完成任務(wù)的,也就是只會開卷考試不會閉卷考試。所以RWKV對prompt比較敏感,要把任務(wù)描述的token放到最前面,帶著問題閱讀材料效果才比較好。

當(dāng)模型寬度繼續(xù)加大時(shí),線性RNN的時(shí)間復(fù)雜度O(Td%5E2)可能更依賴于隱藏層維度d,使得標(biāo)準(zhǔn)的attention機(jī)制也在序列上接近線性了。

結(jié)語

彭博是一個(gè)工程實(shí)力非常強(qiáng)悍的獨(dú)立研究員,早期就在知乎分享了非常多對于模型改進(jìn)的思路和實(shí)現(xiàn)方案,將talk is cheap,show me the code展現(xiàn)的淋漓盡致。他能從大量的論文閱讀中真正提取出別人的精華,并進(jìn)行嘗試,在積累一定量后進(jìn)行融合,也由此誕生了RWKV。RWKV是一個(gè)非常有意思的模型,通過SRU的思路,解決了RNN訓(xùn)練并行效率的問題,通過AFT、time-shift、相對位置編碼等多種思路融合,加強(qiáng)了RWKV的長程依賴且緩解了訓(xùn)練不穩(wěn)定的問題。通過geglu、squred ReLU、FFN with R gate、自定義初始化等多種工作進(jìn)一步強(qiáng)化了transformer中的FFN層。

在OpenAI實(shí)際上并不Open的當(dāng)下,RWKV從21年創(chuàng)立之初完全開源,既受到開源社區(qū)很多幫助,也反哺了開源社區(qū)很多的成果?,F(xiàn)在ChatRWKV已經(jīng)在同尺寸上展現(xiàn)出了相當(dāng)驚人的表現(xiàn)。對于LM基座感興趣的讀者,可以參看這個(gè)鏈接,而想在線體驗(yàn)的讀者,也可以從這個(gè)鏈接直接體驗(yàn)RWKV-4-World-7B模型。

RWKV-V5的構(gòu)想和改進(jìn)計(jì)劃也已在近日公布,相信在可預(yù)見的未來,大語言模型的結(jié)構(gòu)選擇除了transformer,也將會有完全由國人設(shè)計(jì)的RWKV的一席之地。


hi,這里是小牛翻譯~

想要看到更多我們的文章,可以關(guān)注下

機(jī)器翻譯學(xué)堂(公號或網(wǎng)站)

筆芯~

往期精彩文章



RWKV: 大語言模型結(jié)構(gòu)的另一種選擇的評論 (共 條)

分享到微博請遵守國家法律
平山县| 资中县| 奎屯市| 象山县| 旬邑县| 鄂伦春自治旗| 葫芦岛市| 安陆市| 邹平县| 清水县| 乌海市| 利津县| 凤阳县| 新郑市| 浦城县| 镇赉县| 宕昌县| 定襄县| 余干县| 巴林右旗| 武安市| 进贤县| 株洲市| 胶南市| 锦屏县| 兰溪市| 冀州市| 绥芬河市| 闽清县| 永昌县| 宜兰县| 丘北县| 崇阳县| 肇东市| 珲春市| 茶陵县| 华坪县| 马关县| 绥化市| 长泰县| 偃师市|