在注意力機(jī)制中,每個查詢都會關(guān)注所有的鍵-值對并生成一個注意力輸出。由于查詢、鍵和值來自同一組輸入,因此被稱為 自注意力(self-attention),也被稱為內(nèi)部注意力(intra-attention)。本節(jié)將使用自注意力進(jìn)行序列編碼,以及使用序列的順序作為補(bǔ)充信息。
import math
import torch
from torch import nn
from d2l import torch as d2l
10.6.1 自注意力
給定一個由詞元組成的輸入序列 x 1 , … , x n \boldsymbol{x}_1,\dots,\boldsymbol{x}_n x1?,…,xn?,其中任意 x i ∈ R d ( 1 ≤ i ≤ n ) \boldsymbol{x}_i\in\R^d\quad(1\le i\le n) xi?∈Rd(1≤i≤n) 。該序列的自注意力輸出為一個長度相同的序列 y 1 , … , y n \boldsymbol{y}_1,\dots,\boldsymbol{y}_n y1?,…,yn?,其中:
y i = f ( x i , ( x 1 , x 1 ) , … , ( x n , x n ) ) ∈ R d \boldsymbol{y}_i=f(\boldsymbol{x}_i,(\boldsymbol{x}_1,\boldsymbol{x}_1),\dots,(\boldsymbol{x}_n,\boldsymbol{x}_n))\in\R^d yi?=f(xi?,(x1?,x1?),…,(xn?,xn?))∈Rd
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, # 基于多頭注意力對一個張量完成自注意力的計算
num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) # 張量的形狀為(批量大小,時間步的數(shù)目或詞元序列的長度,d)。
attention(X, X, X, valid_lens).shape # 輸出與輸入的張量形狀相同
torch.Size([2, 4, 100])
10.6.2 比較卷積神經(jīng)網(wǎng)絡(luò)、循環(huán)神經(jīng)網(wǎng)絡(luò)和自注意力
-
卷積神經(jīng)網(wǎng)絡(luò)
-
計算復(fù)雜度為 O ( k n d 2 ) O(knd^2) O(knd2)
-
k k k 為卷積核大小
-
n n n 為序列長度是
-
d d d 為輸入和輸出的通道數(shù)量
-
-
并行度為 O ( n ) O(n) O(n)
-
最大路徑長度為 O ( n / k ) O(n/k) O(n/k)
-
-
循環(huán)神經(jīng)網(wǎng)絡(luò)
-
計算復(fù)雜度為 O ( n d 2 ) O(nd^2) O(nd2)
d × d d\times d d×d 權(quán)重矩陣和 d d d 維隱狀態(tài)的乘法計算復(fù)雜度為 O ( d 2 ) O(d^2) O(d2),由于序列長度為 n n n,因此循環(huán)神經(jīng)網(wǎng)絡(luò)層的計算復(fù)雜度為 O ( n d 2 ) O(nd^2) O(nd2)
-
并行度為 O ( 1 ) O(1) O(1)
有 O ( n ) O(n) O(n) 個順序操作無法并行化。
-
最大路徑長度也是 O ( n ) O(n) O(n)
-
-
自注意力
-
計算復(fù)雜度為 O ( n 2 d ) O(n^2d) O(n2d)
查詢、鍵和值都是 n × d n\times d n×d 矩陣
-
并行度為 O ( n ) O(n) O(n)
每個詞元都通過自注意力直接連接到任何其他詞元。因此有 O ( 1 ) O(1) O(1) 個順序操作可以并行計算
-
最大路徑長度也是 O ( 1 ) O(1) O(1)
-
總而言之,卷積神經(jīng)網(wǎng)絡(luò)和自注意力都擁有并行計算的優(yōu)勢,而且自注意力的最大路徑長度最短。但是因為其計算復(fù)雜度是關(guān)于序列長度的二次方,所以在很長的序列中計算會非常慢。
10.6.3 位置編碼
在處理詞元序列時,循環(huán)神經(jīng)網(wǎng)絡(luò)是逐個的重復(fù)地處理詞元的,而自注意力則因為并行計算而放棄了順序操作。為了使用序列的順序信息,通過在輸入表示中添加 位置編碼(positional encoding) 來注入絕對的或相對的位置信息。位置編碼可以通過學(xué)習(xí)得到也可以直接固定得到。
基于正弦函數(shù)和余弦函數(shù)的固定位置編碼的矩陣第 i i i 行、第 2 j 2j 2j 列和 2 j + 1 2j+1 2j+1 列上的元素為:
p i , 2 j = sin ? ( i 1000 0 2 j / d ) p i , 2 j + 1 = cos ? ( i 1000 0 2 j / d ) \begin{align} p_{i,2j}&=\sin{\left(\frac{i}{10000^{2j/d}}\right)}\\ p_{i,2j+1}&=\cos{\left(\frac{i}{10000^{2j/d}}\right)} \end{align} pi,2j?pi,2j+1??=sin(100002j/di?)=cos(100002j/di?)??
#@save
class PositionalEncoding(nn.Module):
"""位置編碼"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 創(chuàng)建一個足夠長的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
在位置嵌入矩陣 P \boldsymbol{P} P 中,行代表詞元在序列中的位置,列代表位置編碼的不同維度。從下面的例子中可以看到位置嵌入矩陣的第 6 列和第 7 列的頻率高于第 8 列和第 9 列。第 6 列和第 7 列之間的偏移量(第 8 列和第 9 列相同)是由于正弦函數(shù)和余弦函數(shù)的交替。
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
?
?
10.6.3.1 絕對位置信息
打印出 0 , 1 , … , 7 0,1,\dots,7 0,1,…,7 的二進(jìn)制表示形式即可明白沿著編碼維度單調(diào)降低的頻率與絕對位置信息的關(guān)系。
每個數(shù)字、每兩個數(shù)字和每四個數(shù)字上的比特值在第一個最低位、第二個最低位和第三個最低位上分別交替。
for i in range(8):
print(f'{i}的二進(jìn)制是:{i:>03b}')
0的二進(jìn)制是:000
1的二進(jìn)制是:001
2的二進(jìn)制是:010
3的二進(jìn)制是:011
4的二進(jìn)制是:100
5的二進(jìn)制是:101
6的二進(jìn)制是:110
7的二進(jìn)制是:111
在二進(jìn)制表示中,較高比特位的交替頻率低于較低比特位,與下面的熱圖所示相似,只是位置編碼通過使用三角函數(shù)在編碼維度上降低頻率。由于輸出是浮點數(shù),因此此類連續(xù)表示比二進(jìn)制表示法更節(jié)省空間。
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
?
?
10.6.3.2 相對位置信息
除了捕獲絕對位置信息之外,上述的位置編碼還允許模型學(xué)習(xí)得到輸入序列中相對位置信息。這是因為對于任何確定的位置偏移 δ \delta δ,位置 i + δ i+\delta i+δ 處的位置編碼可以線性投影位置 i i i 處的位置編碼來表示。
這種投影的數(shù)學(xué)解釋是,令 ω j = 1 / 1000 0 2 j / d \omega_j=1/10000^{2j/d} ωj?=1/100002j/d,對于任何確定的位置偏移 δ \delta δ,上個式子中的任何一對 ( p i , 2 j , p i , 2 j + 1 ) (p_{i,2j},p_{i,2j+1}) (pi,2j?,pi,2j+1?) 都可以線性投影到 ( p i + δ , 2 j , p i + δ , 2 j + 1 ) (p_{i+\delta,2j},p_{i+\delta,2j+1}) (pi+δ,2j?,pi+δ,2j+1?):
[ cos ? ( δ ω j ) sin ? ( δ ω j ) ? sin ? ( δ ω j ) cos ? ( δ ω j ) ] [ p i , 2 j p i , 2 j + 1 ] = [ cos ? ( δ ω j ) sin ? ( i ω j ) + sin ? ( δ ω j ) cos ? ( i ω j ) ? sin ? ( δ ω j ) sin ? ( i ω j ) + cos ? ( δ ω j ) cos ? ( i ω j ) ] = [ sin ? ( ( i + δ ) ω j ) cos ? ( ( i + δ ) ω j ) ] = [ p i , 2 j p i , 2 j + 1 ] \begin{align} &\begin{bmatrix} \cos{(\delta\omega_j)} & \sin{(\delta\omega_j)}\\ -\sin{(\delta\omega_j)} & \cos{(\delta\omega_j)} \end{bmatrix} \begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix}\\ =&\begin{bmatrix} \cos{(\delta\omega_j)}\sin{(i\omega_j)}+\sin{(\delta\omega_j)}\cos{(i\omega_j)}\\ -\sin{(\delta\omega_j)}\sin{(i\omega_j)}+\cos{(\delta\omega_j)}\cos{(i\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} \sin{((i+\delta)\omega_j)}\\ \cos{((i+\delta)\omega_j)} \end{bmatrix}\\ =&\begin{bmatrix} p_{i,2j}\\ p_{i,2j+1} \end{bmatrix} \end{align} ===?[cos(δωj?)?sin(δωj?)?sin(δωj?)cos(δωj?)?][pi,2j?pi,2j+1??][cos(δωj?)sin(iωj?)+sin(δωj?)cos(iωj?)?sin(δωj?)sin(iωj?)+cos(δωj?)cos(iωj?)?][sin((i+δ)ωj?)cos((i+δ)ωj?)?][pi,2j?pi,2j+1??]??
2 × 2 2\times 2 2×2 投影矩陣不依賴于任何位置的索引 i i i。
練習(xí)
(1)假設(shè)設(shè)計一個深度架構(gòu),通過堆疊基于位置編碼的自注意力層來表示序列??赡軙嬖谑裁磫栴}?文章來源:http://www.zghlxwxcb.cn/news/detail-734599.html
(2)請設(shè)計一種可學(xué)習(xí)的位置編碼方法。文章來源地址http://www.zghlxwxcb.cn/news/detail-734599.html
到了這里,關(guān)于《動手學(xué)深度學(xué)習(xí) Pytorch版》 10.6 自注意力和位置編碼的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!