64 注意力機(jī)制【動(dòng)手學(xué)深度學(xué)習(xí)v2】

從生物學(xué)角度上來(lái)說,人的決定是由刻意(隨意)因素和非刻意(非隨意)因素共同決定的

之前學(xué)過的卷積、全連接和池化都只考慮刻意因素
想要查詢的東西,即隨意線索,被稱為Query
非隨意線索則為無(wú)關(guān)緊要的環(huán)境值,以多個(gè)Key和對(duì)應(yīng)的Value表示。
隨意即想要干啥,所以會(huì)關(guān)注的事情,比如想喝咖啡,然后便會(huì)關(guān)注咖啡杯query和key是同一個(gè)東西,只是查詢一個(gè)query的時(shí)候,要考慮所有的query即key
注意力池化層:根據(jù)輸入的Query來(lái)對(duì)多個(gè)Key進(jìn)行有偏向性的選擇。

實(shí)現(xiàn)方法:
x代表key,y代表value。多個(gè)由key和value組成的配對(duì)就是數(shù)據(jù)集。
以下給出幾種注意力池化方案:
1、平均池化方案:無(wú)論給定的Query是多少,總是輸出數(shù)據(jù)集中value的平均值。此方法并沒有利用到數(shù)據(jù)集中的key
2、將要查詢的Query與所有數(shù)據(jù)集中的key進(jìn)行相減,并將此值輸送給函數(shù)K,K被用于計(jì)算query與各個(gè)Key的相關(guān)性的工具函數(shù)。計(jì)算完后,將數(shù)據(jù)集中的value進(jìn)行加權(quán)相加。得到最終要輸出的value值??傮w來(lái)說,此方法是為了輸出與query最為相近的key對(duì)應(yīng)的value值。

相似度判斷函數(shù)K的選取:
使用softmax進(jìn)行相似度判斷

在注意力池化層中添加學(xué)習(xí)參數(shù)w:


代碼實(shí)現(xiàn):
import torch from torch import nn from d2l import torch as d2l

服從均值為0和標(biāo)準(zhǔn)差為0.5的正態(tài)分布。 在這里生成了50個(gè)訓(xùn)練樣本和50個(gè)測(cè)試樣本。 為了更好地可視化之后的注意力模式,需要將訓(xùn)練樣本進(jìn)行排序。
n_train = 50 # 訓(xùn)練樣本數(shù) x_train, _ = torch.sort(torch.rand(n_train) * 5) # 排序后的訓(xùn)練樣本 def f(x): return 2 * torch.sin(x) + x**0.8 y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,)) # 訓(xùn)練樣本的輸出(此處為訓(xùn)練集加上了均值為零,方差為0.5的高斯分布噪音) x_test = torch.arange(0, 5, 0.1) # 測(cè)試樣本為以0.1為步長(zhǎng),從0到5生成的數(shù)據(jù)集。 y_truth = f(x_test) # 測(cè)試樣本的真實(shí)輸出 n_test = len(x_test) # 測(cè)試樣本數(shù) n_test
50
下面的函數(shù)將繪制所有的訓(xùn)練樣本(樣本由圓圈表示), 不帶噪聲項(xiàng)的真實(shí)數(shù)據(jù)生成函數(shù)f
(標(biāo)記為“Truth”), 以及學(xué)習(xí)得到的預(yù)測(cè)函數(shù)(標(biāo)記為“Pred”)。
def plot_kernel_reg(y_hat): d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'], xlim=[0, 5], ylim=[-1, 5]) d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);



# X_repeat的形狀:(n_test,n_train), # 每一行都包含著相同的測(cè)試輸入(例如:同樣的查詢) X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train)) # x_train包含著鍵。attention_weights的形狀:(n_test,n_train), # 每一行都包含著要在給定的每個(gè)查詢的值(y_train)之間分配的注意力權(quán)重 attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1) # y_hat的每個(gè)元素都是值的加權(quán)平均值,其中的權(quán)重是注意力權(quán)重 y_hat = torch.matmul(attention_weights, y_train) plot_kernel_reg(y_hat)

現(xiàn)在來(lái)觀察注意力的權(quán)重。 這里測(cè)試數(shù)據(jù)的輸入相當(dāng)于查詢,而訓(xùn)練數(shù)據(jù)的輸入相當(dāng)于鍵。 因?yàn)閮蓚€(gè)輸入都是經(jīng)過排序的,因此由觀察可知“查詢-鍵”對(duì)越接近, 注意力匯聚的注意力權(quán)重就越高。(顏色較深表示注意力更高)
d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0), xlabel='Sorted training inputs', ylabel='Sorted testing inputs')

X= torch.ones((2, 1, 4)) Y = torch.ones((2, 4, 6)) torch.bmm(X, Y).shape
torch.Size([2, 1, 6])
注意力機(jī)制的背景中,我們可以使用小批量矩陣乘法來(lái)計(jì)算小批量數(shù)據(jù)中的加權(quán)平均值。
eights = torch.ones((2, 10)) * 0.1 values = torch.arange(20.0).reshape((2, 10)) torch.bmm(weights.unsqueeze(1), values.unsqueeze(-1))
tensor([[[ 4.5000]], [[14.5000]]])
基于?(10.2.7)中的 帶參數(shù)的注意力匯聚,使用小批量矩陣乘法, 定義Nadaraya-Watson核回歸的帶參數(shù)版本為:
class NWKernelRegression(nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) self.w = nn.Parameter(torch.rand((1,), requires_grad=True)) def forward(self, queries, keys, values): # queries和attention_weights的形狀為(查詢個(gè)數(shù),“鍵-值”對(duì)個(gè)數(shù)) queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1])) self.attention_weights = nn.functional.softmax( -((queries - keys) * self.w)**2 / 2, dim=1) # values的形狀為(查詢個(gè)數(shù),“鍵-值”對(duì)個(gè)數(shù)) return torch.bmm(self.attention_weights.unsqueeze(1), values.unsqueeze(-1)).reshape(-1)