Attention Is All You Need——集中一下注意力
- Transformer其實(shí)不是完全的Self-Attention(SA,自注意力)結(jié)構(gòu),還帶有Cross-Attention(CA,交叉注意力)、殘差連接、LayerNorm、類似1維卷積的Position-wise Feed-Forward Networks(FFN)、MLP和Positional Encoding(位置編碼)等
- 本文涵蓋Transformer所采用的MHSA(多頭自注意力)、LayerNorm、FFN、位置編碼
- 對(duì)1維卷積的詳解請(qǐng)參考深入理解TDNN(Time Delay Neural Network)——兼談x-vector網(wǎng)絡(luò)結(jié)構(gòu)
- 對(duì)SA的Q、K、V運(yùn)算的詳解請(qǐng)參考深入理解Self-attention(自注意力機(jī)制)
Transformer的訓(xùn)練和推理
- 序列任務(wù)有三種:
- 序列轉(zhuǎn)錄:輸入序列長(zhǎng)度為N,輸出序列長(zhǎng)度為M,例如機(jī)器翻譯
- 序列標(biāo)注:輸入序列長(zhǎng)度為N,輸出序列長(zhǎng)度也為N,例如詞性標(biāo)注
- 序列總結(jié):輸入序列長(zhǎng)度為N,輸出為分類結(jié)果,例如聲紋識(shí)別
- 前兩個(gè)序列任務(wù),常用Transformer進(jìn)行統(tǒng)一建模,Transformer是一種Encoder-Decoder結(jié)構(gòu)。在Transformer中:
- 推理時(shí)
- Encoder負(fù)責(zé)將輸入 ( x 1 , x 2 , . . . , x n ) (x_1, x_2, ..., x_n) (x1?,x2?,...,xn?),編碼成隱藏單元(Hidden Unit) ( z 1 , z 2 , . . . , z n ) (z_1, z_2, ..., z_n) (z1?,z2?,...,zn?),Decoder根據(jù)隱藏單元和過(guò)去時(shí)刻的輸出 ( y 1 , y 2 , . . . , y t ? 1 ) (y_{1}, y_{2}, ..., y_{t-1}) (y1?,y2?,...,yt?1?), y 0 y_{0} y0?為起始符號(hào)"s"或者 y 0 = 0 y_{0}=0 y0?=0(很少見),解碼出當(dāng)前時(shí)刻的輸出 y t y_{t} yt?,Decoder全部的輸出表示為 ( y 1 , y 2 , . . . , y m ) (y_{1}, y_{2}, ..., y_{m}) (y1?,y2?,...,ym?)
- 由于當(dāng)前時(shí)刻的輸出只依賴輸入和過(guò)去時(shí)刻的輸出(不包含未來(lái)信息),因此這種輸出的生成方式是自回歸式的,也叫因果推斷(Causal Inference)
- 訓(xùn)練時(shí)
- Encoder行為不變,Decoder根據(jù)隱藏單元和過(guò)去時(shí)刻的label
(
y
^
1
,
y
^
2
,
.
.
.
,
y
^
t
?
1
)
(\hat{y}_{1}, \hat{y}_{2}, ..., \hat{y}_{t-1})
(y^?1?,y^?2?,...,y^?t?1?),解碼出當(dāng)前時(shí)刻的輸出
y
t
y_{t}
yt?,由于需要對(duì)每個(gè)
y
t
y_{t}
yt?計(jì)算損失,而系統(tǒng)必須是因果的,因此每次解碼時(shí),需要Mask掉未來(lái)的信息,也就是全部置為
?
∞
-\infty
?∞(從而Softmax運(yùn)算后接近0),當(dāng)label為“s I am a student”,則Decoder每一時(shí)刻的輸入,如下圖
- 這種將label作為Decoder的輸入的訓(xùn)練方式叫做Teacher Forcing,類似上述推理時(shí)將輸出作為Decoder的輸入的訓(xùn)練方式叫做Free Running。Teacher Forcing允許并行計(jì)算出每個(gè)時(shí)刻的輸出,因此是最常用的
- Encoder行為不變,Decoder根據(jù)隱藏單元和過(guò)去時(shí)刻的label
(
y
^
1
,
y
^
2
,
.
.
.
,
y
^
t
?
1
)
(\hat{y}_{1}, \hat{y}_{2}, ..., \hat{y}_{t-1})
(y^?1?,y^?2?,...,y^?t?1?),解碼出當(dāng)前時(shí)刻的輸出
y
t
y_{t}
yt?,由于需要對(duì)每個(gè)
y
t
y_{t}
yt?計(jì)算損失,而系統(tǒng)必須是因果的,因此每次解碼時(shí),需要Mask掉未來(lái)的信息,也就是全部置為
?
∞
-\infty
?∞(從而Softmax運(yùn)算后接近0),當(dāng)label為“s I am a student”,則Decoder每一時(shí)刻的輸入,如下圖
Transformer的Encoder和Decoder
- Transformer的Encoder行為與上述一致,設(shè)Encoder的輸入特征圖形狀為 ( n , d m o d e l ) (n, d_{model}) (n,dmodel?),即長(zhǎng)度為n的序列,序列的每個(gè)元素是 d m o d e l d_{model} dmodel?維的向量,Encoder Layer(如下圖左邊重復(fù)N次的結(jié)構(gòu))是不改變輸入特征圖形狀的,并且Encoder Layer內(nèi)部的Sub-layer也是不改變輸入特征圖形狀的,從而Encoder的輸出特征圖形狀也為 ( n , d m o d e l ) (n, d_{model}) (n,dmodel?)
- 這樣設(shè)計(jì)的原因是:每個(gè)Encoder Layer都有兩次殘差連接(如下圖中的Add運(yùn)算),殘差連接要求輸入輸出特征圖形狀不變,為了減少超參數(shù),所以這樣設(shè)計(jì)
LayerNorm
- LayerNorm(如上圖中的Norm運(yùn)算)常用在可變長(zhǎng)度序列任務(wù)里,接下來(lái)通過(guò)對(duì)比BatchNorm和LayerNorm,認(rèn)識(shí)LayerNorm
- 左圖為BN,C為單個(gè)樣本的特征維度(即特征圖的Channels,表示特征的數(shù)量),H、W為特征的形狀,因?yàn)樘卣骺梢允蔷仃囈部梢允窍蛄?,因此統(tǒng)稱特征形狀。BN希望將每個(gè)特征變成0均值1方差,再變換成新的均值和方差,因此需要在一個(gè)Batch中,找尋每個(gè)樣本的該特征,然后計(jì)算該特征的統(tǒng)計(jì)量,由于每個(gè)特征的統(tǒng)計(jì)量需要單獨(dú)維護(hù),因此構(gòu)造BN需要傳入特征的數(shù)量,也就是C。同時(shí),BN的可學(xué)習(xí)參數(shù) w e i g h t + b i a s = 2 ? C weight+bias=2*C weight+bias=2?C
- 中圖為L(zhǎng)N,LN希望不依賴Batch,將單個(gè)樣本的所有特征變成0均值1方差,再變換成新的均值和方差,因此需要指定樣本形狀,告訴LN如何計(jì)算統(tǒng)計(jì)量,由于樣本中的每個(gè)值,都進(jìn)行均值和方差的變換,因此構(gòu)造LN需要傳入樣本的形狀,也就是C、H、W。同時(shí),LN的可學(xué)習(xí)參數(shù) w e i g h t + b i a s = 2 ? C ? H ? W weight+bias=2*C*H*W weight+bias=2?C?H?W
- 示例:
>>> input=torch.rand([1, 3, 2, 2])
>>> input
tensor([[[[0.1181, 0.6704],
[0.7010, 0.8031]],
[[0.0630, 0.2088],
[0.2150, 0.6469]],
[[0.5746, 0.4949],
[0.3656, 0.7391]]]])
>>> layer_norm=torch.nn.LayerNorm((3, 2, 2), eps=1e-05)
>>> output=layer_norm(input)
>>> output
tensor([[[[-1.3912, 0.8131],
[ 0.9349, 1.3424]],
[[-1.6113, -1.0293],
[-1.0047, 0.7191]],
[[ 0.4308, 0.1126],
[-0.4035, 1.0872]]]], grad_fn=<NativeLayerNormBackward0>)
>>> output[0].mean()
tensor(-1.7385e-07, grad_fn=<MeanBackward0>)
>>> output[0].std()
tensor(1.0445, grad_fn=<StdBackward0>)
>>> layer_norm.weight.shape
torch.Size([3, 2, 2])
>>> layer_norm.bias.shape
torch.Size([3, 2, 2])
# 等價(jià)于
>>> mean=input.mean(dim=(-1, -2, -3), keepdim=True)
>>> var=input.var(dim=(-1, -2, -3), keepdim=True, unbiased=False)
>>> (input-mean)/torch.sqrt(var+1e-05)
tensor([[[[-1.3912, 0.8131],
[ 0.9349, 1.3424]],
[[-1.6113, -1.0293],
[-1.0047, 0.7191]],
[[ 0.4308, 0.1126],
[-0.4035, 1.0872]]]])
- 上述兩種情況為計(jì)算機(jī)視覺中的BN和LN,可以看出,BN訓(xùn)練時(shí)需要更新統(tǒng)計(jì)量,從而推理時(shí)使用統(tǒng)計(jì)量進(jìn)行Norm,而LN訓(xùn)練和推理時(shí)的行為是一致的
- 在序列任務(wù)中,特征形狀為1,多出來(lái)一個(gè)序列長(zhǎng)度Seq_len,其他不變,1維的BN(BatchNorm1d)在N*Seq_len個(gè)幀中,計(jì)算每個(gè)特征的統(tǒng)計(jì)量,從而序列任務(wù)中的幀形狀是C,因此LN要傳入的幀形狀是C,并且Input的形狀中,C這個(gè)維度要放在最后
- 1維的BN常用于聲紋識(shí)別,但是Transformer風(fēng)格的模型基本都采用LN,并且LN是適用于任何特征形狀的,BN則根據(jù)特征形狀不同,衍生出BatchNorm1d、BatchNorm2d等
- 示例
>>> input=torch.rand([1, 200, 80])
>>> layer_norm=torch.nn.LayerNorm(80)
>>> layer_norm(input)[0][0].mean()
tensor(8.3447e-08, grad_fn=<MeanBackward0>)
>>> layer_norm(input)[0][1].mean()
tensor(-8.0466e-08, grad_fn=<MeanBackward0>)
>>> layer_norm(input)[0][0].std()
tensor(1.0063, grad_fn=<StdBackward0>)
>>> layer_norm(input)[0][1].std()
tensor(1.0063, grad_fn=<StdBackward0>)
- 在序列任務(wù)中采用LN而不是BN的原因
- 序列任務(wù)的樣本很多時(shí)候是不等長(zhǎng)的,很多時(shí)候要補(bǔ)0幀,當(dāng)batch-size較小時(shí),BN的統(tǒng)計(jì)量波動(dòng)較大,而LN是對(duì)每一幀進(jìn)行Norm的,不受補(bǔ)0幀的影響
- 訓(xùn)練時(shí)要構(gòu)造一個(gè)Batch,因此序列長(zhǎng)度只能固定,但是推理時(shí)序列長(zhǎng)度是可變的,采用BN容易過(guò)擬合序列長(zhǎng)度,LN則不容易過(guò)擬合序列長(zhǎng)度
SA(自注意力)與CA(交叉注意力)
- 對(duì)于一個(gè)輸入序列 ( seq-len , d m o d e l ) (\text{seq-len}, d_{model}) (seq-len,dmodel?),SA通過(guò)Q、K、V計(jì)算矩陣,計(jì)算得到對(duì)應(yīng)長(zhǎng)度的Q、K、V序列,這些序列構(gòu)成Q、K、V矩陣
- 如果Q、K、V矩陣完全由同一個(gè)序列計(jì)算而來(lái),則稱為自注意力SA;如果V、K矩陣由同一個(gè)序列計(jì)算而來(lái),Q矩陣由另一個(gè)序列計(jì)算而來(lái),則稱為交叉注意力CA
-
需要注意,Decoder Layer中的第二個(gè)MHSA(如下圖),從左到右的輸入,計(jì)算順序是V、K、Q,其中V、K是根據(jù)輸入的隱藏單元進(jìn)行計(jì)算的,即
(
z
1
,
z
2
,
.
.
.
,
z
n
)
(z_1, z_2, ..., z_n)
(z1?,z2?,...,zn?),得到的V、K矩陣形狀分別為
(
n
,
d
k
)
(n, d_k)
(n,dk?)、
(
n
,
d
v
)
(n, d_v)
(n,dv?),而Q是根據(jù)輸出的隱藏單元進(jìn)行計(jì)算的,即
(
z
^
1
,
z
^
2
,
.
.
.
,
z
^
m
)
(\hat{z}_1, \hat{z}_2, ..., \hat{z}_m)
(z^1?,z^2?,...,z^m?),得到的Q矩陣形狀為
(
m
,
d
k
)
(m, d_k)
(m,dk?),因此這是一種CA運(yùn)算
- SA和CA的后續(xù)運(yùn)算是一致的,V、K、Q矩陣通過(guò)Attention函數(shù)計(jì)算Attention分?jǐn)?shù),然后對(duì)V矩陣進(jìn)行加權(quán)求和,得到Attention的輸出。Transformer用的Attention函數(shù)是Scaled Dot-Product Attention,公式如下:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk??QKT?)V - 如果是Decoder的Attention函數(shù)則需要Mask掉softmax的輸出,使得未來(lái)時(shí)刻對(duì)應(yīng)的V接近0,如下圖:
- 計(jì)算的細(xì)化過(guò)程如下圖:
- Q K T QK^T QKT內(nèi)積的含義是計(jì)算相似度,因此中間 ( m , n ) (m, n) (m,n)矩陣的第m行,表示第m個(gè)query對(duì)所有key的相似度
- 之后除以 d k \sqrt{d_k} dk??進(jìn)行Scale,并且Mask(具體操作為將未來(lái)時(shí)刻對(duì)應(yīng)的點(diǎn)積結(jié)果置為 ? ∞ -\infty ?∞,從而Softmax運(yùn)算后接近0),然后對(duì) ( m , n ) (m, n) (m,n)矩陣的每一行進(jìn)行Softmax
- 最后output矩陣的第m行,表示第m個(gè)權(quán)重對(duì)不同幀的value進(jìn)行加權(quán)求和
-
需要注意的是
- Attention最后的輸出,序列長(zhǎng)度由Q決定,向量維度由V決定
- Q和K的向量維度一致,序列長(zhǎng)度可以不同;K和V的序列長(zhǎng)度一致,向量維度可以不同
- Softmax是在計(jì)算第m個(gè)query對(duì)不同key的相似度的權(quán)重,求和為1
- 除以 d k \sqrt{d_k} dk??的原因是因?yàn)楹竺嫘枰M(jìn)行Softmax運(yùn)算,具有最大值主導(dǎo)效果。當(dāng) d k d_k dk?較小時(shí),點(diǎn)積的結(jié)果差異不大,當(dāng) d k d_k dk?較大時(shí),點(diǎn)積的結(jié)果波動(dòng)較大(假設(shè)每個(gè)query和key都是0均值1方差的多維隨機(jī)變量,則它們的點(diǎn)積 q ? k = ∑ i = 1 d k q i k i q \cdot k=\sum_{i=1}^{d_k} q_ik_i q?k=∑i=1dk??qi?ki?,為0均值 d k d_k dk?方差的多維隨機(jī)變量),從而Softmax后,大量值接近0,這樣會(huì)導(dǎo)致梯度變得很小,不利于收斂。因此除以一個(gè)值,會(huì)使得這些點(diǎn)積結(jié)果的值變小,從而Softmax運(yùn)算的最大值主導(dǎo)效果不明顯
MHSA
文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-848268.html
- 多頭注意力的動(dòng)機(jī)是:與其將輸入投影到較高的維度,計(jì)算單個(gè)注意力,不如將輸入投影到h個(gè)較低的維度,計(jì)算h個(gè)注意力,然后將h個(gè)注意力的輸出在特征維度Concat起來(lái),最后利用MLP進(jìn)行多頭特征聚合,得到MHSA的輸出。MHSA的公式如下:
MultiHead ( Q , K , V ) = Concat ( h e a d 1 , h e a d 2 , . . . , h e a d h ) W O h e a d i = Attention ( Q i , K i , V i ) \begin{aligned} \text{MultiHead}(Q, K, V)&=\text{Concat}(head_1, head_2, ..., head_h)W^O \\ head_i&=\text{Attention}(Q_i, K_i, V_i) \end{aligned} MultiHead(Q,K,V)headi??=Concat(head1?,head2?,...,headh?)WO=Attention(Qi?,Ki?,Vi?)? - 由于MHSA不能改變輸入輸出形狀,所以每個(gè)SA的設(shè)計(jì)是:當(dāng) d m o d e l = 512 d_{model}=512 dmodel?=512, h = 8 h=8 h=8時(shí), d k = d v = d m o d e l / h = 64 d_k=d_v=d_{model}/h=64 dk?=dv?=dmodel?/h=64
- 在實(shí)際運(yùn)算時(shí),可以通過(guò)一個(gè)大的矩陣運(yùn)算,將輸入投影到 ( n , d m o d e l ) (n, d_{model}) (n,dmodel?),然后在特征維度Split成h個(gè)矩陣,Q、K、V都可如此操作
- 因此一個(gè)MHSA的參數(shù)量: 4 ? d m o d e l ? d m o d e l = 4 ? d m o d e l 2 4*d_{model}*d_{model}=4*d^2_{model} 4?dmodel??dmodel?=4?dmodel2?,即Q、K、V加最后的MLP
FFN
- FFN的操作和MHSA中最后的MLP非常相似的,公式和圖如下:
FFN ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x)=max(0,xW_1+b_1)W_2+b_2 FFN(x)=max(0,xW1?+b1?)W2?+b2? - 采用同一個(gè)MLP,對(duì)輸入特征的每一幀進(jìn)行維度變換(通常是增大為4倍),然后RELU,最后再采用另一個(gè)MLP,將輸入特征的每一幀恢復(fù)回原來(lái)的維度
- 因此一個(gè)FFN的參數(shù)量: d m o d e l ? 4 ? d m o d e l + 4 ? d m o d e l ? d m o d e l = 8 ? d m o d e l 2 d_{model}*4*d_{model}+4*d_{model}*d_{model}=8*d^2_{model} dmodel??4?dmodel?+4?dmodel??dmodel?=8?dmodel2?,即維度提升MLP加維度恢復(fù)MLP
- 綜合,一個(gè)Encoder Layer的參數(shù)量為: 12 ? d m o d e l 2 12*d^2_{model} 12?dmodel2?,一個(gè)Decoder Layer的參數(shù)量為: 16 ? d m o d e l 2 16*d^2_{model} 16?dmodel2?
- 上述所有參數(shù)估計(jì)忽略LayerNorm的參數(shù),因?yàn)槠鋽?shù)量級(jí)較小
Embedding Layer和Softmax
- Encoder和Decoder的Embedding Layer,以及最后的Softmax輸出前,都有一個(gè)MLP,在Transformer中,這三個(gè)MLP是共享參數(shù)的,形狀都是 ( dict-len , d m o d e l ) (\text{dict-len}, d_{model}) (dict-len,dmodel?), dict-len \text{dict-len} dict-len是字典大小
- 在Embedding Layer中,權(quán)重都被除以了 d m o d e l \sqrt{d_{model}} dmodel??,從而Embedding的輸出范圍在[-1, 1]附近,這是為了讓Embedding的值范圍靠近Positional Encoding,從而可以直接相加
Positional Encoding(位置編碼)
- Attention的輸出是不具有時(shí)序信息的,如果把輸入打亂,那么也只會(huì)導(dǎo)致對(duì)應(yīng)的輸出打亂而已,不會(huì)有導(dǎo)致值變化,但序列任務(wù)往往關(guān)注時(shí)序信息,一件事先發(fā)生和后發(fā)生,意義是不一樣的,因此需要對(duì)Attention的輸入添加位置編碼
- 位置編碼的公式如下:
PE ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i d m o d e l ) PE ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i d m o d e l ) \begin{aligned} \text{PE}(pos, 2i)=sin(pos/10000^{\frac{2i}{d_{model}}}) \\ \text{PE}(pos, 2i+1)=cos(pos/10000^{\frac{2i}{d_{model}}}) \end{aligned} PE(pos,2i)=sin(pos/10000dmodel?2i?)PE(pos,2i+1)=cos(pos/10000dmodel?2i?)? - pos表示幀的位置,第二個(gè)參數(shù)表示特征的位置,奇偶交替,也就說(shuō):不同位置的同一特征,根據(jù)位置映射不同頻率的正弦函數(shù)進(jìn)行編碼;同一位置的不同特征,根據(jù)奇偶分布映射不同頻率的正弦函數(shù)進(jìn)行編碼
- 位置編碼值的范圍是[-1, 1](Embedding的權(quán)重需要除以 d m o d e l \sqrt{d_{model}} dmodel??的原因),與Embedding對(duì)應(yīng)元素相加,即可輸入到Attention中
文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-848268.html
到了這里,關(guān)于深入理解Transformer,兼談MHSA(多頭自注意力)、Cross-Attention(交叉注意力)、LayerNorm、FFN、位置編碼的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!