国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

【論文閱讀筆記】Mamba模型代碼理解

這篇具有很好參考價(jià)值的文章主要介紹了【論文閱讀筆記】Mamba模型代碼理解。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問(wèn)。

【論文閱讀筆記】Mamba模型代碼理解,Mamba,論文閱讀,筆記

0.開(kāi)源代碼地址

官方實(shí)現(xiàn):state-spaces/mamba (github.com)

最簡(jiǎn)化實(shí)現(xiàn):johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

直接實(shí)現(xiàn):alxndrTL/mamba.py: A simple and efficient Mamba implementation in PyTorch and MLX. (github.com)

官方代碼做了大量?jī)?yōu)化,目錄層級(jí)較多,對(duì)于理解模型含義較難,這里老師對(duì)上面第二最簡(jiǎn)化實(shí)現(xiàn)的代碼進(jìn)行了詳細(xì)注釋,該代碼性能比官方實(shí)現(xiàn)差,但是對(duì)于理解模型原理比較直白。

這段代碼的主要組成部分包括模型參數(shù)類ModelArgs、完整的Mamba模型類Mamba、殘差塊類ResidualBlock、單個(gè)Mamba塊類MambaBlock、RMSNorm歸一化類以及一些輔助函數(shù)。

1算法核心

的算法圖,原始論文中給出的Mamba(SSSM:Selective state Space model of )的前身S4(SSM:Structured State Space Model):

【論文閱讀筆記】Mamba模型代碼理解,Mamba,論文閱讀,筆記

S6模塊

S6模塊是Mamba架構(gòu)中的一個(gè)復(fù)雜組件,負(fù)責(zé)通過(guò)一系列線性變換和離散化過(guò)程處理輸入序列。它在捕獲序列的時(shí)間動(dòng)態(tài)方面起著關(guān)鍵作用,這是序列建模任務(wù)(如語(yǔ)言建模)的一個(gè)關(guān)鍵方面。這里包括張量運(yùn)算和自定義離散化方法來(lái)處理序列數(shù)據(jù)的復(fù)雜需求。

離散化def discretization(self)中有兩行代碼提出來(lái)解釋,論文中離散化采用零階保持:

A  ̄ = e x p ( Δ A ) \overline{A}=exp(\Delta A) A=exp(ΔA) :對(duì)應(yīng)代碼中的self.dA?

B  ̄ = ( Δ A ) ? 1 ( exp ? ( Δ A ) ? I ) ? Δ B \overline{B}=(\Delta A)^{-1}(\exp(\Delta A)-I)\cdot\Delta B B=(ΔA)?1(exp(ΔA)?I)?ΔB::對(duì)應(yīng)代碼中的self.dB

各個(gè)張量維度如下:

【論文閱讀筆記】Mamba模型代碼理解,Mamba,論文閱讀,筆記

2.Mamba模型定義

2.1 ModelArgs 類

ModelArgs 類是用于存儲(chǔ)和處理Mamba模型配置參數(shù)的容器。它使用Python的dataclass裝飾器來(lái)自動(dòng)生成初始化方法和類的字符串表示方法,簡(jiǎn)化了代碼的編寫(xiě)。這個(gè)類中的每個(gè)屬性對(duì)應(yīng)于構(gòu)建Mamba模型所需的一個(gè)配置參數(shù),例如模型的隱藏層維度、層數(shù)、詞匯表大小等。__post_init__方法在初始化后自動(dòng)調(diào)用,用于執(zhí)行一些額外的設(shè)置,比如計(jì)算內(nèi)部維度d_inner和自動(dòng)調(diào)整詞匯表大小,以確保模型的配置參數(shù)是有效的和一致的。

# 使用dataclass裝飾器自動(dòng)生成初始化方法和類的字符串表示方法
@dataclass
class ModelArgs:
    # @dataclass 會(huì)自動(dòng)為這個(gè)類生成初始化方法和代表類的字符串形式的方法
    d_model: int  # 定義模型的隱藏層維度
    n_layer: int # 定義模型的層數(shù)
    vocab_size: int  # 定義詞匯表的大小
    d_state: int = 16 # 定義狀態(tài)空間的維度,默認(rèn)為16
    expand: int = 2 # 定義擴(kuò)展因子,默認(rèn)為2
    dt_rank: Union[int, str] = 'auto'  # 定義輸入依賴步長(zhǎng)Δ的秩,'auto'表示自動(dòng)設(shè)置
    d_conv: int = 4   # 定義卷積核的維度,默認(rèn)為4
    pad_vocab_size_multiple: int = 8   # 定義詞匯表大小的最小公倍數(shù),默認(rèn)為8
    conv_bias: bool = True # 定義卷積層是否使用偏置項(xiàng)
    bias: bool = False # 定義其他層(如線性層)是否使用偏置項(xiàng)
    
    def __post_init__(self):
        # 在__init__后自動(dòng)被調(diào)用,用于執(zhí)行初始化之后的額外設(shè)置或驗(yàn)證
        # 計(jì)算內(nèi)部維度,即擴(kuò)展后的維度
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':# 如果dt_rank未指定,則自動(dòng)計(jì)算設(shè)置
            # 根據(jù)隱藏層維度自動(dòng)計(jì)算Δ的秩
            self.dt_rank = math.ceil(self.d_model / 16)
        # 確保vocab_size是pad_vocab_size_multiple的倍數(shù)
        # 如果不是,調(diào)整為最近的倍數(shù)
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

2.2 Mamba 類

Mamba 類是Mamba模型的主干,繼承自PyTorch的nn.Module類。這個(gè)類的實(shí)例化對(duì)象將構(gòu)成整個(gè)Mamba模型的結(jié)構(gòu)和前向傳播邏輯。

在初始化方法__init__中,首先調(diào)用父類的構(gòu)造函數(shù)來(lái)初始化模型。然后,根據(jù)傳入的ModelArgs對(duì)象中的參數(shù)配置模型的各個(gè)組件:

  • self.embedding是一個(gè)嵌入層,它將輸入的詞匯ID轉(zhuǎn)換為對(duì)應(yīng)的向量表示。這些向量隨后會(huì)被送入模型的深層網(wǎng)絡(luò)中。
  • self.layers是一個(gè)模塊列表,其中包含了多個(gè)ResidualBlock殘差塊。這些殘差塊有助于訓(xùn)練深層網(wǎng)絡(luò)并防止梯度消失問(wèn)題。
  • self.norm_f是一個(gè)RMSNorm歸一化模塊,用于在模型的某些層之后進(jìn)行歸一化操作,以穩(wěn)定訓(xùn)練過(guò)程。
  • self.lm_head是一個(gè)線性層,它將模型的最終隱藏狀態(tài)映射回詞匯表的大小,以便進(jìn)行下一步的預(yù)測(cè)或分類任務(wù)。

forward方法中,定義了模型的前向傳播邏輯。輸入input_ids首先通過(guò)嵌入層轉(zhuǎn)換為向量表示,然后依次通過(guò)每個(gè)殘差塊進(jìn)行處理。經(jīng)過(guò)所有層之后,模型的輸出通過(guò)RMSNorm歸一化,最后通過(guò)線性層self.lm_head得到最終的logits輸出。這個(gè)輸出可以用于后續(xù)的損失計(jì)算或生成任務(wù)。

class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        # 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
        self.args = args
        # 創(chuàng)建一個(gè)嵌入層,將詞匯表中的詞轉(zhuǎn)換為對(duì)應(yīng)的向量表示
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        # 創(chuàng)建一個(gè)包含多個(gè)殘差塊的模塊列表,殘差塊的數(shù)量等于模型層數(shù)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        # 創(chuàng)建一個(gè)RMSNorm模塊,用于歸一化操作
        self.norm_f = RMSNorm(args.d_model)
        # 創(chuàng)建一個(gè)線性層,用于最終的輸出,將隱藏層的輸出映射回詞匯表的大小
        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        # 將線性層的輸出權(quán)重與嵌入層的權(quán)重綁定,這是權(quán)重共享的一種形式,有助于減少參數(shù)數(shù)量并可能提高模型的泛化能力
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper


    def forward(self, input_ids):
        """
        Args:
            input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            logits: shape (b, l, vocab_size)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """
        # 將輸入ID轉(zhuǎn)換為向量表示
        x = self.embedding(input_ids)
        # 遍歷所有的殘差塊,并應(yīng)用它們
        for layer in self.layers:
            x = layer(x)
        # 應(yīng)用歸一化操作
        x = self.norm_f(x)
        # 通過(guò)線性層得到最終的logits輸出
        logits = self.lm_head(x)
        # 返回模型的輸出
        return logits

解釋一下:為什么輸入的input_ids已經(jīng)是經(jīng)過(guò)分詞器(tokenizer)處理后的詞匯表索引,還需要通過(guò)nn.Embedding?

這些索引代表了輸入文本中的單詞或子詞單元在詞匯表中的位置。盡管這些索引已經(jīng)是一個(gè)相對(duì)緊湊的數(shù)值表示,但它們并不直接對(duì)應(yīng)于模型可以處理的向量表示。nn.Embedding層的作用是將這些離散的索引映射到一個(gè)連續(xù)的向量空間中。每個(gè)索引input_ids中的值都會(huì)被nn.Embedding層轉(zhuǎn)換成一個(gè)固定維度的向量,這個(gè)向量捕捉了對(duì)應(yīng)單詞或子詞的語(yǔ)義信息。這個(gè)轉(zhuǎn)換過(guò)程是模型學(xué)習(xí)的一部分,通過(guò)訓(xùn)練數(shù)據(jù)中的模式,模型可以學(xué)習(xí)到如何將這些索引映射到能夠有效表示輸入文本的向量。

2.3 ResidualBlock 類

定義了Mamba模型中的一個(gè)殘差塊。這個(gè)類的目的是為了在模型中引入殘差連接,這有助于訓(xùn)練深層網(wǎng)絡(luò),因?yàn)樗试S梯度直接流過(guò)網(wǎng)絡(luò),從而緩解了梯度消失問(wèn)題。

__init__方法中,首先調(diào)用父類nn.Module的構(gòu)造函數(shù)來(lái)初始化殘差塊。然后,根據(jù)傳入的ModelArgs對(duì)象中的參數(shù)配置殘差塊的組件:

  • self.mixer是一個(gè)MambaBlock實(shí)例,它是這個(gè)殘差塊的核心組件,負(fù)責(zé)執(zhí)行Mamba模型的大部分計(jì)算。
  • self.norm是一個(gè)RMSNorm歸一化模塊,用于在將數(shù)據(jù)送入MambaBlock之前進(jìn)行歸一化處理。

forward方法中,定義了殘差塊的前向傳播邏輯。輸入張量x首先通過(guò)RMSNorm模塊進(jìn)行歸一化,然后送入MambaBlockMambaBlock的輸出接著與原始輸入x相加,形成殘差連接。這樣做可以使得模型的學(xué)習(xí)更加靈活,因?yàn)樗试S模型學(xué)習(xí)到輸入和輸出之間的恒等映射(即不改變輸入數(shù)據(jù)),這在某些情況下是有益的。最后,殘差塊的輸出被返回。

class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        # 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
        self.args = args
        # 創(chuàng)建一個(gè)MambaBlock,它是這個(gè)殘差塊的核心組件
        self.mixer = MambaBlock(args)
        # 創(chuàng)建一個(gè)RMSNorm歸一化模塊,用于歸一化操作
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
             x (Tensor): 輸入張量,形狀為(batch_size, sequence_length, hidden_size)
        Returns:
            output: shape (b, l, d)
            輸出張量,形狀與輸入相同
        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        # 應(yīng)用歸一化和MambaBlock,然后與輸入x進(jìn)行殘差連接
        output = self.mixer(self.norm(x)) + x

        return output

2.4 MambaBlock 類

MambaBlock 類定義了Mamba模型中的一個(gè)基本構(gòu)建塊,即Mamba塊。這個(gè)塊是模型的核心組件,負(fù)責(zé)執(zhí)行序列數(shù)據(jù)的處理和狀態(tài)空間模型的更新。

__init__方法中,首先調(diào)用父類nn.Module的構(gòu)造函數(shù)來(lái)初始化Mamba塊。然后,根據(jù)傳入的ModelArgs對(duì)象中的參數(shù)配置Mamba塊的組件:

  • self.in_proj是一個(gè)線性變換層,用于輸入的投影。
  • self.conv1d是一個(gè)一維卷積層,用于執(zhí)行深度卷積,這是Mamba模型的特色之一,用于處理序列數(shù)據(jù)。
  • self.x_projself.dt_proj是線性變換層,用于將輸入映射到狀態(tài)空間模型的參數(shù)。
  • self.A_log是矩陣A的對(duì)數(shù)值,作為一個(gè)可訓(xùn)練參數(shù)。
  • self.D是矩陣D,初始化為全1,也是一個(gè)可訓(xùn)練參數(shù)。
  • self.out_proj是一個(gè)線性變換層,用于輸出的投影。

forward方法中,定義了Mamba塊的前向傳播邏輯。輸入張量x首先通過(guò)線性變換層和深度卷積層進(jìn)行處理,然后應(yīng)用激活函數(shù)。接著,通過(guò)狀態(tài)空間模型(ssm)和選擇性掃描(selective_scan)算法更新?tīng)顟B(tài),并計(jì)算輸出。最后,輸出通過(guò)另一個(gè)線性變換層進(jìn)行投影,得到最終的輸出結(jié)果。

ssm方法負(fù)責(zé)運(yùn)行狀態(tài)空間模型,它使用矩陣A、B、C和D以及輸入x來(lái)更新?tīng)顟B(tài)并計(jì)算輸出。

【論文閱讀筆記】Mamba模型代碼理解,Mamba,論文閱讀,筆記

selective_scan方法執(zhí)行選擇性掃描算法,這是Mamba模型的關(guān)鍵特性,它允許模型根據(jù)輸入動(dòng)態(tài)調(diào)整其行為,從而更好地處理序列數(shù)據(jù)。通過(guò)這種方式,Mamba模型能夠有效地捕捉序列中的長(zhǎng)期依賴關(guān)系,同時(shí)保持線性時(shí)間復(fù)雜度。

【論文閱讀筆記】Mamba模型代碼理解,Mamba,論文閱讀,筆記

class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        # 保存模型參數(shù)
        self.args = args
        # 輸入線性變換層
        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        # 創(chuàng)建了一個(gè)所謂的“深度卷積”,其中每個(gè)輸入通道被單獨(dú)卷積到每個(gè)輸出通道。
        # 這意味著每個(gè)輸出通道的結(jié)果是通過(guò)僅與一個(gè)輸入通道卷積得到的。
        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        # 將輸入x映射到狀態(tài)空間模型的參數(shù)Δ、B和C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        # 將Δ從args.dt_rank維度映射到args.d_inner維度
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        # 創(chuàng)建一個(gè)重復(fù)的序列,用于初始化狀態(tài)空間模型的矩陣A
        # n->dxn
        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        # 將矩陣A的對(duì)數(shù)值作為可訓(xùn)練參數(shù)保存
        self.A_log = nn.Parameter(torch.log(A))
        # 初始化矩陣D為全1的可訓(xùn)練參數(shù)
        self.D = nn.Parameter(torch.ones(args.d_inner))
        # 輸出線性變換層
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        

    def forward(self, x):
        """MambaBlock的前向傳播函數(shù),與Mamba論文圖3 Section 3.4相同.
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        # 獲取輸入x的維度
        # batchsize,seq_len,dim
        (b, l, d) = x.shape # 獲取輸入x的維度
        # 應(yīng)用輸入線性變換
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        # 將變換后的輸出分為兩部分x和res。
        # 得到的x分為兩個(gè)部分,一部分x繼續(xù)用于后續(xù)變換,生成所需要的參數(shù),res用于殘差部分
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
        # 調(diào)整x的形狀
        x = rearrange(x, 'b l d_in -> b d_in l')
        # 應(yīng)用深度卷積,然后截取前l(fā)個(gè)輸出
        x = self.conv1d(x)[:, :, :l]
        # 再次調(diào)整x的形狀
        x = rearrange(x, 'b d_in l -> b l d_in')
        # 應(yīng)用SiLU激活函數(shù)
        x = F.silu(x)
        # 運(yùn)行狀態(tài)空間模型
        y = self.ssm(x)
        # 將res的SiLU激活結(jié)果與y相乘
        y = y * F.silu(res)
        # 應(yīng)用輸出線性變換
        output = self.out_proj(y)
        # 返回輸出結(jié)果
        return output

    
    def ssm(self, x):
        """運(yùn)行狀態(tài)空間模型,參考Mamba論文 Section 3.2和注釋[2]:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        # 獲取A_log的維度
        # A在初始化時(shí)候經(jīng)過(guò)如下賦值:
        #  A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        #  self.A_log = nn.Parameter(torch.log(A))
        # (args.d_inner, args.d_state)
        (d_in, n) = self.A_log.shape # 獲取A_log的維度

        # 計(jì)算 ? A B C D, 這些屬于狀態(tài)空間參數(shù).
        #     A, D 是 與輸入無(wú)關(guān)的 (見(jiàn)Mamba論文Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ?, B, C 與輸入有關(guān)(這是與線性是不變模型S4最大的不同,
        #                       也是為什么Mamba被稱為 “選擇性” 狀態(tài)空間的原因)

        # 計(jì)算矩陣A
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        # 取D的值
        D = self.D.float()

        # 應(yīng)用x的投影變換
        # ( b,l,d_in) -> (b, l, dt_rank + 2*n)
        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)

        # 分割delta, B, C
        # delta: (b, l, dt_rank). B, C: (b, l, n)
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
        # 應(yīng)用dt_proj并計(jì)算delta
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        # 應(yīng)用選擇性掃描算法
        y = self.selective_scan(x, delta, A, B, C, D)
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """執(zhí)行選擇性掃描算法,參考Mamba論文[1] Section 2和注釋[2]. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        經(jīng)典的離散狀態(tài)空間公式:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
       除了B和C (以及step size delta用于離散化) 與輸入x(t)相關(guān).
    
        參數(shù):
            u: shape (b, l, d_in)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)

        過(guò)程概述:
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        # 獲取輸入u的維度
        (b, l, d_in) = u.shape
        # 獲取矩陣A的列數(shù)
        n = A.shape[1]  #  A: shape (d_in, n)
        
        # 離散化連續(xù)參數(shù)(A, B)
        # - A 使用 zero-order hold (ZOH) 離散化 (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is 使用一個(gè)簡(jiǎn)化的Euler discretization而不是ZOH.根據(jù)作者的討論:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"

        # 計(jì)算離散化的A
        # 將delta和A進(jìn)行點(diǎn)乘,將A沿著delta的最后一個(gè)維度進(jìn)行廣播,然后執(zhí)行逐元素乘法
        # A:(d_in, n),delta:(b, l, d_in)
        # A廣播拓展->(b,l,d_in, n),deltaA對(duì)應(yīng)原論文中的A_bar
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        # delta、B和u,這個(gè)計(jì)算和原始論文不同
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        # 執(zhí)行選擇性掃描,初始化狀態(tài)x為零
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        # 初始化輸出列表ys
        ys = []    
        for i in range(l):
            # 更新?tīng)顟B(tài)x
            # deltaA:((b,l,d_in, n)
            # deltaB_u:( b,l,d_in,n)
            # x:
            x = deltaA[:, i] * x + deltaB_u[:, i]
            # 計(jì)算輸出y
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            # 將輸出y添加到列表ys中
            ys.append(y)
        # 將列表ys堆疊成張量y
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        # 將輸入u乘以D并加到輸出y上
        y = y + u * D
    
        return y

解釋1:深度卷積的幾行代碼

x = rearrange(x, ‘b l d_in -> b d_in l’) 調(diào)整x的形狀

這行代碼使用rearrange函數(shù)將輸入張量x的形狀從(batch_size, sequence_length, d_model)轉(zhuǎn)換為(batch_size, d_model, sequence_length)。這種形狀調(diào)整是為了適配后續(xù)的一維卷積層self.conv1d,該卷積層期望輸入的形狀為(batch_size, channels, length),其中channels是卷積核的深度,length是序列的長(zhǎng)度。

x = self.conv1d(x)[:, :, :l] 應(yīng)用深度卷積

self.conv1d是一個(gè)一維卷積層,它沿著序列長(zhǎng)度l的方向應(yīng)用卷積核。由于self.conv1din_channels參數(shù)設(shè)置為args.d_inner,這意味著卷積操作是在d_model維的特征空間內(nèi)獨(dú)立進(jìn)行的。卷積操作的輸出是一個(gè)三維張量,其形狀為(batch_size, d_inner, sequence_length)。然后,代碼通過(guò)切片操作[:, :, :l]只保留了序列長(zhǎng)度為l的輸出,這是因?yàn)槲覀冎粚?duì)序列中的前l個(gè)元素感興趣。

x = rearrange(x, ‘b d_in l -> b l d_in’ 再次調(diào)整x的形狀

最后,為了繼續(xù)后續(xù)的計(jì)算,需要將卷積后的張量形狀再次調(diào)整回(batch_size, sequence_length, d_model)。這樣做是為了確保數(shù)據(jù)在后續(xù)層中的流動(dòng)是連貫的,特別是當(dāng)數(shù)據(jù)傳遞給后續(xù)的Mamba塊或其他層時(shí)。這里的rearrange函數(shù)將卷積輸出的形狀從(batch_size, d_inner, sequence_length)轉(zhuǎn)換回(batch_size, sequence_length, d_inner)

解釋2:A = -torch.exp(self.A_log.float())前面的負(fù)號(hào)

這里的負(fù)號(hào)-是因?yàn)樵跔顟B(tài)空間模型中,矩陣A通常表示的是一個(gè)離散時(shí)間系統(tǒng)的轉(zhuǎn)換矩陣,它描述了系統(tǒng)狀態(tài)隨時(shí)間的演變。在許多情況下,A矩陣的元素應(yīng)該是負(fù)的,以確保系統(tǒng)的穩(wěn)定性。這是因?yàn)樵陔x散時(shí)間系統(tǒng)中,我們希望系統(tǒng)的狀態(tài)隨著時(shí)間的推移而衰減或穩(wěn)定下來(lái),而不是增長(zhǎng),從而避免系統(tǒng)變得不穩(wěn)定或發(fā)散。

解釋3:狀態(tài)空間更新代碼

這兩行代碼首先根據(jù)當(dāng)前時(shí)間步的轉(zhuǎn)換矩陣deltaA和輸入影響deltaB_u更新?tīng)顟B(tài)向量x,然后計(jì)算狀態(tài)向量x和輸出矩陣C的點(diǎn)乘,得到當(dāng)前時(shí)間步的輸出y。這個(gè)過(guò)程是狀態(tài)空間模型中的核心計(jì)算步驟,它允許模型動(dòng)態(tài)地處理序列數(shù)據(jù)并生成響應(yīng)。

  1. x = deltaA[:, i] * x + deltaB_u[:, i]
    • deltaA是一個(gè)四維張量,其形狀為(batch_size, sequence_length, d_in, n)。這里deltaA[:, i]表示我們選擇了deltaA張量中第i個(gè)時(shí)間步的切片,形狀變?yōu)?code>(batch_size, d_in, n)。
    • x是狀態(tài)向量,其形狀為(batch_size, d_in, n),代表當(dāng)前時(shí)間步的狀態(tài)。
    • deltaB_u是一個(gè)四維張量,其形狀也為(batch_size, sequence_length, d_in, n),它是通過(guò)delta、B和輸入u計(jì)算得到的,代表了輸入對(duì)狀態(tài)的直接影響。
    • 這行代碼首先執(zhí)行deltaA[:, i] * x,這是一個(gè)逐元素乘法操作,它根據(jù)當(dāng)前時(shí)間步的轉(zhuǎn)換矩陣更新?tīng)顟B(tài)向量x。由于deltaA[:, i]的形狀是(batch_size, d_in, n),它可以直接與形狀相同的x進(jìn)行逐元素乘法。
    • 接著,代碼執(zhí)行+ deltaB_u[:, i],將輸入的影響加到更新后的狀態(tài)向量x上。這里的deltaB_u[:, i]deltaB_u張量中第i個(gè)時(shí)間步的切片,形狀也是(batch_size, d_in, n)。
  2. y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
    • 這行代碼使用einsum函數(shù)來(lái)計(jì)算輸出y。einsum是PyTorch中的一個(gè)函數(shù),用于執(zhí)行復(fù)雜的張量運(yùn)算。
    • x是當(dāng)前狀態(tài)向量,形狀為(batch_size, d_in, n)
    • C[:, i, :]是從輸出參數(shù)矩陣C中取出的第i個(gè)時(shí)間步的切片,形狀為(batch_size, n, d_in)
    • 'b d_in n, b n -> b d_in'einsum的索引模式,它指示了如何執(zhí)行點(diǎn)乘和求和操作。在這個(gè)模式中,'b'表示批次維度保持不變,'d_in n'表示x的第二個(gè)和第三個(gè)維度與C的第二個(gè)維度進(jìn)行點(diǎn)乘,'b d_in'表示輸出y的形狀應(yīng)該與x的前兩個(gè)維度相同。
    • 結(jié)果y的形狀是(batch_size, d_in),它是模型在當(dāng)前時(shí)間步對(duì)輸入序列的響應(yīng)。

2.5 RMSNorm 類

這個(gè)類實(shí)現(xiàn)了基于均方根的歸一化操作。它接收輸入x,計(jì)算其均方根值,并使用這個(gè)值來(lái)歸一化輸入。這種歸一化有助于模型的訓(xùn)練穩(wěn)定性。

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        """
        初始化RMSNorm模塊,該模塊實(shí)現(xiàn)了基于均方根的歸一化操作。
        參數(shù):
        d_model (int): 模型的特征維度。
        eps (float, 可選): 為了避免除以零,添加到分母中的一個(gè)小的常數(shù)。
        """
        super().__init__()  
        self.eps = eps  # 保存輸入的eps值,用于數(shù)值穩(wěn)定性
        self.weight = nn.Parameter(torch.ones(d_model))  # 創(chuàng)建一個(gè)可訓(xùn)練的權(quán)重參數(shù),初始值為全1,維度與輸入特征維度d_model相同

    def forward(self, x):
        """
        定義RMSNorm模塊的前向傳播函數(shù)。
        參數(shù):
        x (Tensor): 輸入的張量,通常是一個(gè)特征矩陣,其形狀為(batch_size, sequence_length, d_model)。
        返回:
        output (Tensor): 歸一化后的特征矩陣。
        """
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight  # 計(jì)算歸一化輸出
        return output  # 返回歸一化后的輸出

小結(jié):狀態(tài)空間參數(shù)是如何與輸入相關(guān)的

這個(gè)是S6與S4的根本區(qū)別

在上面的MambaBlock類的代碼中,狀態(tài)空間的參數(shù)與輸入相關(guān)性體現(xiàn)在self.x_projself.dt_proj的使用上,以及在ssmselective_scan方法中的計(jì)算過(guò)程中。

  1. self.x_projself.dt_proj:
    • self.x_proj是一個(gè)線性變換層,它將輸入x映射到狀態(tài)空間模型的參數(shù)Δ、B和C。這個(gè)映射是輸入依賴的,因?yàn)?code>x是模型的輸入,它的值會(huì)直接影響到這些參數(shù)的計(jì)算。
    • self.dt_proj是一個(gè)線性變換層,用于將Δ從args.dt_rank維度映射到args.d_inner維度。雖然這個(gè)映射本身是一個(gè)固定的線性變換,但它的輸入(即x)是依賴于輸入數(shù)據(jù)的。
  2. ssm方法:
    • ssm方法中,計(jì)算了狀態(tài)空間模型的參數(shù)A、B、C和D。其中,A和D是與輸入無(wú)關(guān)的,而B(niǎo)和C是通過(guò)self.x_projself.dt_proj從輸入x中計(jì)算得到的,因此它們與輸入是相關(guān)的。
  3. selective_scan方法:
    • selective_scan方法執(zhí)行選擇性掃描算法,它是狀態(tài)空間模型的核心計(jì)算過(guò)程。在這個(gè)方法中,輸入u(實(shí)際上是x經(jīng)過(guò)一系列變換后的結(jié)果)與狀態(tài)空間參數(shù)Δ、A、B、C和D一起使用,來(lái)更新?tīng)顟B(tài)并計(jì)算輸出。
    • 方法中的deltaAdeltaB_u計(jì)算顯示了輸入u如何影響狀態(tài)空間參數(shù)。deltaA是通過(guò)einsum函數(shù)將輸入u的每個(gè)元素與矩陣A的每個(gè)元素進(jìn)行點(diǎn)乘得到的,這意味著輸入的每個(gè)元素都會(huì)影響A的每個(gè)元素。
    • deltaB_u是通過(guò)einsum函數(shù)將輸入u、矩陣B和Δ進(jìn)行三元組乘法得到的,這進(jìn)一步表明輸入u直接影響了狀態(tài)空間參數(shù)B的計(jì)算。

總的來(lái)說(shuō),狀態(tài)空間的參數(shù)與輸入相關(guān)性是通過(guò)輸入數(shù)據(jù)x直接影響Δ、B和C的計(jì)算來(lái)實(shí)現(xiàn)的。這種相關(guān)性使得Mamba模型能夠根據(jù)輸入數(shù)據(jù)的不同動(dòng)態(tài)調(diào)整其內(nèi)部狀態(tài),從而更好地捕捉序列數(shù)據(jù)的特性。這是Mamba模型區(qū)別于傳統(tǒng)的線性時(shí)不變(LTI)狀態(tài)空間模型的關(guān)鍵特性。

3.模型測(cè)試代碼

3.1 加載模型

from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

3.2 生成文本

這個(gè)函數(shù)通過(guò)迭代地向模型提供輸入,并基于模型預(yù)測(cè)的概率分布來(lái)生成下一個(gè)令牌,直到達(dá)到指定的令牌數(shù)量。生成過(guò)程中,可以通過(guò)top_k采樣來(lái)限制概率分布,或者通過(guò)采樣來(lái)隨機(jī)選擇令牌,從而增加生成文本的多樣性。最終,函數(shù)返回生成的文本

import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    # 將模型設(shè)置為評(píng)估模式,這通常會(huì)關(guān)閉dropout等訓(xùn)練時(shí)的特性。
    model.eval()
    # 使用分詞器將提示字符串轉(zhuǎn)換為模型可以處理的輸入ID。
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    # 循環(huán)生成指定數(shù)量的令牌。
    for token_n in range(n_tokens_to_gen):
        # 無(wú)需計(jì)算梯度,因?yàn)槲覀兪窃谏晌谋径皇怯?xùn)練模型。
        with torch.no_grad():
            # 準(zhǔn)備輸入模型的索引。
            indices_to_input = input_ids
            # 通過(guò)模型獲取當(dāng)前輸入的下一個(gè)令牌的logits。
            next_token_logits = model(indices_to_input)[:, -1]
        # 對(duì)logits應(yīng)用softmax函數(shù),將其轉(zhuǎn)換為概率分布。
        probs = F.softmax(next_token_logits, dim=-1)
         # 獲取概率分布的形狀,即批次大小和詞匯表大小。
        (batch, vocab_size) = probs.shape
        # 如果指定了top_k采樣,則獲取概率最高的k個(gè)令牌及其對(duì)應(yīng)的值和索引。
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            # 將概率低于最低top_k令牌的概率值設(shè)置為0。
            probs[probs < values[:, -1, None]] = 0
            # 重新歸一化概率分布,使得所有概率之和為1。
            probs = probs / probs.sum(axis=1, keepdims=True)
        # 如果采樣為T(mén)rue,則通過(guò)多項(xiàng)式采樣(Multinomial Sampling)來(lái)選擇下一個(gè)令牌。
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:# 如果不采樣,則選擇概率最高的令牌作為下一個(gè)令牌。
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        # 將生成的下一個(gè)令牌添加到輸入ID列表中。
        input_ids = torch.cat([input_ids, next_indices], dim=1)
    # 將最終的輸入ID轉(zhuǎn)換為文本,并解碼為可讀的字符串。
    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    # 返回生成的文本。
    return output_completions
print(generate(model, tokenizer, 'Mamba is the'))

Mamba is the world’s longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

print(generate(model, tokenizer, 'John: Hi!\nSally:'))

John: Hi!
Sally: Hey!
John: So, when’s the wedding?
Sally: We haven’t decided.
John: It’s in September.
Sally: Yeah, we were thinking July or August.文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-849915.html

附:完整模型代碼

"""Simple, minimal implementation of Mamba in one file of PyTorch.

Suggest reading the following before/while reading the code:
    [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
        https://arxiv.org/abs/2312.00752
    [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
        https://srush.github.io/annotated-s4

Glossary:
    b: batch size                       (`B` in Mamba paper [1] Algorithm 2)
    l: sequence length                  (`L` in [1] Algorithm 2)
    d or d_model: hidden dim
    n or d_state: latent state dim      (`N` in [1] Algorithm 2)
    expand: expansion factor            (`E` in [1] Section 3.4)
    d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)
    A, B, C, D: state space parameters  (See any state space representation formula)
                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
    Δ or delta: input-dependent step size
    dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ?")

"""
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

# 使用dataclass裝飾器自動(dòng)生成初始化方法和類的字符串表示方法
@dataclass
class ModelArgs:
    # @dataclass 會(huì)自動(dòng)為這個(gè)類生成初始化方法和代表類的字符串形式的方法
    d_model: int  # 定義模型的隱藏層維度
    n_layer: int # 定義模型的層數(shù)
    vocab_size: int  # 定義詞匯表的大小
    d_state: int = 16 # 定義狀態(tài)空間的維度,默認(rèn)為16
    expand: int = 2 # 定義擴(kuò)展因子,默認(rèn)為2
    dt_rank: Union[int, str] = 'auto'  # 定義輸入依賴步長(zhǎng)Δ的秩,'auto'表示自動(dòng)設(shè)置
    d_conv: int = 4   # 定義卷積核的維度,默認(rèn)為4
    pad_vocab_size_multiple: int = 8   # 定義詞匯表大小的最小公倍數(shù),默認(rèn)為8
    conv_bias: bool = True # 定義卷積層是否使用偏置項(xiàng)
    bias: bool = False # 定義其他層(如線性層)是否使用偏置項(xiàng)
    
    def __post_init__(self):
        # 在__init__后自動(dòng)被調(diào)用,用于執(zhí)行初始化之后的額外設(shè)置或驗(yàn)證
        # 計(jì)算內(nèi)部維度,即擴(kuò)展后的維度
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':# 如果dt_rank未指定,則自動(dòng)計(jì)算設(shè)置
            # 根據(jù)隱藏層維度自動(dòng)計(jì)算Δ的秩
            self.dt_rank = math.ceil(self.d_model / 16)
        # 確保vocab_size是pad_vocab_size_multiple的倍數(shù)
        # 如果不是,調(diào)整為最近的倍數(shù)
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)


class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        # 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
        self.args = args
        # 創(chuàng)建一個(gè)嵌入層,將詞匯表中的詞轉(zhuǎn)換為對(duì)應(yīng)的向量表示
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        # 創(chuàng)建一個(gè)包含多個(gè)殘差塊的模塊列表,殘差塊的數(shù)量等于模型層數(shù)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        # 創(chuàng)建一個(gè)RMSNorm模塊,用于歸一化操作
        self.norm_f = RMSNorm(args.d_model)
        # 創(chuàng)建一個(gè)線性層,用于最終的輸出,將隱藏層的輸出映射回詞匯表的大小
        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        # 將線性層的輸出權(quán)重與嵌入層的權(quán)重綁定,這是權(quán)重共享的一種形式,有助于減少參數(shù)數(shù)量并可能提高模型的泛化能力
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper


    def forward(self, input_ids):
        """
        Args:
            input_ids (long tensor): shape (b, l)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            logits: shape (b, l, vocab_size)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """
        # 將輸入ID轉(zhuǎn)換為向量表示
        x = self.embedding(input_ids)
        # 遍歷所有的殘差塊,并應(yīng)用它們
        for layer in self.layers:
            x = layer(x)
        # 應(yīng)用歸一化操作
        x = self.norm_f(x)
        # 通過(guò)線性層得到最終的logits輸出
        logits = self.lm_head(x)
        # 返回模型的輸出
        return logits

    
    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """Load pretrained weights from HuggingFace into model.
    
        Args:
            pretrained_model_name: One of
                * 'state-spaces/mamba-2.8b-slimpj'
                * 'state-spaces/mamba-2.8b'
                * 'state-spaces/mamba-1.4b'
                * 'state-spaces/mamba-790m'
                * 'state-spaces/mamba-370m'
                * 'state-spaces/mamba-130m'
                            
        Returns:
            model: Mamba model with weights loaded
    
        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file
        
        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))
        
        
        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)
        
        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
            d_model=config_data['d_model'],
            n_layer=config_data['n_layer'],
            vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)
        
        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)
        
        return model


class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        # 保存?zhèn)魅氲腗odelArgs對(duì)象,包含模型的配置參數(shù)
        self.args = args
        # 創(chuàng)建一個(gè)MambaBlock,它是這個(gè)殘差塊的核心組件
        self.mixer = MambaBlock(args)
        # 創(chuàng)建一個(gè)RMSNorm歸一化模塊,用于歸一化操作
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
             x (Tensor): 輸入張量,形狀為(batch_size, sequence_length, hidden_size)
        Returns:
            output: shape (b, l, d)
            輸出張量,形狀與輸入相同
        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297
            
            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
            
        """
        # 應(yīng)用歸一化和MambaBlock,然后與輸入x進(jìn)行殘差連接
        output = self.mixer(self.norm(x)) + x

        return output
            

class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        # 保存模型參數(shù)
        self.args = args
        # 輸入線性變換層
        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        # 創(chuàng)建了一個(gè)所謂的“深度卷積”,其中每個(gè)輸入通道被單獨(dú)卷積到每個(gè)輸出通道。
        # 這意味著每個(gè)輸出通道的結(jié)果是通過(guò)僅與一個(gè)輸入通道卷積得到的。
        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        # 將輸入x映射到狀態(tài)空間模型的參數(shù)Δ、B和C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        # 將Δ從args.dt_rank維度映射到args.d_inner維度
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        # 創(chuàng)建一個(gè)重復(fù)的序列,用于初始化狀態(tài)空間模型的矩陣A
        # n->dxn
        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        # 將矩陣A的對(duì)數(shù)值作為可訓(xùn)練參數(shù)保存
        self.A_log = nn.Parameter(torch.log(A))
        # 初始化矩陣D為全1的可訓(xùn)練參數(shù)
        self.D = nn.Parameter(torch.ones(args.d_inner))
        # 輸出線性變換層
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        

    def forward(self, x):
        """MambaBlock的前向傳播函數(shù),與Mamba論文圖3 Section 3.4相同.
    
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d)
        
        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        # 獲取輸入x的維度
        # batchsize,seq_len,dim
        (b, l, d) = x.shape # 獲取輸入x的維度
        # 應(yīng)用輸入線性變換
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        # 將變換后的輸出分為兩部分x和res。
        # 得到的x分為兩個(gè)部分,一部分x繼續(xù)用于后續(xù)變換,生成所需要的參數(shù),res用于殘差部分
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
        # 調(diào)整x的形狀
        x = rearrange(x, 'b l d_in -> b d_in l')
        # 應(yīng)用深度卷積,然后截取前l(fā)個(gè)輸出
        x = self.conv1d(x)[:, :, :l]
        # 再次調(diào)整x的形狀
        x = rearrange(x, 'b d_in l -> b l d_in')
        # 應(yīng)用SiLU激活函數(shù)
        x = F.silu(x)
        # 運(yùn)行狀態(tài)空間模型
        y = self.ssm(x)
        # 將res的SiLU激活結(jié)果與y相乘
        y = y * F.silu(res)
        # 應(yīng)用輸出線性變換
        output = self.out_proj(y)
        # 返回輸出結(jié)果
        return output

    
    def ssm(self, x):
        """運(yùn)行狀態(tài)空間模型,參考Mamba論文 Section 3.2和注釋[2]:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
    
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
            
        """
        # 獲取A_log的維度
        # A在初始化時(shí)候經(jīng)過(guò)如下賦值:
        #  A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        #  self.A_log = nn.Parameter(torch.log(A))
        # (args.d_inner, args.d_state)
        (d_in, n) = self.A_log.shape # 獲取A_log的維度

        # 計(jì)算 ? A B C D, 這些屬于狀態(tài)空間參數(shù).
        #     A, D 是 與輸入無(wú)關(guān)的 (見(jiàn)Mamba論文Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ?, B, C 與輸入有關(guān)(這是與線性是不變模型S4最大的不同,
        #                       也是為什么Mamba被稱為 “選擇性” 狀態(tài)空間的原因)

        # 計(jì)算矩陣A
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        # 取D的值
        D = self.D.float()

        # 應(yīng)用x的投影變換
        # ( b,l,d_in) -> (b, l, dt_rank + 2*n)
        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)

        # 分割delta, B, C
        # delta: (b, l, dt_rank). B, C: (b, l, n)
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
        # 應(yīng)用dt_proj并計(jì)算delta
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        # 應(yīng)用選擇性掃描算法
        y = self.selective_scan(x, delta, A, B, C, D)
        return y

    
    def selective_scan(self, u, delta, A, B, C, D):
        """執(zhí)行選擇性掃描算法,參考Mamba論文[1] Section 2和注釋[2]. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        經(jīng)典的離散狀態(tài)空間公式:
            x(t + 1) = Ax(t) + Bu(t)
            y(t)     = Cx(t) + Du(t)
       除了B和C (以及step size delta用于離散化) 與輸入x(t)相關(guān).
    
        參數(shù):
            u: shape (b, l, d_in)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)

        過(guò)程概述:
    
        Returns:
            output: shape (b, l, d_in)
    
        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
            
        """
        # 獲取輸入u的維度
        (b, l, d_in) = u.shape
        # 獲取矩陣A的列數(shù)
        n = A.shape[1]  #  A: shape (d_in, n)
        
        # 離散化連續(xù)參數(shù)(A, B)
        # - A 使用 zero-order hold (ZOH) 離散化 (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is 使用一個(gè)簡(jiǎn)化的Euler discretization而不是ZOH.根據(jù)作者的討論:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"

        # 計(jì)算離散化的A
        # 將delta和A進(jìn)行點(diǎn)乘,將A沿著delta的最后一個(gè)維度進(jìn)行廣播,然后執(zhí)行逐元素乘法
        # A:(d_in, n),delta:(b, l, d_in)
        # A廣播拓展->(b,l,d_in, n),deltaA對(duì)應(yīng)原論文中的A_bar
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        # delta、B和u,這個(gè)計(jì)算和原始論文不同
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        # 執(zhí)行選擇性掃描,初始化狀態(tài)x為零
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        # 初始化輸出列表ys
        ys = []    
        for i in range(l):
            # 更新?tīng)顟B(tài)x
            # deltaA:((b,l,d_in, n)
            # deltaB_u:( b,l,d_in,n)
            # x:
            x = deltaA[:, i] * x + deltaB_u[:, i]
            # 計(jì)算輸出y
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            # 將輸出y添加到列表ys中
            ys.append(y)
        # 將列表ys堆疊成張量y
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        # 將輸入u乘以D并加到輸出y上
        y = y + u * D
    
        return y


class RMSNorm(nn.Module):
    """
    初始化RMSNorm模塊,該模塊實(shí)現(xiàn)了基于均方根的歸一化操作。

    參數(shù):
    d_model (int): 模型的特征維度。
    eps (float, 可選): 為了避免除以零,添加到分母中的一個(gè)小的常數(shù)。
    """
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps# 保存輸入的eps值,用于數(shù)值穩(wěn)定性。
        # 創(chuàng)建一個(gè)可訓(xùn)練的權(quán)重參數(shù),初始值為全1,維度與輸入特征維度d_model相同。
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        """
                計(jì)算輸入x的均方根值,用于后續(xù)的歸一化操作。
                x.pow(2) 計(jì)算x中每個(gè)元素的平方。
                mean(-1, keepdim=True) 對(duì)x的最后一個(gè)維度(特征維度)進(jìn)行平方和求平均,保持維度以便進(jìn)行廣播操作。
                torch.rsqrt 對(duì)求得的平均值取倒數(shù)和平方根,得到每個(gè)特征的均方根值的逆。
                + self.eps 添加一個(gè)小的常數(shù)eps以保持?jǐn)?shù)值穩(wěn)定性,防止除以零的情況發(fā)生。
                x * ... * self.weight 將輸入x與計(jì)算得到的歸一化因子和可訓(xùn)練的權(quán)重相乘,得到最終的歸一化輸出。
                """
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

到了這里,關(guān)于【論文閱讀筆記】Mamba模型代碼理解的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來(lái)自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 《Vision mamba》論文筆記

    《Vision mamba》論文筆記

    [2401.09417] Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model (arxiv.org) Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model 我們提出了 Vision Mamba (Vim), Vim是一種基于純SSM的方法,并以序列方式對(duì)圖像進(jìn)行建模 ,它結(jié)合了 雙向 SSM 用于數(shù)據(jù)

    2024年04月15日
    瀏覽(23)
  • 【模型壓縮】 LPPN論文閱讀筆記

    【模型壓縮】 LPPN論文閱讀筆記

    LPPN: A Lightweight Network for Fast Phase Picking ?深度學(xué)習(xí)模型的問(wèn)題在于計(jì)算復(fù)雜度較高,在實(shí)際數(shù)據(jù)處理中需要面臨較高的處理代價(jià),且需要專用的加速處理設(shè)備,如GPU。隨著數(shù)據(jù)累積,迫切需要設(shè)計(jì)一種能夠保證精度的輕量化高速震相拾取模型,以提高處理海量數(shù)據(jù)的效率,這

    2024年02月16日
    瀏覽(58)
  • [論文閱讀筆記18] DiffusionDet論文筆記與代碼解讀

    [論文閱讀筆記18] DiffusionDet論文筆記與代碼解讀

    擴(kuò)散模型近期在圖像生成領(lǐng)域很火, 沒(méi)想到很快就被用在了檢測(cè)上. 打算對(duì)這篇論文做一個(gè)筆記. 論文地址: 論文 代碼: 代碼 首先介紹什么是擴(kuò)散模型. 我們考慮生成任務(wù), 即encoder-decoder形式的模型, encoder提取輸入的抽象信息, 并嘗試在decoder中恢復(fù)出來(lái). 擴(kuò)散模型就是這一類中的

    2023年04月08日
    瀏覽(26)
  • 【長(zhǎng)文閱讀】MAMBA作者博士論文<MODELING SEQUENCES WITH STRUCTURED STATE SPACES>-Chapter1

    【長(zhǎng)文閱讀】MAMBA作者博士論文<MODELING SEQUENCES WITH STRUCTURED STATE SPACES>-Chapter1

    Chapter1 Gu A. Modeling Sequences with Structured State Spaces[D]. Stanford University, 2023. 本文是MAMBA作者的博士畢業(yè)論文,為了理清楚MAMBA專門(mén)花時(shí)間拜讀這篇長(zhǎng)達(dá)330頁(yè)的博士論文,由于知識(shí)水平有限,只能盡自己所能概述記錄,并適當(dāng)補(bǔ)充一些相關(guān)數(shù)學(xué)背景,歡迎探討與批評(píng)指正。內(nèi)容多,

    2024年01月19日
    瀏覽(47)
  • 【長(zhǎng)文閱讀】MAMBA作者博士論文<MODELING SEQUENCES WITH STRUCTURED STATE SPACES>-Chapter2

    【長(zhǎng)文閱讀】MAMBA作者博士論文<MODELING SEQUENCES WITH STRUCTURED STATE SPACES>-Chapter2

    Gu A. Modeling Sequences with Structured State Spaces[D]. Stanford University, 2023. 本文是MAMBA作者的博士畢業(yè)論文,為了理清楚MAMBA專門(mén)花時(shí)間拜讀這篇長(zhǎng)達(dá)330頁(yè)的博士論文,由于知識(shí)水平有限,只能盡自己所能概述記錄,并適當(dāng)補(bǔ)充一些相關(guān)數(shù)學(xué)背景,歡迎探討與批評(píng)指正。內(nèi)容多,分章節(jié)

    2024年01月20日
    瀏覽(18)
  • 多模態(tài)大模型-CogVLm 論文閱讀筆記

    多模態(tài)大模型-CogVLm 論文閱讀筆記

    論文地址 :https://arxiv.org/pdf/2311.03079.pdf code地址 : https://github.com/THUDM/CogVLM 時(shí)間 : 2023-11 機(jī)構(gòu) : zhipuai,tsinghua : visual language model 效果:(2023-11) :CogVLM-17B achieves state-of-the-art performance on 10 classic cross-modal benchmarks, including NoCaps, Flicker30k captioning, RefCOCO, RefCOCO+, RefCOCOg, Visual7W,

    2024年02月03日
    瀏覽(20)
  • 中英雙語(yǔ)大模型ChatGLM論文閱讀筆記

    中英雙語(yǔ)大模型ChatGLM論文閱讀筆記

    論文傳送門(mén): [1] GLM: General Language Model Pretraining with Autoregressive Blank Infilling [2] Glm-130b: An open bilingual pre-trained model Github鏈接: THUDM/ChatGLM-6B GLM-130B 和 GPT-3 175B(davinci) 相比,參數(shù)量減少,但性能提升了。 INT4 quantization without post training INT4量化是一種將模型的權(quán)重和激活從使用

    2024年02月02日
    瀏覽(27)
  • CLIP原理解讀——大模型論文閱讀筆記一

    CLIP原理解讀——大模型論文閱讀筆記一

    通過(guò)自然語(yǔ)言處理來(lái)的一些監(jiān)督信號(hào),可以去訓(xùn)練一個(gè)遷移效果很好的視覺(jué)模型。 論文的作者團(tuán)隊(duì)收集了一個(gè)超級(jí)大的圖像文本配對(duì)的數(shù)據(jù)集,有400 million個(gè)圖片文本的配對(duì), 模型最大用了ViT-large,提出了CLIP(Contrastive Language-Image Pre-training),是一種從自然語(yǔ)言監(jiān)督中學(xué)習(xí)

    2024年02月08日
    瀏覽(36)
  • MiniGPT-4原理解讀——大模型論文閱讀筆記三

    MiniGPT-4原理解讀——大模型論文閱讀筆記三

    論文:https://arxiv.org/pdf/2304.10592v1.pdf 代碼:https://github.com/vision-cair/minigpt-4 GPT-4展示了非凡的多模態(tài)能力,比如直接從手寫(xiě)文本生成網(wǎng)站,以及識(shí)別圖像中的幽默元素。這些特性在以前的視覺(jué)語(yǔ)言模型中很少見(jiàn)。我們認(rèn)為GPT-4具有先進(jìn)的多模態(tài)生成能力的主要原因在于利用了更

    2024年02月11日
    瀏覽(28)
  • 論文閱讀筆記AI篇 —— Transformer模型理論+實(shí)戰(zhàn) (四)

    論文閱讀筆記AI篇 —— Transformer模型理論+實(shí)戰(zhàn) (四)

    參考文章或視頻鏈接 [1] 《論文閱讀筆記AI篇 —— Transformer模型理論+實(shí)戰(zhàn) (一)》- CSDN [2] 《論文閱讀筆記AI篇 —— Transformer模型理論+實(shí)戰(zhàn) (二)》- CSDN [3] 《論文閱讀筆記AI篇 —— Transformer模型理論+實(shí)戰(zhàn) (三)》- CSDN 如果說(shuō)鋼鐵俠中的 J.A.R.V.I.S. (賈維斯)是一個(gè)AGI通用人工智能的

    2024年01月24日
    瀏覽(17)

覺(jué)得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包