- 參考資料:
https://zhuanlan.zhihu.com/p/636784644
https://spaces.ac.cn/archives/8265 ——《Transformer升級之路:2、博采眾長的旋轉(zhuǎn)式位置編碼》
前言:本次閱讀代碼位置,在transformers庫底下的modeling_llama.py,具體位置在:transformers/models/llama/modeling_llama.py,如下圖所示:
1. LlamaModel整體結(jié)構(gòu)流程圖
2. LlamaRMSNorm
- 代碼如下
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
-
RMSNorm的公式如下所示:
x i 1 n ∑ i = 1 n x i 2 + e p s ? w e i g h t i \frac{x_i}{\sqrt{\frac{1}{n}\sum\limits_{i=1}^{n}{x_i}^2 + eps}} * weight_i n1?i=1∑n?xi?2+eps?xi???weighti?- 其中,公式與代碼的對應(yīng)關(guān)系如下:
- 其中,公式與代碼的對應(yīng)關(guān)系如下:
3. LlamaMLP
- 代碼如下:
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-
流程圖:
-
其中輸入為x,輸出為y
-
代碼中intermediate_size一般比hidden_size大,我們通過在jupyter notebook中打印Llama-13B的模型,可以看到如下所示:
-
總結(jié):MLP模塊就是幾個(gè)nn.Linear的組合
4. LlamaRotaryEmbedding
- 代碼如下
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
- 具體的使用,還調(diào)用了另外兩個(gè)函數(shù),如下所示:
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
-
注意這里的實(shí)現(xiàn)跟原始推導(dǎo)有點(diǎn)區(qū)別,這里實(shí)現(xiàn)的方式如下圖所示:
-
原始推導(dǎo)如下圖所示:
具體可以查看作者的博客:??戳我?? -
總結(jié):RoPE就是在attention計(jì)算時(shí),K跟Q做內(nèi)積之前,先給各自注入位置信息。文章來源:http://www.zghlxwxcb.cn/news/detail-685218.html
結(jié)束。文章來源地址http://www.zghlxwxcb.cn/news/detail-685218.html
到了這里,關(guān)于Llama模型結(jié)構(gòu)解析(源碼閱讀)的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!