0.開(kāi)源代碼地址
官方實(shí)現(xiàn):state-spaces/mamba (github.com)
最簡(jiǎn)化實(shí)現(xiàn):johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)
直接實(shí)現(xiàn):alxndrTL/mamba.py: A simple and efficient Mamba implementation in PyTorch and MLX. (github.com)
官方代碼做了大量?jī)?yōu)化,目錄層級(jí)較多,對(duì)于理解模型含義較難,這里老師對(duì)上面第二最簡(jiǎn)化實(shí)現(xiàn)的代碼進(jìn)行了詳細(xì)注釋,該代碼性能比官方實(shí)現(xiàn)差,但是對(duì)于理解模型原理比較直白。
這段代碼的主要組成部分包括模型參數(shù)類ModelArgs
、完整的Mamba模型類Mamba
、殘差塊類ResidualBlock
、單個(gè)Mamba塊類MambaBlock
、RMSNorm歸一化類以及一些輔助函數(shù)。
1算法核心
的算法圖,原始論文中給出的Mamba(SSSM:Selective state Space model of )的前身S4(SSM:Structured State Space Model):
S6模塊
S6模塊是Mamba架構(gòu)中的一個(gè)復(fù)雜組件,負(fù)責(zé)通過(guò)一系列線性變換和離散化過(guò)程處理輸入序列。它在捕獲序列的時(shí)間動(dòng)態(tài)方面起著關(guān)鍵作用,這是序列建模任務(wù)(如語(yǔ)言建模)的一個(gè)關(guān)鍵方面。這里包括張量運(yùn)算和自定義離散化方法來(lái)處理序列數(shù)據(jù)的復(fù)雜需求。
離散化def discretization(self)中有兩行代碼提出來(lái)解釋,論文中離散化采用零階保持:
A  ̄ = e x p ( Δ A ) \overline{A}=exp(\Delta A) A=exp(ΔA) :對(duì)應(yīng)代碼中的self.dA?
B  ̄ = ( Δ A ) ? 1 ( exp ? ( Δ A ) ? I ) ? Δ B \overline{B}=(\Delta A)^{-1}(\exp(\Delta A)-I)\cdot\Delta B B=(ΔA)?1(exp(ΔA)?I)?ΔB::對(duì)應(yīng)代碼中的self.dB
各個(gè)張量維度如下:

2.Mamba模型定義
2.1 ModelArgs 類
ModelArgs
類是用于存儲(chǔ)和處理Mamba模型配置參數(shù)的容器。它使用Python的dataclass
裝飾器來(lái)自動(dòng)生成初始化方法和類的字符串表示方法,簡(jiǎn)化了代碼的編寫(xiě)。這個(gè)類中的每個(gè)屬性對(duì)應(yīng)于構(gòu)建Mamba模型所需的一個(gè)配置參數(shù),例如模型的隱藏層維度、層數(shù)、詞匯表大小等。__post_init__
方法在初始化后自動(dòng)調(diào)用,用于執(zhí)行一些額外的設(shè)置,比如計(jì)算內(nèi)部維度d_inner
和自動(dòng)調(diào)整詞匯表大小,以確保模型的配置參數(shù)是有效的和一致的。
# 使用dataclass裝飾器自動(dòng)生成初始化方法和類的字符串表示方法
@dataclass
class ModelArgs:
# @dataclass 會(huì)自動(dòng)為這個(gè)類生成初始化方法和代表類的字符串形式的方法
d_model: int # 定義模型的隱藏層維度
n_layer: int # 定義模型的層數(shù)
vocab_size: int # 定義詞匯表的大小
d_state: int = 16 # 定義狀態(tài)空間的維度,默認(rèn)為16
expand: int = 2 # 定義擴(kuò)展因子,默認(rèn)為2
dt_rank: Union[int, str] = 'auto' # 定義輸入依賴步長(zhǎng)Δ的秩,'auto'表示自動(dòng)設(shè)置
d_conv: int = 4 # 定義卷積核的維度,默認(rèn)為4
pad_vocab_size_multiple: int = 8 # 定義詞匯表大小的最小公倍數(shù),默認(rèn)為8
conv_bias: bool = True # 定義卷積層是否使用偏置項(xiàng)
bias: bool = False # 定義其他層(如線性層)是否使用偏置項(xiàng)
def __post_init__(self):
# 在__init__后自動(dòng)被調(diào)用,用于執(zhí)行初始化之后的額外設(shè)置或驗(yàn)證
# 計(jì)算內(nèi)部維度,即擴(kuò)展后的維度
self.d_inner = int(self.expand * self.d_model)
if self.dt_rank == 'auto':# 如果dt_rank未指定,則自動(dòng)計(jì)算設(shè)置
# 根據(jù)隱藏層維度自動(dòng)計(jì)算Δ的秩
self.dt_rank = math.ceil(self.d_model / 16)
# 確保vocab_size是pad_vocab_size_multiple的倍數(shù)
# 如果不是,調(diào)整為最近的倍數(shù)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
2.2 Mamba 類
Mamba
類是Mamba模型的主干,繼承自PyTorch的nn.Module
類。這個(gè)類的實(shí)例化對(duì)象將構(gòu)成整個(gè)Mamba模型的結(jié)構(gòu)和前向傳播邏輯。
在初始化方法__init__
中,首先調(diào)用父類的構(gòu)造函數(shù)來(lái)初始化模型。然后,根據(jù)傳入的ModelArgs
對(duì)象中的參數(shù)配置模型的各個(gè)組件:
-
self.embedding
是一個(gè)嵌入層,它將輸入的詞匯ID轉(zhuǎn)換為對(duì)應(yīng)的向量表示。這些向量隨后會(huì)被送入模型的深層網(wǎng)絡(luò)中。 -
self.layers
是一個(gè)模塊列表,其中包含了多個(gè)ResidualBlock
殘差塊。這些殘差塊有助于訓(xùn)練深層網(wǎng)絡(luò)并防止梯度消失問(wèn)題。 -
self.norm_f
是一個(gè)RMSNorm歸一化模塊,用于在模型的某些層之后進(jìn)行歸一化操作,以穩(wěn)定訓(xùn)練過(guò)程。 -
self.lm_head
是一個(gè)線性層,它將模型的最終隱藏狀態(tài)映射回詞匯表的大小,以便進(jìn)行下一步的預(yù)測(cè)或分類任務(wù)。
在forward
方法中,定義了模型的前向傳播邏輯。輸入input_ids
首先通過(guò)嵌入層轉(zhuǎn)換為向量表示,然后依次通過(guò)每個(gè)殘差塊進(jìn)行處理。經(jīng)過(guò)所有層之后,模型的輸出通過(guò)RMSNorm歸一化,最后通過(guò)線性層self.lm_head
得到最終的logits輸出。這個(gè)輸出可以用于后續(xù)的損失計(jì)算或生成任務(wù)。
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""Full Mamba model."""
super().__init__()
# 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
self.args = args
# 創(chuàng)建一個(gè)嵌入層,將詞匯表中的詞轉(zhuǎn)換為對(duì)應(yīng)的向量表示
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
# 創(chuàng)建一個(gè)包含多個(gè)殘差塊的模塊列表,殘差塊的數(shù)量等于模型層數(shù)
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
# 創(chuàng)建一個(gè)RMSNorm模塊,用于歸一化操作
self.norm_f = RMSNorm(args.d_model)
# 創(chuàng)建一個(gè)線性層,用于最終的輸出,將隱藏層的輸出映射回詞匯表的大小
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
# 將線性層的輸出權(quán)重與嵌入層的權(quán)重綁定,這是權(quán)重共享的一種形式,有助于減少參數(shù)數(shù)量并可能提高模型的泛化能力
self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
# See "Weight Tying" paper
def forward(self, input_ids):
"""
Args:
input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
logits: shape (b, l, vocab_size)
Official Implementation:
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
"""
# 將輸入ID轉(zhuǎn)換為向量表示
x = self.embedding(input_ids)
# 遍歷所有的殘差塊,并應(yīng)用它們
for layer in self.layers:
x = layer(x)
# 應(yīng)用歸一化操作
x = self.norm_f(x)
# 通過(guò)線性層得到最終的logits輸出
logits = self.lm_head(x)
# 返回模型的輸出
return logits
解釋一下:為什么輸入的input_ids已經(jīng)是經(jīng)過(guò)分詞器(tokenizer)處理后的詞匯表索引,還需要通過(guò)nn.Embedding?
這些索引代表了輸入文本中的單詞或子詞單元在詞匯表中的位置。盡管這些索引已經(jīng)是一個(gè)相對(duì)緊湊的數(shù)值表示,但它們并不直接對(duì)應(yīng)于模型可以處理的向量表示。
nn.Embedding
層的作用是將這些離散的索引映射到一個(gè)連續(xù)的向量空間中。每個(gè)索引input_ids
中的值都會(huì)被nn.Embedding
層轉(zhuǎn)換成一個(gè)固定維度的向量,這個(gè)向量捕捉了對(duì)應(yīng)單詞或子詞的語(yǔ)義信息。這個(gè)轉(zhuǎn)換過(guò)程是模型學(xué)習(xí)的一部分,通過(guò)訓(xùn)練數(shù)據(jù)中的模式,模型可以學(xué)習(xí)到如何將這些索引映射到能夠有效表示輸入文本的向量。
2.3 ResidualBlock 類
定義了Mamba模型中的一個(gè)殘差塊。這個(gè)類的目的是為了在模型中引入殘差連接,這有助于訓(xùn)練深層網(wǎng)絡(luò),因?yàn)樗试S梯度直接流過(guò)網(wǎng)絡(luò),從而緩解了梯度消失問(wèn)題。
在__init__
方法中,首先調(diào)用父類nn.Module
的構(gòu)造函數(shù)來(lái)初始化殘差塊。然后,根據(jù)傳入的ModelArgs
對(duì)象中的參數(shù)配置殘差塊的組件:
-
self.mixer
是一個(gè)MambaBlock
實(shí)例,它是這個(gè)殘差塊的核心組件,負(fù)責(zé)執(zhí)行Mamba模型的大部分計(jì)算。 -
self.norm
是一個(gè)RMSNorm
歸一化模塊,用于在將數(shù)據(jù)送入MambaBlock
之前進(jìn)行歸一化處理。
在forward
方法中,定義了殘差塊的前向傳播邏輯。輸入張量x
首先通過(guò)RMSNorm
模塊進(jìn)行歸一化,然后送入MambaBlock
。MambaBlock
的輸出接著與原始輸入x
相加,形成殘差連接。這樣做可以使得模型的學(xué)習(xí)更加靈活,因?yàn)樗试S模型學(xué)習(xí)到輸入和輸出之間的恒等映射(即不改變輸入數(shù)據(jù)),這在某些情況下是有益的。最后,殘差塊的輸出被返回。
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""Simple block wrapping Mamba block with normalization and residual connection."""
super().__init__()
# 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
self.args = args
# 創(chuàng)建一個(gè)MambaBlock,它是這個(gè)殘差塊的核心組件
self.mixer = MambaBlock(args)
# 創(chuàng)建一個(gè)RMSNorm歸一化模塊,用于歸一化操作
self.norm = RMSNorm(args.d_model)
def forward(self, x):
"""
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
x (Tensor): 輸入張量,形狀為(batch_size, sequence_length, hidden_size)
Returns:
output: shape (b, l, d)
輸出張量,形狀與輸入相同
Official Implementation:
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
Note: the official repo chains residual blocks that look like
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
where the first Add is a no-op. This is purely for performance reasons as this
allows them to fuse the Add->Norm.
We instead implement our blocks as the more familiar, simpler, and numerically equivalent
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
# 應(yīng)用歸一化和MambaBlock,然后與輸入x進(jìn)行殘差連接
output = self.mixer(self.norm(x)) + x
return output
2.4 MambaBlock 類
MambaBlock
類定義了Mamba模型中的一個(gè)基本構(gòu)建塊,即Mamba塊。這個(gè)塊是模型的核心組件,負(fù)責(zé)執(zhí)行序列數(shù)據(jù)的處理和狀態(tài)空間模型的更新。
在__init__
方法中,首先調(diào)用父類nn.Module
的構(gòu)造函數(shù)來(lái)初始化Mamba塊。然后,根據(jù)傳入的ModelArgs
對(duì)象中的參數(shù)配置Mamba塊的組件:
-
self.in_proj
是一個(gè)線性變換層,用于輸入的投影。 -
self.conv1d
是一個(gè)一維卷積層,用于執(zhí)行深度卷積,這是Mamba模型的特色之一,用于處理序列數(shù)據(jù)。 -
self.x_proj
和self.dt_proj
是線性變換層,用于將輸入映射到狀態(tài)空間模型的參數(shù)。 -
self.A_log
是矩陣A的對(duì)數(shù)值,作為一個(gè)可訓(xùn)練參數(shù)。 -
self.D
是矩陣D,初始化為全1,也是一個(gè)可訓(xùn)練參數(shù)。 -
self.out_proj
是一個(gè)線性變換層,用于輸出的投影。
在forward
方法中,定義了Mamba塊的前向傳播邏輯。輸入張量x
首先通過(guò)線性變換層和深度卷積層進(jìn)行處理,然后應(yīng)用激活函數(shù)。接著,通過(guò)狀態(tài)空間模型(ssm)和選擇性掃描(selective_scan)算法更新?tīng)顟B(tài),并計(jì)算輸出。最后,輸出通過(guò)另一個(gè)線性變換層進(jìn)行投影,得到最終的輸出結(jié)果。
ssm
方法負(fù)責(zé)運(yùn)行狀態(tài)空間模型,它使用矩陣A、B、C和D以及輸入x
來(lái)更新?tīng)顟B(tài)并計(jì)算輸出。
selective_scan
方法執(zhí)行選擇性掃描算法,這是Mamba模型的關(guān)鍵特性,它允許模型根據(jù)輸入動(dòng)態(tài)調(diào)整其行為,從而更好地處理序列數(shù)據(jù)。通過(guò)這種方式,Mamba模型能夠有效地捕捉序列中的長(zhǎng)期依賴關(guān)系,同時(shí)保持線性時(shí)間復(fù)雜度。
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
super().__init__()
# 保存模型參數(shù)
self.args = args
# 輸入線性變換層
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
# 創(chuàng)建了一個(gè)所謂的“深度卷積”,其中每個(gè)輸入通道被單獨(dú)卷積到每個(gè)輸出通道。
# 這意味著每個(gè)輸出通道的結(jié)果是通過(guò)僅與一個(gè)輸入通道卷積得到的。
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
bias=args.conv_bias,
kernel_size=args.d_conv,
groups=args.d_inner,
padding=args.d_conv - 1,
)
# x_proj takes in `x` and outputs the input-specific Δ, B, C
# 將輸入x映射到狀態(tài)空間模型的參數(shù)Δ、B和C
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
# dt_proj projects Δ from dt_rank to d_in
# 將Δ從args.dt_rank維度映射到args.d_inner維度
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
# 創(chuàng)建一個(gè)重復(fù)的序列,用于初始化狀態(tài)空間模型的矩陣A
# n->dxn
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
# 將矩陣A的對(duì)數(shù)值作為可訓(xùn)練參數(shù)保存
self.A_log = nn.Parameter(torch.log(A))
# 初始化矩陣D為全1的可訓(xùn)練參數(shù)
self.D = nn.Parameter(torch.ones(args.d_inner))
# 輸出線性變換層
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def forward(self, x):
"""MambaBlock的前向傳播函數(shù),與Mamba論文圖3 Section 3.4相同.
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
# 獲取輸入x的維度
# batchsize,seq_len,dim
(b, l, d) = x.shape # 獲取輸入x的維度
# 應(yīng)用輸入線性變換
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
# 將變換后的輸出分為兩部分x和res。
# 得到的x分為兩個(gè)部分,一部分x繼續(xù)用于后續(xù)變換,生成所需要的參數(shù),res用于殘差部分
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
# 調(diào)整x的形狀
x = rearrange(x, 'b l d_in -> b d_in l')
# 應(yīng)用深度卷積,然后截取前l(fā)個(gè)輸出
x = self.conv1d(x)[:, :, :l]
# 再次調(diào)整x的形狀
x = rearrange(x, 'b d_in l -> b l d_in')
# 應(yīng)用SiLU激活函數(shù)
x = F.silu(x)
# 運(yùn)行狀態(tài)空間模型
y = self.ssm(x)
# 將res的SiLU激活結(jié)果與y相乘
y = y * F.silu(res)
# 應(yīng)用輸出線性變換
output = self.out_proj(y)
# 返回輸出結(jié)果
return output
def ssm(self, x):
"""運(yùn)行狀態(tài)空間模型,參考Mamba論文 Section 3.2和注釋[2]:
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
Args:
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d_in)
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
# 獲取A_log的維度
# A在初始化時(shí)候經(jīng)過(guò)如下賦值:
# A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
# self.A_log = nn.Parameter(torch.log(A))
# (args.d_inner, args.d_state)
(d_in, n) = self.A_log.shape # 獲取A_log的維度
# 計(jì)算 ? A B C D, 這些屬于狀態(tài)空間參數(shù).
# A, D 是 與輸入無(wú)關(guān)的 (見(jiàn)Mamba論文Section 3.5.2 "Interpretation of A" for why A isn't selective)
# ?, B, C 與輸入有關(guān)(這是與線性是不變模型S4最大的不同,
# 也是為什么Mamba被稱為 “選擇性” 狀態(tài)空間的原因)
# 計(jì)算矩陣A
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
# 取D的值
D = self.D.float()
# 應(yīng)用x的投影變換
# ( b,l,d_in) -> (b, l, dt_rank + 2*n)
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
# 分割delta, B, C
# delta: (b, l, dt_rank). B, C: (b, l, n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
# 應(yīng)用dt_proj并計(jì)算delta
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
# 應(yīng)用選擇性掃描算法
y = self.selective_scan(x, delta, A, B, C, D)
return y
def selective_scan(self, u, delta, A, B, C, D):
"""執(zhí)行選擇性掃描算法,參考Mamba論文[1] Section 2和注釋[2]. See:
- Section 2 State Space Models in the Mamba paper [1]
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
經(jīng)典的離散狀態(tài)空間公式:
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
除了B和C (以及step size delta用于離散化) 與輸入x(t)相關(guān).
參數(shù):
u: shape (b, l, d_in)
delta: shape (b, l, d_in)
A: shape (d_in, n)
B: shape (b, l, n)
C: shape (b, l, n)
D: shape (d_in,)
過(guò)程概述:
Returns:
output: shape (b, l, d_in)
Official Implementation:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
"""
# 獲取輸入u的維度
(b, l, d_in) = u.shape
# 獲取矩陣A的列數(shù)
n = A.shape[1] # A: shape (d_in, n)
# 離散化連續(xù)參數(shù)(A, B)
# - A 使用 zero-order hold (ZOH) 離散化 (see Section 2 Equation 4 in the Mamba paper [1])
# - B is 使用一個(gè)簡(jiǎn)化的Euler discretization而不是ZOH.根據(jù)作者的討論:
# "A is the more important term and the performance doesn't change much with the simplification on B"
# 計(jì)算離散化的A
# 將delta和A進(jìn)行點(diǎn)乘,將A沿著delta的最后一個(gè)維度進(jìn)行廣播,然后執(zhí)行逐元素乘法
# A:(d_in, n),delta:(b, l, d_in)
# A廣播拓展->(b,l,d_in, n),deltaA對(duì)應(yīng)原論文中的A_bar
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
# delta、B和u,這個(gè)計(jì)算和原始論文不同
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
# is additionally hardware-aware (like FlashAttention).
# 執(zhí)行選擇性掃描,初始化狀態(tài)x為零
x = torch.zeros((b, d_in, n), device=deltaA.device)
# 初始化輸出列表ys
ys = []
for i in range(l):
# 更新?tīng)顟B(tài)x
# deltaA:((b,l,d_in, n)
# deltaB_u:( b,l,d_in,n)
# x:
x = deltaA[:, i] * x + deltaB_u[:, i]
# 計(jì)算輸出y
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
# 將輸出y添加到列表ys中
ys.append(y)
# 將列表ys堆疊成張量y
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
# 將輸入u乘以D并加到輸出y上
y = y + u * D
return y
解釋1:深度卷積的幾行代碼
x = rearrange(x, ‘b l d_in -> b d_in l’) 調(diào)整x的形狀
這行代碼使用
rearrange
函數(shù)將輸入張量x
的形狀從(batch_size, sequence_length, d_model)
轉(zhuǎn)換為(batch_size, d_model, sequence_length)
。這種形狀調(diào)整是為了適配后續(xù)的一維卷積層self.conv1d
,該卷積層期望輸入的形狀為(batch_size, channels, length)
,其中channels
是卷積核的深度,length
是序列的長(zhǎng)度。x = self.conv1d(x)[:, :, :l] 應(yīng)用深度卷積
self.conv1d
是一個(gè)一維卷積層,它沿著序列長(zhǎng)度l
的方向應(yīng)用卷積核。由于self.conv1d
的in_channels
參數(shù)設(shè)置為args.d_inner
,這意味著卷積操作是在d_model
維的特征空間內(nèi)獨(dú)立進(jìn)行的。卷積操作的輸出是一個(gè)三維張量,其形狀為(batch_size, d_inner, sequence_length)
。然后,代碼通過(guò)切片操作[:, :, :l]
只保留了序列長(zhǎng)度為l
的輸出,這是因?yàn)槲覀冎粚?duì)序列中的前l
個(gè)元素感興趣。x = rearrange(x, ‘b d_in l -> b l d_in’ 再次調(diào)整x的形狀
最后,為了繼續(xù)后續(xù)的計(jì)算,需要將卷積后的張量形狀再次調(diào)整回
(batch_size, sequence_length, d_model)
。這樣做是為了確保數(shù)據(jù)在后續(xù)層中的流動(dòng)是連貫的,特別是當(dāng)數(shù)據(jù)傳遞給后續(xù)的Mamba塊或其他層時(shí)。這里的rearrange
函數(shù)將卷積輸出的形狀從(batch_size, d_inner, sequence_length)
轉(zhuǎn)換回(batch_size, sequence_length, d_inner)
解釋2:A = -torch.exp(self.A_log.float())前面的負(fù)號(hào)
這里的負(fù)號(hào)
-
是因?yàn)樵跔顟B(tài)空間模型中,矩陣A
通常表示的是一個(gè)離散時(shí)間系統(tǒng)的轉(zhuǎn)換矩陣,它描述了系統(tǒng)狀態(tài)隨時(shí)間的演變。在許多情況下,A
矩陣的元素應(yīng)該是負(fù)的,以確保系統(tǒng)的穩(wěn)定性。這是因?yàn)樵陔x散時(shí)間系統(tǒng)中,我們希望系統(tǒng)的狀態(tài)隨著時(shí)間的推移而衰減或穩(wěn)定下來(lái),而不是增長(zhǎng),從而避免系統(tǒng)變得不穩(wěn)定或發(fā)散。
解釋3:狀態(tài)空間更新代碼
這兩行代碼首先根據(jù)當(dāng)前時(shí)間步的轉(zhuǎn)換矩陣
deltaA
和輸入影響deltaB_u
更新?tīng)顟B(tài)向量x
,然后計(jì)算狀態(tài)向量x
和輸出矩陣C
的點(diǎn)乘,得到當(dāng)前時(shí)間步的輸出y
。這個(gè)過(guò)程是狀態(tài)空間模型中的核心計(jì)算步驟,它允許模型動(dòng)態(tài)地處理序列數(shù)據(jù)并生成響應(yīng)。
x = deltaA[:, i] * x + deltaB_u[:, i]
deltaA
是一個(gè)四維張量,其形狀為(batch_size, sequence_length, d_in, n)
。這里deltaA[:, i]
表示我們選擇了deltaA
張量中第i
個(gè)時(shí)間步的切片,形狀變?yōu)?code>(batch_size, d_in, n)。x
是狀態(tài)向量,其形狀為(batch_size, d_in, n)
,代表當(dāng)前時(shí)間步的狀態(tài)。deltaB_u
是一個(gè)四維張量,其形狀也為(batch_size, sequence_length, d_in, n)
,它是通過(guò)delta
、B
和輸入u
計(jì)算得到的,代表了輸入對(duì)狀態(tài)的直接影響。- 這行代碼首先執(zhí)行
deltaA[:, i] * x
,這是一個(gè)逐元素乘法操作,它根據(jù)當(dāng)前時(shí)間步的轉(zhuǎn)換矩陣更新?tīng)顟B(tài)向量x
。由于deltaA[:, i]
的形狀是(batch_size, d_in, n)
,它可以直接與形狀相同的x
進(jìn)行逐元素乘法。- 接著,代碼執(zhí)行
+ deltaB_u[:, i]
,將輸入的影響加到更新后的狀態(tài)向量x
上。這里的deltaB_u[:, i]
是deltaB_u
張量中第i
個(gè)時(shí)間步的切片,形狀也是(batch_size, d_in, n)
。y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
- 這行代碼使用
einsum
函數(shù)來(lái)計(jì)算輸出y
。einsum
是PyTorch中的一個(gè)函數(shù),用于執(zhí)行復(fù)雜的張量運(yùn)算。x
是當(dāng)前狀態(tài)向量,形狀為(batch_size, d_in, n)
。C[:, i, :]
是從輸出參數(shù)矩陣C
中取出的第i
個(gè)時(shí)間步的切片,形狀為(batch_size, n, d_in)
。'b d_in n, b n -> b d_in'
是einsum
的索引模式,它指示了如何執(zhí)行點(diǎn)乘和求和操作。在這個(gè)模式中,'b'
表示批次維度保持不變,'d_in n'
表示x
的第二個(gè)和第三個(gè)維度與C
的第二個(gè)維度進(jìn)行點(diǎn)乘,'b d_in'
表示輸出y
的形狀應(yīng)該與x
的前兩個(gè)維度相同。- 結(jié)果
y
的形狀是(batch_size, d_in)
,它是模型在當(dāng)前時(shí)間步對(duì)輸入序列的響應(yīng)。
2.5 RMSNorm 類
這個(gè)類實(shí)現(xiàn)了基于均方根的歸一化操作。它接收輸入x
,計(jì)算其均方根值,并使用這個(gè)值來(lái)歸一化輸入。這種歸一化有助于模型的訓(xùn)練穩(wěn)定性。
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5):
"""
初始化RMSNorm模塊,該模塊實(shí)現(xiàn)了基于均方根的歸一化操作。
參數(shù):
d_model (int): 模型的特征維度。
eps (float, 可選): 為了避免除以零,添加到分母中的一個(gè)小的常數(shù)。
"""
super().__init__()
self.eps = eps # 保存輸入的eps值,用于數(shù)值穩(wěn)定性
self.weight = nn.Parameter(torch.ones(d_model)) # 創(chuàng)建一個(gè)可訓(xùn)練的權(quán)重參數(shù),初始值為全1,維度與輸入特征維度d_model相同
def forward(self, x):
"""
定義RMSNorm模塊的前向傳播函數(shù)。
參數(shù):
x (Tensor): 輸入的張量,通常是一個(gè)特征矩陣,其形狀為(batch_size, sequence_length, d_model)。
返回:
output (Tensor): 歸一化后的特征矩陣。
"""
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight # 計(jì)算歸一化輸出
return output # 返回歸一化后的輸出
小結(jié):狀態(tài)空間參數(shù)是如何與輸入相關(guān)的
這個(gè)是S6與S4的根本區(qū)別
在上面的MambaBlock
類的代碼中,狀態(tài)空間的參數(shù)與輸入相關(guān)性體現(xiàn)在self.x_proj
和self.dt_proj
的使用上,以及在ssm
和selective_scan
方法中的計(jì)算過(guò)程中。
-
self.x_proj
和self.dt_proj
:-
self.x_proj
是一個(gè)線性變換層,它將輸入x
映射到狀態(tài)空間模型的參數(shù)Δ、B和C。這個(gè)映射是輸入依賴的,因?yàn)?code>x是模型的輸入,它的值會(huì)直接影響到這些參數(shù)的計(jì)算。 -
self.dt_proj
是一個(gè)線性變換層,用于將Δ從args.dt_rank
維度映射到args.d_inner
維度。雖然這個(gè)映射本身是一個(gè)固定的線性變換,但它的輸入(即x
)是依賴于輸入數(shù)據(jù)的。
-
-
ssm
方法:- 在
ssm
方法中,計(jì)算了狀態(tài)空間模型的參數(shù)A、B、C和D。其中,A和D是與輸入無(wú)關(guān)的,而B(niǎo)和C是通過(guò)self.x_proj
和self.dt_proj
從輸入x
中計(jì)算得到的,因此它們與輸入是相關(guān)的。
- 在
-
selective_scan
方法:-
selective_scan
方法執(zhí)行選擇性掃描算法,它是狀態(tài)空間模型的核心計(jì)算過(guò)程。在這個(gè)方法中,輸入u
(實(shí)際上是x
經(jīng)過(guò)一系列變換后的結(jié)果)與狀態(tài)空間參數(shù)Δ、A、B、C和D一起使用,來(lái)更新?tīng)顟B(tài)并計(jì)算輸出。 - 方法中的
deltaA
和deltaB_u
計(jì)算顯示了輸入u
如何影響狀態(tài)空間參數(shù)。deltaA
是通過(guò)einsum
函數(shù)將輸入u
的每個(gè)元素與矩陣A的每個(gè)元素進(jìn)行點(diǎn)乘得到的,這意味著輸入的每個(gè)元素都會(huì)影響A的每個(gè)元素。 -
deltaB_u
是通過(guò)einsum
函數(shù)將輸入u
、矩陣B和Δ進(jìn)行三元組乘法得到的,這進(jìn)一步表明輸入u
直接影響了狀態(tài)空間參數(shù)B的計(jì)算。
-
總的來(lái)說(shuō),狀態(tài)空間的參數(shù)與輸入相關(guān)性是通過(guò)輸入數(shù)據(jù)x
直接影響Δ、B和C的計(jì)算來(lái)實(shí)現(xiàn)的。這種相關(guān)性使得Mamba模型能夠根據(jù)輸入數(shù)據(jù)的不同動(dòng)態(tài)調(diào)整其內(nèi)部狀態(tài),從而更好地捕捉序列數(shù)據(jù)的特性。這是Mamba模型區(qū)別于傳統(tǒng)的線性時(shí)不變(LTI)狀態(tài)空間模型的關(guān)鍵特性。
3.模型測(cè)試代碼
3.1 加載模型
from model import Mamba, ModelArgs
from transformers import AutoTokenizer
# One of:
# 'state-spaces/mamba-2.8b-slimpj'
# 'state-spaces/mamba-2.8b'
# 'state-spaces/mamba-1.4b'
# 'state-spaces/mamba-790m'
# 'state-spaces/mamba-370m'
# 'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'
model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
3.2 生成文本
這個(gè)函數(shù)通過(guò)迭代地向模型提供輸入,并基于模型預(yù)測(cè)的概率分布來(lái)生成下一個(gè)令牌,直到達(dá)到指定的令牌數(shù)量。生成過(guò)程中,可以通過(guò)top_k采樣來(lái)限制概率分布,或者通過(guò)采樣來(lái)隨機(jī)選擇令牌,從而增加生成文本的多樣性。最終,函數(shù)返回生成的文本
import torch
import torch.nn.functional as F
def generate(model,
tokenizer,
prompt: str,
n_tokens_to_gen: int = 50,
sample: bool = True,
top_k: int = 40):
# 將模型設(shè)置為評(píng)估模式,這通常會(huì)關(guān)閉dropout等訓(xùn)練時(shí)的特性。
model.eval()
# 使用分詞器將提示字符串轉(zhuǎn)換為模型可以處理的輸入ID。
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
# 循環(huán)生成指定數(shù)量的令牌。
for token_n in range(n_tokens_to_gen):
# 無(wú)需計(jì)算梯度,因?yàn)槲覀兪窃谏晌谋径皇怯?xùn)練模型。
with torch.no_grad():
# 準(zhǔn)備輸入模型的索引。
indices_to_input = input_ids
# 通過(guò)模型獲取當(dāng)前輸入的下一個(gè)令牌的logits。
next_token_logits = model(indices_to_input)[:, -1]
# 對(duì)logits應(yīng)用softmax函數(shù),將其轉(zhuǎn)換為概率分布。
probs = F.softmax(next_token_logits, dim=-1)
# 獲取概率分布的形狀,即批次大小和詞匯表大小。
(batch, vocab_size) = probs.shape
# 如果指定了top_k采樣,則獲取概率最高的k個(gè)令牌及其對(duì)應(yīng)的值和索引。
if top_k is not None:
(values, indices) = torch.topk(probs, k=top_k)
# 將概率低于最低top_k令牌的概率值設(shè)置為0。
probs[probs < values[:, -1, None]] = 0
# 重新歸一化概率分布,使得所有概率之和為1。
probs = probs / probs.sum(axis=1, keepdims=True)
# 如果采樣為T(mén)rue,則通過(guò)多項(xiàng)式采樣(Multinomial Sampling)來(lái)選擇下一個(gè)令牌。
if sample:
next_indices = torch.multinomial(probs, num_samples=1)
else:# 如果不采樣,則選擇概率最高的令牌作為下一個(gè)令牌。
next_indices = torch.argmax(probs, dim=-1)[:, None]
# 將生成的下一個(gè)令牌添加到輸入ID列表中。
input_ids = torch.cat([input_ids, next_indices], dim=1)
# 將最終的輸入ID轉(zhuǎn)換為文本,并解碼為可讀的字符串。
output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
# 返回生成的文本。
return output_completions
print(generate(model, tokenizer, 'Mamba is the'))
Mamba is the world’s longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-849915.html
print(generate(model, tokenizer, 'John: Hi!\nSally:'))
John: Hi!
Sally: Hey!
John: So, when’s the wedding?
Sally: We haven’t decided.
John: It’s in September.
Sally: Yeah, we were thinking July or August.文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-849915.html
附:完整模型代碼
"""Simple, minimal implementation of Mamba in one file of PyTorch.
Suggest reading the following before/while reading the code:
[1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
https://arxiv.org/abs/2312.00752
[2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
https://srush.github.io/annotated-s4
Glossary:
b: batch size (`B` in Mamba paper [1] Algorithm 2)
l: sequence length (`L` in [1] Algorithm 2)
d or d_model: hidden dim
n or d_state: latent state dim (`N` in [1] Algorithm 2)
expand: expansion factor (`E` in [1] Section 3.4)
d_in or d_inner: d * expand (`D` in [1] Algorithm 2)
A, B, C, D: state space parameters (See any state space representation formula)
(B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
Δ or delta: input-dependent step size
dt_rank: rank of Δ (See [1] Section 3.6 "Parameterization of ?")
"""
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
# 使用dataclass裝飾器自動(dòng)生成初始化方法和類的字符串表示方法
@dataclass
class ModelArgs:
# @dataclass 會(huì)自動(dòng)為這個(gè)類生成初始化方法和代表類的字符串形式的方法
d_model: int # 定義模型的隱藏層維度
n_layer: int # 定義模型的層數(shù)
vocab_size: int # 定義詞匯表的大小
d_state: int = 16 # 定義狀態(tài)空間的維度,默認(rèn)為16
expand: int = 2 # 定義擴(kuò)展因子,默認(rèn)為2
dt_rank: Union[int, str] = 'auto' # 定義輸入依賴步長(zhǎng)Δ的秩,'auto'表示自動(dòng)設(shè)置
d_conv: int = 4 # 定義卷積核的維度,默認(rèn)為4
pad_vocab_size_multiple: int = 8 # 定義詞匯表大小的最小公倍數(shù),默認(rèn)為8
conv_bias: bool = True # 定義卷積層是否使用偏置項(xiàng)
bias: bool = False # 定義其他層(如線性層)是否使用偏置項(xiàng)
def __post_init__(self):
# 在__init__后自動(dòng)被調(diào)用,用于執(zhí)行初始化之后的額外設(shè)置或驗(yàn)證
# 計(jì)算內(nèi)部維度,即擴(kuò)展后的維度
self.d_inner = int(self.expand * self.d_model)
if self.dt_rank == 'auto':# 如果dt_rank未指定,則自動(dòng)計(jì)算設(shè)置
# 根據(jù)隱藏層維度自動(dòng)計(jì)算Δ的秩
self.dt_rank = math.ceil(self.d_model / 16)
# 確保vocab_size是pad_vocab_size_multiple的倍數(shù)
# 如果不是,調(diào)整為最近的倍數(shù)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""Full Mamba model."""
super().__init__()
# 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
self.args = args
# 創(chuàng)建一個(gè)嵌入層,將詞匯表中的詞轉(zhuǎn)換為對(duì)應(yīng)的向量表示
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
# 創(chuàng)建一個(gè)包含多個(gè)殘差塊的模塊列表,殘差塊的數(shù)量等于模型層數(shù)
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
# 創(chuàng)建一個(gè)RMSNorm模塊,用于歸一化操作
self.norm_f = RMSNorm(args.d_model)
# 創(chuàng)建一個(gè)線性層,用于最終的輸出,將隱藏層的輸出映射回詞匯表的大小
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
# 將線性層的輸出權(quán)重與嵌入層的權(quán)重綁定,這是權(quán)重共享的一種形式,有助于減少參數(shù)數(shù)量并可能提高模型的泛化能力
self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
# See "Weight Tying" paper
def forward(self, input_ids):
"""
Args:
input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
logits: shape (b, l, vocab_size)
Official Implementation:
class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
"""
# 將輸入ID轉(zhuǎn)換為向量表示
x = self.embedding(input_ids)
# 遍歷所有的殘差塊,并應(yīng)用它們
for layer in self.layers:
x = layer(x)
# 應(yīng)用歸一化操作
x = self.norm_f(x)
# 通過(guò)線性層得到最終的logits輸出
logits = self.lm_head(x)
# 返回模型的輸出
return logits
@staticmethod
def from_pretrained(pretrained_model_name: str):
"""Load pretrained weights from HuggingFace into model.
Args:
pretrained_model_name: One of
* 'state-spaces/mamba-2.8b-slimpj'
* 'state-spaces/mamba-2.8b'
* 'state-spaces/mamba-1.4b'
* 'state-spaces/mamba-790m'
* 'state-spaces/mamba-370m'
* 'state-spaces/mamba-130m'
Returns:
model: Mamba model with weights loaded
"""
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
from transformers.utils.hub import cached_file
def load_config_hf(model_name):
resolved_archive_file = cached_file(model_name, CONFIG_NAME,
_raise_exceptions_for_missing_entries=False)
return json.load(open(resolved_archive_file))
def load_state_dict_hf(model_name, device=None, dtype=None):
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
config_data = load_config_hf(pretrained_model_name)
args = ModelArgs(
d_model=config_data['d_model'],
n_layer=config_data['n_layer'],
vocab_size=config_data['vocab_size']
)
model = Mamba(args)
state_dict = load_state_dict_hf(pretrained_model_name)
new_state_dict = {}
for key in state_dict:
new_key = key.replace('backbone.', '')
new_state_dict[new_key] = state_dict[key]
model.load_state_dict(new_state_dict)
return model
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""Simple block wrapping Mamba block with normalization and residual connection."""
super().__init__()
# 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
self.args = args
# 創(chuàng)建一個(gè)MambaBlock,它是這個(gè)殘差塊的核心組件
self.mixer = MambaBlock(args)
# 創(chuàng)建一個(gè)RMSNorm歸一化模塊,用于歸一化操作
self.norm = RMSNorm(args.d_model)
def forward(self, x):
"""
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
x (Tensor): 輸入張量,形狀為(batch_size, sequence_length, hidden_size)
Returns:
output: shape (b, l, d)
輸出張量,形狀與輸入相同
Official Implementation:
Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
Note: the official repo chains residual blocks that look like
[Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
where the first Add is a no-op. This is purely for performance reasons as this
allows them to fuse the Add->Norm.
We instead implement our blocks as the more familiar, simpler, and numerically equivalent
[Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
"""
# 應(yīng)用歸一化和MambaBlock,然后與輸入x進(jìn)行殘差連接
output = self.mixer(self.norm(x)) + x
return output
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
"""A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
super().__init__()
# 保存模型參數(shù)
self.args = args
# 輸入線性變換層
self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
# 創(chuàng)建了一個(gè)所謂的“深度卷積”,其中每個(gè)輸入通道被單獨(dú)卷積到每個(gè)輸出通道。
# 這意味著每個(gè)輸出通道的結(jié)果是通過(guò)僅與一個(gè)輸入通道卷積得到的。
self.conv1d = nn.Conv1d(
in_channels=args.d_inner,
out_channels=args.d_inner,
bias=args.conv_bias,
kernel_size=args.d_conv,
groups=args.d_inner,
padding=args.d_conv - 1,
)
# x_proj takes in `x` and outputs the input-specific Δ, B, C
# 將輸入x映射到狀態(tài)空間模型的參數(shù)Δ、B和C
self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
# dt_proj projects Δ from dt_rank to d_in
# 將Δ從args.dt_rank維度映射到args.d_inner維度
self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
# 創(chuàng)建一個(gè)重復(fù)的序列,用于初始化狀態(tài)空間模型的矩陣A
# n->dxn
A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
# 將矩陣A的對(duì)數(shù)值作為可訓(xùn)練參數(shù)保存
self.A_log = nn.Parameter(torch.log(A))
# 初始化矩陣D為全1的可訓(xùn)練參數(shù)
self.D = nn.Parameter(torch.ones(args.d_inner))
# 輸出線性變換層
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
def forward(self, x):
"""MambaBlock的前向傳播函數(shù),與Mamba論文圖3 Section 3.4相同.
Args:
x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d)
Official Implementation:
class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
# 獲取輸入x的維度
# batchsize,seq_len,dim
(b, l, d) = x.shape # 獲取輸入x的維度
# 應(yīng)用輸入線性變換
x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
# 將變換后的輸出分為兩部分x和res。
# 得到的x分為兩個(gè)部分,一部分x繼續(xù)用于后續(xù)變換,生成所需要的參數(shù),res用于殘差部分
(x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
# 調(diào)整x的形狀
x = rearrange(x, 'b l d_in -> b d_in l')
# 應(yīng)用深度卷積,然后截取前l(fā)個(gè)輸出
x = self.conv1d(x)[:, :, :l]
# 再次調(diào)整x的形狀
x = rearrange(x, 'b d_in l -> b l d_in')
# 應(yīng)用SiLU激活函數(shù)
x = F.silu(x)
# 運(yùn)行狀態(tài)空間模型
y = self.ssm(x)
# 將res的SiLU激活結(jié)果與y相乘
y = y * F.silu(res)
# 應(yīng)用輸出線性變換
output = self.out_proj(y)
# 返回輸出結(jié)果
return output
def ssm(self, x):
"""運(yùn)行狀態(tài)空間模型,參考Mamba論文 Section 3.2和注釋[2]:
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
Args:
x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)
Returns:
output: shape (b, l, d_in)
Official Implementation:
mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
"""
# 獲取A_log的維度
# A在初始化時(shí)候經(jīng)過(guò)如下賦值:
# A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
# self.A_log = nn.Parameter(torch.log(A))
# (args.d_inner, args.d_state)
(d_in, n) = self.A_log.shape # 獲取A_log的維度
# 計(jì)算 ? A B C D, 這些屬于狀態(tài)空間參數(shù).
# A, D 是 與輸入無(wú)關(guān)的 (見(jiàn)Mamba論文Section 3.5.2 "Interpretation of A" for why A isn't selective)
# ?, B, C 與輸入有關(guān)(這是與線性是不變模型S4最大的不同,
# 也是為什么Mamba被稱為 “選擇性” 狀態(tài)空間的原因)
# 計(jì)算矩陣A
A = -torch.exp(self.A_log.float()) # shape (d_in, n)
# 取D的值
D = self.D.float()
# 應(yīng)用x的投影變換
# ( b,l,d_in) -> (b, l, dt_rank + 2*n)
x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)
# 分割delta, B, C
# delta: (b, l, dt_rank). B, C: (b, l, n)
(delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
# 應(yīng)用dt_proj并計(jì)算delta
delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
# 應(yīng)用選擇性掃描算法
y = self.selective_scan(x, delta, A, B, C, D)
return y
def selective_scan(self, u, delta, A, B, C, D):
"""執(zhí)行選擇性掃描算法,參考Mamba論文[1] Section 2和注釋[2]. See:
- Section 2 State Space Models in the Mamba paper [1]
- Algorithm 2 in Section 3.2 in the Mamba paper [1]
- run_SSM(A, B, C, u) in The Annotated S4 [2]
經(jīng)典的離散狀態(tài)空間公式:
x(t + 1) = Ax(t) + Bu(t)
y(t) = Cx(t) + Du(t)
除了B和C (以及step size delta用于離散化) 與輸入x(t)相關(guān).
參數(shù):
u: shape (b, l, d_in)
delta: shape (b, l, d_in)
A: shape (d_in, n)
B: shape (b, l, n)
C: shape (b, l, n)
D: shape (d_in,)
過(guò)程概述:
Returns:
output: shape (b, l, d_in)
Official Implementation:
selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
"""
# 獲取輸入u的維度
(b, l, d_in) = u.shape
# 獲取矩陣A的列數(shù)
n = A.shape[1] # A: shape (d_in, n)
# 離散化連續(xù)參數(shù)(A, B)
# - A 使用 zero-order hold (ZOH) 離散化 (see Section 2 Equation 4 in the Mamba paper [1])
# - B is 使用一個(gè)簡(jiǎn)化的Euler discretization而不是ZOH.根據(jù)作者的討論:
# "A is the more important term and the performance doesn't change much with the simplification on B"
# 計(jì)算離散化的A
# 將delta和A進(jìn)行點(diǎn)乘,將A沿著delta的最后一個(gè)維度進(jìn)行廣播,然后執(zhí)行逐元素乘法
# A:(d_in, n),delta:(b, l, d_in)
# A廣播拓展->(b,l,d_in, n),deltaA對(duì)應(yīng)原論文中的A_bar
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
# delta、B和u,這個(gè)計(jì)算和原始論文不同
deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
# Perform selective scan (see scan_SSM() in The Annotated S4 [2])
# Note that the below is sequential, while the official implementation does a much faster parallel scan that
# is additionally hardware-aware (like FlashAttention).
# 執(zhí)行選擇性掃描,初始化狀態(tài)x為零
x = torch.zeros((b, d_in, n), device=deltaA.device)
# 初始化輸出列表ys
ys = []
for i in range(l):
# 更新?tīng)顟B(tài)x
# deltaA:((b,l,d_in, n)
# deltaB_u:( b,l,d_in,n)
# x:
x = deltaA[:, i] * x + deltaB_u[:, i]
# 計(jì)算輸出y
y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
# 將輸出y添加到列表ys中
ys.append(y)
# 將列表ys堆疊成張量y
y = torch.stack(ys, dim=1) # shape (b, l, d_in)
# 將輸入u乘以D并加到輸出y上
y = y + u * D
return y
class RMSNorm(nn.Module):
"""
初始化RMSNorm模塊,該模塊實(shí)現(xiàn)了基于均方根的歸一化操作。
參數(shù):
d_model (int): 模型的特征維度。
eps (float, 可選): 為了避免除以零,添加到分母中的一個(gè)小的常數(shù)。
"""
def __init__(self,
d_model: int,
eps: float = 1e-5):
super().__init__()
self.eps = eps# 保存輸入的eps值,用于數(shù)值穩(wěn)定性。
# 創(chuàng)建一個(gè)可訓(xùn)練的權(quán)重參數(shù),初始值為全1,維度與輸入特征維度d_model相同。
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
"""
計(jì)算輸入x的均方根值,用于后續(xù)的歸一化操作。
x.pow(2) 計(jì)算x中每個(gè)元素的平方。
mean(-1, keepdim=True) 對(duì)x的最后一個(gè)維度(特征維度)進(jìn)行平方和求平均,保持維度以便進(jìn)行廣播操作。
torch.rsqrt 對(duì)求得的平均值取倒數(shù)和平方根,得到每個(gè)特征的均方根值的逆。
+ self.eps 添加一個(gè)小的常數(shù)eps以保持?jǐn)?shù)值穩(wěn)定性,防止除以零的情況發(fā)生。
x * ... * self.weight 將輸入x與計(jì)算得到的歸一化因子和可訓(xùn)練的權(quán)重相乘,得到最終的歸一化輸出。
"""
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
到了這里,關(guān)于【論文閱讀筆記】Mamba模型代碼理解的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!