從零實(shí)現(xiàn)BERT、GPT及Difussion類算法-3:Multi-head Attention & Transformer
教程簡(jiǎn)介及目錄見: 從零實(shí)現(xiàn)BERT、GPT及Difussion類算法:文章簡(jiǎn)介及目錄
本章完整源碼見https://github.com/firechecking/CleanTransformer/blob/main/CleanTransformer/transformer.py
這一章將參考《attention is all you need》論文,實(shí)現(xiàn)Multi-head Attention、LayerNorm、TransformerBlock,有了這章的基礎(chǔ)后,在下一章就可以開始搭建Bert、GPT等模型結(jié)構(gòu)了
Multi-head Attention
參考:https://arxiv.org/abs/1706.03762
Attention介紹(選讀)
先簡(jiǎn)單介紹下一般性的Attention,如果已經(jīng)了解的同學(xué)可以跳過
Attention字面意思是注意力,也就是讓模型能夠學(xué)習(xí)到一個(gè)權(quán)重,來將輸入選擇性的傳入下一層
比較常用的操作如下:
首先假定輸入tensor為q, k, v,其中
self-attention是attention的一個(gè)特例:q=k=v
以下給出基礎(chǔ)attention的偽代碼
Multi-head Attention基礎(chǔ)原理

由以上論文截圖可知,
所以實(shí)現(xiàn)步驟如下(可以和上文基礎(chǔ)Attention對(duì)比著看):
對(duì)Q,K,V進(jìn)行Linear:得到新的Q、K、V的size不變
Multi-Head拆分:
使用Q、K計(jì)算Weight(其中第二行是Transformer在attention基礎(chǔ)上增加的一個(gè)scaling factor)
使用Weight和V,計(jì)算新的V
對(duì)V進(jìn)行維度變換
Multi-head Attention實(shí)現(xiàn)代碼
代碼不是太復(fù)雜,結(jié)合上文和注釋,應(yīng)該能很容易看懂
LayerNorm
參考
https://arxiv.org/abs/1607.06450
https://arxiv.org/abs/1607.06450
https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
https://blog.csdn.net/xinjieyuan/article/details/109587913
BatchNorm與LayerNorm的差異
batch normalization
對(duì)每一個(gè)輸入,先在mini-batch上計(jì)算輸入的均值、標(biāo)準(zhǔn)差,然后當(dāng)前層的每個(gè)輸入值使用均值、標(biāo)準(zhǔn)差進(jìn)行正則計(jì)算
公式如下
先在mini-batch上計(jì)算出每個(gè)位置的均值
、標(biāo)準(zhǔn)差
,其中
為mini-batch大小
然后對(duì)每個(gè)值應(yīng)用變換
提示:這里之所以有下標(biāo)i,是因?yàn)閎atch normalization是在batch內(nèi),對(duì)不同樣本的相同位置進(jìn)行歸一
layer normalization
batch normalization是在batch內(nèi),對(duì)不同樣本的相同位置進(jìn)行歸一;而layer normalization是在layer內(nèi),對(duì)同一個(gè)樣本的不同位置進(jìn)行歸一
batch normalization不在整個(gè)mini-batch上計(jì)算均值、標(biāo)準(zhǔn)差,而是在當(dāng)前層的當(dāng)前樣本計(jì)算輸入的均值、標(biāo)準(zhǔn)差,然后對(duì)當(dāng)前層的當(dāng)前樣本輸入值使用均值、標(biāo)準(zhǔn)差進(jìn)行正則計(jì)算(也可以理解為L(zhǎng)ayer Normalization是和batch無關(guān),而是對(duì)每個(gè)樣本單獨(dú)處理)
公式如下
先在單個(gè)樣本上計(jì)算出每一層的均值
、標(biāo)準(zhǔn)差
,其中
為當(dāng)前l(fā)ayer的大小hidden units數(shù)量
然后對(duì)每個(gè)值應(yīng)用變換
Layer Normalization代碼實(shí)現(xiàn)
代碼如下
eps為一個(gè)較小值,是為了防止標(biāo)準(zhǔn)差std為0時(shí),0作為除數(shù)
從上文公式看出,標(biāo)準(zhǔn)差是計(jì)算
的均值后開根號(hào),所以代碼中有std = self._mean((x - mean).pow(2) + self.eps).pow(0.5),是復(fù)用了self._mean()的計(jì)算均值操作
為了和pytorch的LayerNorm保持一致,這里同樣可以接受normalized_shape參數(shù),表示Normalization要在哪幾個(gè)維度上進(jìn)行計(jì)算,可以結(jié)合_mean()函數(shù)中的代碼進(jìn)行理解
TransformerBlock
參考
https://arxiv.org/abs/1706.03762
Transformer原理

從《Attention Is All You Need》論文中這張圖可以看出以下幾點(diǎn)信息:
Encoder、Decoder基本相同,最大差別是Decoder上多了一層Multi-Head Attention
每一個(gè)TransformerBlock只由Multi-Head Attention、Add、LayerNorm、Linear這4種操作組合而成
在上文已經(jīng)實(shí)現(xiàn)的Multi-Head Attention、LayerNorm基礎(chǔ)上,再來實(shí)現(xiàn)TransformerBlock就很簡(jiǎn)單了
為進(jìn)一步簡(jiǎn)化,在本章我們先只實(shí)現(xiàn)Encoder,并且省略掉mask等額外操作。到之后講到GPT時(shí)再來實(shí)現(xiàn)Decoder以及更完善的TransformerBlock
Transformer代碼實(shí)現(xiàn)
代碼主要由attention+Add+Norm,以及FFW+Add+Norm這2個(gè)部分組成,其中ffw是兩層全連接中間夾一個(gè)ReLU激活函數(shù)
從以上代碼看出TransformerBlock還是非常簡(jiǎn)潔的,而Bert、GPT等模型就是對(duì)TransformerBlock的堆疊,這部分內(nèi)容將放在下一章講解