原理
Vanilla Transformer 與 LLaMa 的區(qū)別
Vanilla Transformer 與 LLaMa
的對比:LLaMa與普通的Transformer架構(gòu)不同的地方,包括采用了前置了層歸一化(Pre-normalization)并使用RMSNorm 歸一化函數(shù)(Normalizing Function)、使用了旋轉(zhuǎn)位置嵌入(RoPE)、激活函數(shù)由ReLU更換為SwiGLU,并且將self-attention改進(jìn)為使用KV-Cache的Grouped Query,整體Transformer架構(gòu)與GPT-2 類似。
LLaMa -> Alpaca -> Vicuna
的演進(jìn):
-
LLaMa:Meta開源的Pre-trained Model,
模型參數(shù)從7B、13B、32B、65B不等
,LLaMa-7B在大多數(shù)基準(zhǔn)測試上超過了Text-davinci-003(即GPT3-173B),相比于ChatGPT或者GPT4來說,LLaMa可能效果上還有差距,目前hugging face已集成了LLaMa的代碼實(shí)現(xiàn)和開源模型。學(xué)術(shù)界和工業(yè)界都可以在此基礎(chǔ)上進(jìn)行學(xué)習(xí)和研究。 -
Alpaca:斯坦福
在LLaMa-7B的基礎(chǔ)上監(jiān)督微調(diào)
出來的模型,斯坦福是用OpenAI的Text-davinci-003(即GPT3-173B)的API
配合self-instruct
技術(shù),使用175個(gè)提示語種子自動(dòng)生成了52K條提示-回復(fù)的指示數(shù)據(jù)集,在LLaMa-7B上微調(diào)得到的模型,在8張80G的A100上訓(xùn)練了3小時(shí)。 -
Vicuna:
在LLaMa-13B的基礎(chǔ)上使用監(jiān)督微調(diào)
得到的模型,數(shù)據(jù)集來自于ShareGPT 產(chǎn)生的用戶對話數(shù)據(jù)
,共70K條。使用Pytorch FSDP在8張A100上訓(xùn)練了一天。相較于Alpaca,Vicuna在訓(xùn)練中將序列長度由512擴(kuò)展到了2048
,并且通過梯度檢測和flash attention來解決內(nèi)存問題;調(diào)整訓(xùn)練損失考慮多輪對話,并僅根據(jù)模型的輸出進(jìn)行微調(diào)。通過GPT4來打分評測,Vicuna可以達(dá)到ChatGPT 90%的效果。 -
LLaMa2:采用了Llama 1的大部分預(yù)訓(xùn)練設(shè)置和模型架構(gòu)。LLaMa2和LLaMa1的最大差別是
增加了文本長度
,并在訓(xùn)練34B、70B
的模型中應(yīng)用了GQA
。
Embedding
Embedding的過程:word -> token_id -> embedding_vector
,其中第一步轉(zhuǎn)化使用tokenizer的詞表進(jìn)行,第二步轉(zhuǎn)化使用 learnable 的 Embedding layer。
RMS Norm
對比 Batch Norm 和 Layer Norm
:都是減去均值Mean,除以方差Var,最終將歸一化為正態(tài)分布N(0,1)
。只不過兩者是在不同的維度(batch還是feature)求均值和方差,(其中,減均值:re-centering
將均值mean變換為0,除方差:re-scaling
將方差varance變換為1)。
RMS Norm(Root Mean Layer Norm)
:RMS Norm認(rèn)為,Layer Norm成功的原因是re-scaling
,因?yàn)榉讲頥ar計(jì)算的過程中使用了均值Mean,因此RMS Norm不再使用均值Mean,而是構(gòu)造了一個(gè)特殊的統(tǒng)計(jì)量RMS
代替方差Var。為什么使用RMS Norm?(1)RMS Norm計(jì)算量更小。(2)RMS的效果和Layer Norm一樣好。
針對輸入向量 a 的RMS Norm 函數(shù)計(jì)算公式如下:
此外,RMSNorm 還可以引入可學(xué)習(xí)的縮放因子gi 和偏移參數(shù)bi,從而得到
RMSNorm 在HuggingFace Transformer 庫中代碼實(shí)現(xiàn)如下所示:
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps # eps 防止取倒數(shù)之后分母為0
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# weight 是末尾乘的可訓(xùn)練參數(shù), 即g_i
return (self.weight * hidden_states).to(input_dtype)
為了使得模型訓(xùn)練過程更加穩(wěn)定,GPT-2 相較于GPT 就提出了將Layer Norm前置,將第一個(gè)層歸一化移動(dòng)到多頭自注意力層之前,第二個(gè)層歸一化也移動(dòng)到了全連接層之前,同時(shí)殘差連接的位置也調(diào)整到了多頭自注意力層與全連接層之后。層歸一化中也采用了RMSNorm 歸一化函數(shù)。
Rotary Positional Encodding
普通絕對Positional Encodding的使用過程:word -> token_id -> embedding_vector + position_encodding -> Encoder_Input
,其中第一步轉(zhuǎn)化使用tokenizer的詞表進(jìn)行,第二步轉(zhuǎn)化使用 learnable 的 Embedding layer。將得到的embedding_vector 和 position_encodding 進(jìn)行element-wise的相加,然后才做為input送入LLM的encoder。
對比Absolute PE 和 Relative PE
:
-
Absolute PE 絕對位置編碼
:每次單獨(dú)1個(gè)token的PE,每個(gè)token的PE之間沒有關(guān)系,是一組固定的vector,反映每個(gè)token在sequence中的絕對位置。 -
Relative PE 相對位置編碼
:每次處理2個(gè)token的PE,只在計(jì)算attention時(shí)使用(在query@key
時(shí)加在key上),反映2個(gè)token的相關(guān)度。
旋轉(zhuǎn)位置編碼(RoPE)
:RoPE 借助了復(fù)數(shù)的思想,出發(fā)點(diǎn)是通過絕對位置編碼的方式實(shí)現(xiàn)相對位置編碼。其目標(biāo)是通過下述 f 運(yùn)算,來給q,k 添加絕對位置信息m和n,得到?qm 和?kn,然后進(jìn)行q@k:
實(shí)際上,我們借助了復(fù)數(shù)的思想
,尋找了一個(gè) g 運(yùn)算
來合并 f 運(yùn)算
和q@k
這兩個(gè)操作,這樣只需要token q
和k
以及兩者的在seqence中的絕對位置m
和n
即可:
可以看到與普通的相對位置編碼不同,旋轉(zhuǎn)相對位置編碼用于在計(jì)算attention_score=q@k
之后,對attention_score強(qiáng)調(diào)每個(gè)token之間的相對位置:
為什么叫旋轉(zhuǎn)位置編碼?因?yàn)槭褂?code>歐拉公式構(gòu)造旋轉(zhuǎn)矩陣
,將q@k的計(jì)算結(jié)果旋轉(zhuǎn)到空間中對應(yīng)的位置,實(shí)現(xiàn)對計(jì)算結(jié)果添加位置信息。
上面是2維的例子,只有2個(gè)token xm
和xn
,LLaMa中是n維的,n個(gè)token也是一樣操作:
由于上述旋轉(zhuǎn)矩陣Rn 具有稀疏性,有大量元素是0,因此可以使用逐位相乘?
操作進(jìn)一步加快計(jì)算速度。
RoPE 在HuggingFace Transformer 庫中代碼實(shí)現(xiàn)如下所示:
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`.
# Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype),
persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype),
persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
SwiGLU Function
SwiGLU 激活函數(shù)是Shazeer 在文獻(xiàn)中提出,并在PaLM等模中進(jìn)行了廣泛應(yīng)用,并且取得了不錯(cuò)的效果,相較于ReLU 函數(shù)在大部分評測中都有不少提升。在LLaMA 中全連接層使用帶有SwiGLU 激活函數(shù)的FFN(Position-wise Feed-Forward Network)的計(jì)算公式如下:
其中,σ(x) 是Sigmoid 函數(shù)。下圖給出了Swish 激活函數(shù)在參數(shù)β 不同取值下的形狀??梢钥吹疆?dāng)β 趨近于0 時(shí),Swish 函數(shù)趨近于線性函數(shù)y = x,當(dāng)β 趨近于無窮大時(shí),Swish 函數(shù)趨近于ReLU 函數(shù),β 取值為1 時(shí),Swish 函數(shù)是光滑且非單調(diào)。
HuggingFace 的Transformer 庫中
S
w
i
s
h
β
=
1
Swish_{\beta=1}
Swishβ=1?函數(shù)使用 SILU 函數(shù) 代替。
KV-Cache
首先來了解一下LLama的訓(xùn)練(下詞預(yù)測任務(wù)):seq2seq的生成,但迭代T次,seq_len
逐漸增加。
下句預(yù)測時(shí)的Self-Attention:
- timpstep=1時(shí)
seq_len=1
,給[SOS]時(shí),預(yù)測Love; - timpstep=2時(shí)
seq_len=2
,給[SOS] 和 Love時(shí),預(yù)測that - timpstep=4時(shí)
seq_len=4
,給[SOS] 和 Love 和 can 和 quickly時(shí),預(yù)測seize…
每個(gè)timestep我們只關(guān)注生成的最后一個(gè)token
,但因?yàn)長LaMa是一個(gè)seq2seq的model,每次必須重新計(jì)算和生成前面的token,因此我們希望能將之前timestep計(jì)算生成過的token給緩存起來,下個(gè)timestep不用再次計(jì)算,這樣的背景下,KV-Cache就產(chǎn)生了。
再來分析一下,每次個(gè)timestep的self-attention中我們到底需要哪些:因?yàn)槲覀冎魂P(guān)注最后一個(gè)token的attention_output
,如下圖timestep=4,我們只需要attention_output的第4個(gè)token。
因此我們只需要Q的最后一個(gè)token和K的所有token相乘,得到最后一個(gè)token的attention_score
,然后用V的所有token再與attention_score
點(diǎn)積(相乘求和),得到最后一個(gè)token的attention_output
:
由上分析可知,每個(gè)timestep,我們的Q只需要新增的那個(gè)token即可,而K和V要緩存之前timestep的token,保證token是全的。每次計(jì)算出來的attention_output就是那個(gè)新增的token的attention。 這樣就可以節(jié)省大量計(jì)算開銷。
Grouped Multi-Query Attention
回顧原始的多頭注意力Multi-Head Attention:時(shí)間開銷的瓶頸在于矩陣的運(yùn)算matrix computation
。
當(dāng)我們使用KV-Cache后:時(shí)間開銷的瓶頸在于內(nèi)存的訪問memory access
。
Multi Query Attention
多查詢注意力(Multi Query Attention,MQA
) 是多頭注意力的一種變體。其主要區(qū)別在于,在多查詢注意力中不同的注意力頭共享一個(gè)鍵和值的集合,每個(gè)頭只單獨(dú)保留了一份查詢參數(shù)。 具體操作上,去除 K和V 的head維度,只為Q保留head維度
。因此這就是被叫做Multi Query Attention的原因。
因此K和V的矩陣僅有一份(不分head),這大幅度減少了顯存占用,使其更高效。由于多查詢注意力改變了注意力機(jī)制的結(jié)構(gòu),因此模型通常需要從訓(xùn)練開始就支持多查詢注意力。
研究結(jié)果表明,可以通過對已經(jīng)訓(xùn)練好的模型進(jìn)行微調(diào)來添加多查詢注意力支持,僅需要約 5% 的原始訓(xùn)練數(shù)據(jù)量就可以達(dá)到不錯(cuò)的效果。包括Falcon、SantaCoder、StarCoder等在內(nèi)很多模型都采用了多查詢注意力機(jī)制。
以LLM Foundry 為例,多查詢注意力實(shí)現(xiàn)代碼如下,與LLM Foundry 中實(shí)現(xiàn)的多頭自注意力代碼相對比,其區(qū)別僅在于建立Wqkv 層上:
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
device: Optional[str] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.Wqkv = nn.Linear( # Multi-Query Attention 創(chuàng)建
d_model,
d_model + 2 * self.head_dim, # 只創(chuàng)建查詢的頭向量,所以只有1 個(gè)d_model
device=device, # 而鍵和值則共享各自的一個(gè)head_dim 的向量
)
self.attn_fn = scaled_multihead_dot_product_attention
self.out_proj = nn.Linear(
self.d_model,
self.d_model,
device=device
)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
):
qkv = self.Wqkv(x) # (1, 512, 960)
query, key, value = qkv.split( # query -> (1, 512, 768)
[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
dim=2 # value -> (1, 512, 96)
)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
multiquery=True,
)
return self.out_proj(context), attn_weights, past_key_value
Grouped Multi-Query Attention
就是在 Multi-Query Attention的基礎(chǔ)上,對input進(jìn)行分組,每組都有自己的K,V,以及多頭Q。
文章來源:http://www.zghlxwxcb.cn/news/detail-823704.html
源碼
[LLMs 實(shí)踐] 01 llama、alpaca、vicuna 整體介紹及 llama 推理過程文章來源地址http://www.zghlxwxcb.cn/news/detail-823704.html
到了這里,關(guān)于LLaMa 原理+源碼——拆解 (KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU)的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!