国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

PaLM中ROPE位置編碼實現(xiàn)源碼解析

這篇具有很好參考價值的文章主要介紹了PaLM中ROPE位置編碼實現(xiàn)源碼解析。希望對大家有所幫助。如果存在錯誤或未考慮完全的地方,請大家不吝賜教,您也可以點擊"舉報違法"按鈕提交疑問。

1、源碼

import torch
from einops import rearrange
from torch import einsum, nn

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # [d/2]
        # inv_freq形式化為 [theta_0, theta_1, ..., theta_(d/2-1)]
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        # 計算m
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype) # [length]

        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        # freqs形式化為 [m*theta_0, m*theta_1, ..., m*theta_d/2],其中 m=0,1,...,length-1

        # return結(jié)果形式化為 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1
        return torch.cat((freqs, freqs), dim=-1) # [length, d]


def rotate_half(x):
    # x為q或k, 形式化為[q0, q1, .., qd-1]
    # x: [bs, head, length, d]--> [bs, head, length, 2, d/2]
    # 下式將x進行劃分,前半部分形式化為[q0, q1, .., q(d/2-1)]
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    # x1形式化為[q0, q1, .., q(d/2-1)]
    # x2形式化為[q(d/2), q(d/2+1), .., q(d-1)]
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1) # [-q(d/2), -q(d/2+1), .., -q(d-1), q0, q1, .., q(d/2-1)]

def apply_rotary_pos_emb(pos, t):
    # t: [bs, head, length, d], [q0, q1, .., qd-1]
    # pos: [length, d], [m*theta_0, m*theta_1, ..., m*theta_(d/2-1),m*theta_0, m*theta_1, ..., m*theta_(d/2-1)]
    rotate_half(t)
    # 以第一個為例,q0*cos(m*theta_0) - q(d/2)*sin(m*theta_0)
    #  第二個,q1*cos(m*theta_1) - q(d/2+1)*sin(m*theta_1)
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())


if __name__=='__main__':
    # (bs, head, length, d)
    q = torch.randn((2, 12, 10, 32))  # q=[q0, q1, .., qd-1]
    k = torch.randn((2, 12, 10, 32))
    v = torch.randn((2, 12, 10, 32))
    print('q:', q[0][0][0])
    print('k:', k[0][0][0])
    rotary_emb = RotaryEmbedding(dim=32)
    pos_emb = rotary_emb(max_seq_len=10, device=torch.device('cpu'))  # [length, d]
    q_new, k_new = map(lambda t: apply_rotary_pos_emb(pos_emb, t), (q, k))
    print()

2、公式
( q 0 q 1 . . . q d / 2 ? 1 q d / 2 . . . q d ? 2 q d ? 1 ) ? ( c o s ( m θ 0 ) c o s ( m θ 1 ) . . . c o s ( m θ d / 2 ? 1 ) c o s ( m θ 0 ) . . . c o s ( m θ d / 2 ? 2 ) c o s ( m θ d / 2 ? 1 ) ) + ( ? q d / 2 ? q d / 2 + 1 . . . ? q d ? 1 q 0 . . . q d / 2 ? 2 q d / 2 ? 1 ) ( s i n ( m θ 0 ) s i n ( m θ 1 ) . . . s i n ( m θ d / 2 ? 1 ) s i n ( m θ 0 ) . . . s i n ( m θ d / 2 ? 2 ) s i n ( m θ d / 2 ? 1 ) ) \left( \begin{array}{cccc} q_0 \\ q_1\\ ...\\ q_{d/2-1}\\ q_{d/2}\\. ..\\ q_{d-2}\\ q_{d-1} \end{array} \right)* \left( \begin{array}{cccc} cos(m\theta_0) \\ cos(m\theta_1)\\ ...\\ cos(m\theta_{d/2-1})\\ cos(m\theta_0)\\. ..\\ cos(m\theta_{d/2-2})\\ cos(m\theta_{d/2-1}) \end{array} \right)+ \left( \begin{array}{cccc} -q_{d/2} \\ -q_{d/2+1}\\ ...\\ -q_{d-1}\\ q_{0}\\. ..\\ q_{d/2-2}\\ q_{d/2-1} \end{array} \right) \left( \begin{array}{cccc} sin(m\theta_0) \\ sin(m\theta_1)\\ ...\\ sin(m\theta_{d/2-1})\\ sin(m\theta_0)\\. ..\\ sin(m\theta_{d/2-2})\\ sin(m\theta_{d/2-1}) \end{array} \right) ?q0?q1?...qd/2?1?qd/2?...qd?2?qd?1?? ?? ?cos(mθ0?)cos(mθ1?)...cos(mθd/2?1?)cos(mθ0?)...cos(mθd/2?2?)cos(mθd/2?1?)? ?+ ??qd/2??qd/2+1?...?qd?1?q0?...qd/2?2?qd/2?1?? ? ?sin(mθ0?)sin(mθ1?)...sin(mθd/2?1?)sin(mθ0?)...sin(mθd/2?2?)sin(mθd/2?1?)? ?
3、圖形
觀察上圖,可以發(fā)現(xiàn) q 0 q_0 q0? q d / 2 q_{d/2} qd/2?相互作用,生成新的 q 0 n e w q^{new}_0 q0new? q d / 2 n e w q^{new}_{d/2} qd/2new?,拆解后可以得到下式

q 0 n e w = q 0 ? c o s ( m θ 0 ) ? q d / 2 ? s i n ( m θ 0 ) q^{new}_0=q_0*cos(m\theta_0)-q_{d/2}*sin(m\theta_0) q0new?=q0??cos(mθ0?)?qd/2??sin(mθ0?)
q d / 2 n e w = q 0 ? s i n ( m θ 0 ) + q d / 2 ? c o s ( m θ 0 ) q^{new}_{d/2}=q_0*sin(m\theta_0)+q_{d/2}*cos(m\theta_0) qd/2new?=q0??sin(mθ0?)+qd/2??cos(mθ0?)
也即向量 ( q 0 n e w , q d / 2 n e w ) (q^{new}_0,q^{new}_{d/2}) (q0new?,qd/2new?)由向量 ( q 0 , q d / 2 ) (q_0,q_{d/2}) (q0?,qd/2?)逆時針旋轉(zhuǎn) m θ 0 m\theta_0 mθ0?得到
可于下面鏈接中LLaMA中ROPE實現(xiàn)做對比
LLaMA中ROPE位置編碼實現(xiàn)源碼解析文章來源地址http://www.zghlxwxcb.cn/news/detail-672773.html

到了這里,關(guān)于PaLM中ROPE位置編碼實現(xiàn)源碼解析的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實不符,請點擊違法舉報進行投訴反饋,一經(jīng)查實,立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費用

相關(guān)文章

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包