自然語言處理: 第五章Attention注意力機(jī)制
理論基礎(chǔ)
Attention(來自2017年google發(fā)表的[1706.03762] Attention Is All You Need (arxiv.org) ),顧名思義是注意力機(jī)制,字面意思就是你所關(guān)注的東西,比如我們看到一個(gè)非常非常的故事的時(shí)候,但是其實(shí)我們一般能用5W2H就能很好的歸納這個(gè)故事,所以我們在復(fù)述或者歸納一段文字的時(shí)候,我們肯定有我們所關(guān)注的點(diǎn),這些關(guān)注的點(diǎn)就是我們的注意力,而類似How 或者when 這種不同的形式就成為了Attention里的多頭的機(jī)制。 下圖是引自GPT3.5對注意力的一種直觀的解釋,簡而言之其實(shí)就是各種不同(多頭)我們關(guān)注的點(diǎn)(注意力)構(gòu)成了注意力機(jī)制,這個(gè)奠定現(xiàn)代人工智能基石的基礎(chǔ)。
那么注意力機(jī)制的優(yōu)點(diǎn)是什么呢? (下面的對比是相對于上一節(jié)的Seq2Seq模型)
- 解決了長距離依賴問題,由于Seq2Seq模型一般是以時(shí)序模型eg RNN / Lstm / GRU 作為基礎(chǔ), 所以就會必然導(dǎo)致模型更傾向新的輸入 – 多頭注意力機(jī)制允許模型在解碼階段關(guān)注輸入序列中的不同部分
- 信息損失:很難將所有信息壓縮到一個(gè)固定長度的向量中(encorder 輸出是一個(gè)定長的向量) – 注意力機(jī)制動態(tài)地選擇輸入序列的關(guān)鍵部分
- 復(fù)雜度和計(jì)算成本:順序處理序列的每個(gè)時(shí)間步 – 全部網(wǎng)絡(luò)都是以全連接層或者點(diǎn)積操作,沒有時(shí)序模型
- 對齊問題:源序列和目標(biāo)序列可能存在不對齊的情況 – 注意力機(jī)制能夠?yàn)槟P吞峁└?xì)的詞匯級別對齊信
注意力可以拆解成下面6個(gè)部分,下面會在代碼實(shí)現(xiàn)部分逐個(gè)解釋
- 縮放點(diǎn)積注意力
兩個(gè)向量相乘可以得到其相似度, 一般常用的是直接點(diǎn)積也比較簡單,原論文里還提出里還提出了下面兩種計(jì)算相似度分?jǐn)?shù)的方式, 也可以參考下圖。
其實(shí)相似度分?jǐn)?shù),直觀理解就是兩個(gè)向量點(diǎn)積可以得到相似度的分?jǐn)?shù),加權(quán)求和得到輸出。
然后就是對點(diǎn)積注意力的拆解首先我們要明確目標(biāo)我們要求解的是X1關(guān)于X2的注意力輸出,所以首先需要確定的是X1 和 X2 的特征維度以及batch_size肯定要相同,然后是seq長度可以不同。 然后我們計(jì)算原始注意力權(quán)重即X1(N , seq1 , embedding) · X2(N , seq2 , embedding) -> attention( N , seq1 , seq2) , 可以看到我們得到了X1中每個(gè)單詞對X2中每個(gè)單詞的注意力權(quán)重矩陣所以維度是(seq1 , seq2)。
當(dāng)特征維度尺寸比較大時(shí),注意力的值會變得非常大,從而導(dǎo)致后期計(jì)算softmax的時(shí)候梯度消失,所以這里會對注意力的值做一個(gè)縮小,也就是將點(diǎn)積的結(jié)果/scaler_factor, 這個(gè)scaler_factor一般是embedding_size 的開根號。
然后我們對X2的維度做一個(gè)softmax 得到歸一化的注意力權(quán)重,至于為什么是X2,是因?yàn)槲覀冇?jì)算的是X1關(guān)于X2的注意力,所以在下一步我會會讓整個(gè)attention 權(quán)重與X2做點(diǎn)積也就是加權(quán)求和,這里需要把X2對應(yīng)的權(quán)重做歸一化所以要對X2的權(quán)重做歸一化。
也就是X1 與 X2 之間相互每個(gè)單詞的相似度(score) 因?yàn)槲覀兦蟮氖荴1 關(guān)于X2的注意力,所以最后我們將歸一化后的權(quán)重與X2做一個(gè)加權(quán)求和(點(diǎn)積)即Attention_scaled(N , seq1 , seq2) · X2(N , seq2 , embedding) -> X1_attention_X2( N , seq1 , embedding) 這個(gè)時(shí)候我們可以看到最后的輸出與X1的維度相同,但是里面的信息已經(jīng)是整合了X2的信息的的X1’。
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-2cDP4HFU-1689687821802)(image/06_attention/1689604378729.png)]
-
編解碼注意力
這個(gè)僅存在Seq2Seq的架構(gòu)中,也就是將編碼器的最后輸出與解碼器的隱藏狀態(tài)相互結(jié)合,下圖進(jìn)行了解釋可以看到encoder將輸入的上下文進(jìn)行編碼后整合成一個(gè)context 向量,由于我們最后的輸出是decoder 所以這里X1 是解碼器的隱層狀態(tài),X2 是編碼器的隱層狀態(tài)。
4. QKV
下面介紹的就是注意力中一個(gè)經(jīng)常被弄混的概念,QKV, 根據(jù)前面的只是其實(shí)query 就是X1也就是我們需要查詢的目標(biāo), Key 和 Value 也就是X2,只是X2不同的表現(xiàn)形式,K 可以等于 V 也可以不等,上面的做法都是相等的。
-
自注意力
最后就是注意力機(jī)制最核心的內(nèi)容,也就是自注意力機(jī)制,那么為什么多了一個(gè)自呢?其實(shí)就是X1 = X2 ,換句話說就是自己對自己做了一個(gè)升華,文本對自己的內(nèi)容做了一個(gè)類似summary的機(jī)制,得到了精華。就如同下圖一樣,自注意力的QKV 都來自同一個(gè)輸入X向量,最后得到的X’, 它是自己整合了自己全部信息向量,它讀完了自己全部的內(nèi)容,并不只是單獨(dú)的一個(gè)字或者一段話,而是去其糟粕后的X。
而多頭自注意力也就是同樣的X切分好幾個(gè)QKV, 可以捕捉不同的重點(diǎn)(類似一個(gè)qkv捕捉when , 一個(gè)qkv捕捉how),所以多頭是有助于網(wǎng)絡(luò)的表達(dá),然后這里需要注意的是多頭是將embedding - > (n_head , embedding // n_head ) , 不是(n_head , embedding)。
所以多頭中得到的注意力權(quán)重的shape也會變成(N , n_head , seq , seq ) 這里由于是自注意力 所以seq1 = seq2 = seq 。
代碼實(shí)現(xiàn)
這里只介紹了核心代碼實(shí)現(xiàn), 下面是多頭注意力的實(shí)現(xiàn):
import torch.nn as nn # 導(dǎo)入torch.nn庫
# 創(chuàng)建一個(gè)Attention類,用于計(jì)算注意力權(quán)重
class Mult_attention(nn.Module):
def __init__(self, n_head):
super(Mult_attention, self).__init__()
self.n_head = n_head
def forward(self, decoder_context, encoder_context , dec_enc_attn_mask): # decoder_context : x1(q) , encoder_context : x2(k , v)
# print(decoder_context.shape, encoder_context.shape) # X1(N , seq_1 , embedding) , X2(N , seq_2 , embedding)
# 進(jìn)行切分 , (N , seq_len_X , emb_dim) -> (N , num_head , seq_len_X , head_dim) / head_dim * num_head = emb_dim
Q = self.split_heads(decoder_context) # X1
K = self.split_heads(encoder_context) # X2
V = self.split_heads(encoder_context) # X2
# print(Q.shape , 0)
# 將注意力掩碼復(fù)制到多頭 attn_mask: [batch_size, n_heads, len_q, len_k]
attn_mask = dec_enc_attn_mask.unsqueeze(1).repeat(1, self.n_head, 1, 1)
# 計(jì)算decoder_context和encoder_context的點(diǎn)積,得到多頭注意力分?jǐn)?shù),其實(shí)就是在原有的基礎(chǔ)上多加一個(gè)尺度
scores = torch.matmul(Q, K.transpose(-2, -1)) # -> (N , num_head , seq_len_1 , seq_len_2)
scores.masked_fill_(attn_mask , -1e9)
# print(scores.shape ,1 )
# 自注意力原始權(quán)重進(jìn)行縮放
scale_factor = K.size(-1) ** 0.5
scaled_weights = scores / scale_factor # -> (N , num_head , seq_len_1 , seq_len_2)
# print(scaled_weights.shape , 2)
# 歸一化分?jǐn)?shù)
attn_weights = nn.functional.softmax(scaled_weights, dim=-1) # -> (N , num_head , seq_len_1 , seq_len_2)
# print(attn_weights.shape , 3)
# 將注意力權(quán)重乘以encoder_context,得到加權(quán)的上下文向量
attn_outputs = torch.matmul(attn_weights, V) # -> (N , num_head , seq_len_1 , embedding // num_head)
# print(attn_outputs.shape , 4)
# 將多頭合并下(output & attention)
attn_outputs = self.combine_heads(attn_outputs) # 與Q的尺度是一樣的
attn_weights = self.combine_heads(attn_weights)
# print(attn_weights.shape , attn_outputs.shape , 5) #
return attn_outputs, attn_weights
# 將所有頭的結(jié)果拼接起來,就是把n_head 這個(gè)維度去掉,
def combine_heads(self , tensor):
# print(tensor.size())
batch_size, num_heads, seq_len, head_dim = tensor.size()
feature_dim = num_heads * head_dim
return tensor.transpose(1, 2).contiguous().view(batch_size, seq_len, feature_dim)
def split_heads(self , tensor):
batch_size, seq_len, feature_dim = tensor.size()
head_dim = feature_dim // self.n_head
# print(tensor.shape, head_dim , self.n_head)
return tensor.view(batch_size, seq_len, self.n_head, head_dim).transpose(1, 2)
多頭注意力的解碼器,這里添加了mask機(jī)制。
class DecoderWithMutliHeadAttention(nn.Module):
def __init__(self, hidden_size, output_size , n_head):
super(DecoderWithMutliHeadAttention, self).__init__()
self.hidden_size = hidden_size # 設(shè)置隱藏層大小
self.n_head = n_head # 多頭
self.embedding = nn.Embedding(output_size, hidden_size) # 創(chuàng)建詞嵌入層
self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True) # 創(chuàng)建RNN層
self.multi_attention = Mult_attention(n_head = n_head)
self.out = nn.Linear(2 * hidden_size, output_size) # 修改線性輸出層,考慮隱藏狀態(tài)和上下文向量
def forward(self, inputs, hidden, encoder_outputs , encoder_input):
embedded = self.embedding(inputs) # 將輸入轉(zhuǎn)換為嵌入向量
rnn_output, hidden = self.rnn(embedded, hidden) # 將嵌入向量輸入RNN層并獲取輸出
dec_enc_attn_mask = self.get_attn_pad_mask(inputs, encoder_input) # 解碼器-編碼器掩碼
context, attn_weights = self.multi_attention(rnn_output, encoder_outputs , dec_enc_attn_mask) # 計(jì)算注意力上下文向量
output = torch.cat((rnn_output, context), dim=-1) # 將上下文向量與解碼器的輸出拼接
output = self.out(output) # 使用線性層生成最終輸出
return output, hidden, attn_weights
def get_attn_pad_mask(self , seq_q, seq_k):
#-------------------------維度信息--------------------------------
# seq_q 的維度是 [batch_size, len_q]
# seq_k 的維度是 [batch_size, len_k]
#-----------------------------------------------------------------
# print(seq_q.size(), seq_k.size())
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# 生成布爾類型張量[batch_size,1,len_k(=len_q)]
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) #<PAD> Token的編碼值為0
# 變形為何注意力分?jǐn)?shù)相同形狀的張量 [batch_size,len_q,len_k]
pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k)
#-------------------------維度信息--------------------------------
# pad_attn_mask 的維度是 [batch_size,len_q,len_k]
#-----------------------------------------------------------------
return pad_attn_mask # [batch_size,len_q,len_k]
結(jié)果
整體實(shí)驗(yàn)結(jié)果如下,可能是因?yàn)檎w語料庫太小了,所以翻譯結(jié)果不是太好,但是多頭注意力機(jī)制還是都跑通了:文章來源:http://www.zghlxwxcb.cn/news/detail-580641.html
文章來源地址http://www.zghlxwxcb.cn/news/detail-580641.html
到了這里,關(guān)于自然語言處理: 第五章Attention注意力機(jī)制的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!