国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

Llama深入淺出

這篇具有很好參考價(jià)值的文章主要介紹了Llama深入淺出。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問。

前方干貨預(yù)警:這可能是你能夠找到的最容易懂最具實(shí)操性學(xué)習(xí)開源LLM模型源碼的教程。

本例從零開始基于transformers庫(kù)逐模塊搭建和解讀Llama模型源碼(中文可以翻譯成羊駝)。

并且訓(xùn)練它來(lái)實(shí)現(xiàn)一個(gè)有趣的實(shí)例:兩數(shù)之和。

輸入輸出類似如下:

輸入:"12345+54321="

輸出:"66666"

我們把這個(gè)任務(wù)當(dāng)做一個(gè)文本生成任務(wù)來(lái)進(jìn)行。輸入是一個(gè)序列的上半部分,輸出其下半部分.

這和文本生成的輸入輸出結(jié)構(gòu)是類似的,所以可以用Llama來(lái)做。

目前大部分開源LLM模型都是基于transformers庫(kù)來(lái)做的,它們的結(jié)構(gòu)大部分都和Llama大同小異。

俗話說,魔鬼隱藏在細(xì)節(jié)中,深入理解Llama模型的的源碼細(xì)節(jié),將會(huì)幫助你打通和開源LLM模型相關(guān)的基礎(chǔ)原理(如旋轉(zhuǎn)位置編碼以及長(zhǎng)度外推),并讓你熟悉各種參數(shù)的配置和使用(如past_key_value,attention_mask的使用等等)。

????公眾號(hào)算法美食屋后臺(tái)回復(fù)關(guān)鍵詞:torchkeras,獲取本文notebook源碼。

一,準(zhǔn)備數(shù)據(jù)

import?random

import?numpy?as?np
import?torch
from?torch.utils.data?import?Dataset,DataLoader

#?定義字典
words?=?'<PAD>,<BOS>,<EOS>,1,2,3,4,5,6,7,8,9,0,+,='
vocab?=?{word:?i?for?i,?word?in?enumerate(words.split(','))}
vocab_r?=?[k?for?k,?v?in?vocab.items()]?#反查詞典
#兩數(shù)相加數(shù)據(jù)集
def?get_data(min_length=10,max_length=20):
????#?定義詞集合
????words?=?['0',?'1',?'2',?'3',?'4',?'5',?'6',?'7',?'8',?'9']

????#?每個(gè)詞被選中的概率
????p?=?np.array([7,?5,?5,?7,?6,?5,?7,?6,?5,?7])
????p?=?p?/?p.sum()

????#?隨機(jī)采樣n1個(gè)詞作為s1
????n1?=?random.randint(min_length,?max_length)
????s1?=?np.random.choice(words,?size=n1,?replace=True,?p=p)
????s1?=?s1.tolist()

????#?隨機(jī)采樣n2個(gè)詞作為s2
????n2?=?random.randint(min_length,?max_length)
????s2?=?np.random.choice(words,?size=n2,?replace=True,?p=p)
????s2?=?s2.tolist()

????#?x等于s1和s2字符上的相加
????x?=?s1?+?['+']?+?s2?+?['=']
????
????#?y等于s1和s2數(shù)值上的相加
????y?=?int(''.join(s1))?+?int(''.join(s2))
????y?=?list(str(y))
????
????#?加上首尾符號(hào)
????x?=?['<BOS>']?+?x?
????y?=??y?+?['<EOS>']
????
????return?x,y

x,y?=?get_data()?
print(''.join(x)+''.join(y),"\n")
<BOS>3914835626735057733+318829464988=3914835945564522721<EOS>
#?定義數(shù)據(jù)集
class?TwoSumDataset(torch.utils.data.Dataset):
????def?__init__(self,size?=?100000,?min_length=10,max_length=20):
????????super(Dataset,?self).__init__()
????????self.size?=?size
????????self.min_length=min_length
????????self.max_length=max_length

????def?__len__(self):
????????return?self.size

????def?__getitem__(self,?i):
????????x,y?=?self.get(i)
????????
????????#?編碼成token
????????context_ids?=?[vocab[i]?for?i?in?x]
????????target_ids?=?[vocab[i]?for?i?in?y]
????????
????????input_ids?=?context_ids?+?target_ids
????????
????????#-100標(biāo)志位后面會(huì)在計(jì)算loss時(shí)會(huì)被忽略不貢獻(xiàn)損失,我們集中優(yōu)化target部分生成的loss
????????labels?=?[-100]*len(context_ids)+?target_ids
????????masks?=?[0?if?t==vocab['<PAD>']?else?1?for?t?in?input_ids]
????????
????????example?=?{'input_ids':input_ids,
??????????????????'labels':labels,'attention_mask':masks}
????????
????????return?example
????
????def?get(self,i):
????????return?get_data(self.min_length,self.max_length)
????
????
????def?show_example(self,example):
????????input_ids,labels?=?example['input_ids'],example['labels']
????????x?=?''.join([vocab_r[a]?for?a,b?in?zip(input_ids,labels)?if?b==-100])
????????y?=?''.join([vocab_r[a]?for?a,b?in?zip(input_ids,labels)?if?b!=-100])
????????print(x+y)
????????
????????
????
ds_train?=?TwoSumDataset(size?=?100000,min_length=10,max_length=20)
ds_val?=?TwoSumDataset(size?=?10000,min_length=10,max_length=20)
example?=?ds_train[0]
ds_train.show_example(example)
<BOS>12878683929048906366+11274414130675477=12889958343179581843<EOS>
def?data_collator(examples:?list):
????len_ids?=?[len(example["input_ids"])?for?example?in?examples]
????longest?=?max(len_ids)?#之后按照batch中最長(zhǎng)的input_ids進(jìn)行padding
????
????input_ids?=?[]
????labels_list?=?[]
????masks_list?=?[]
????
????for?length,?example?in?sorted(zip(len_ids,?examples),?key=lambda?x:?-x[0]):
????????ids?=?example["input_ids"]
????????labs?=?example["labels"]
????????masks?=?example['attention_mask']
????????
????????ids?=?[vocab['<PAD>']]?*?(longest?-?length)+ids?
????????labs?=?[-100]?*?(longest?-?length)+labs
????????masks?=?[0]*(longest?-?length)+masks
????????
????????input_ids.append(torch.LongTensor(ids))
????????labels_list.append(torch.LongTensor(labs))
????????masks_list.append(torch.LongTensor(masks))
??????????
????input_ids?=?torch.stack(input_ids)
????labels?=?torch.stack(labels_list)
????attention_mask?=?torch.stack(masks_list)
????return?{
????????"input_ids":?input_ids,
????????"labels":?labels,
????????"attention_mask":attention_mask
????}

#?數(shù)據(jù)加載器
dl_train?=?DataLoader(dataset=ds_train,
?????????batch_size=200,
?????????drop_last=True,
?????????shuffle=True,
?????????collate_fn?=?data_collator????????
????????)

dl_val?=?DataLoader(dataset=ds_val,
?????????batch_size=200,
?????????drop_last=True,
?????????shuffle=False,
?????????collate_fn?=?data_collator??
????????)
for?batch?in?dl_train:
????break
batch
{'input_ids': tensor([[ 1, 11,  6,  ...,  7, 11,  2],
         [ 0,  1,  6,  ...,  5,  4,  2],
         [ 0,  1,  7,  ...,  8,  8,  2],
         ...,
         [ 0,  0,  0,  ..., 10, 11,  2],
         [ 0,  0,  0,  ..., 12,  3,  2],
         [ 0,  0,  0,  ..., 11, 12,  2]]),
 'labels': tensor([[-100, -100, -100,  ...,    7,   11,    2],
         [-100, -100, -100,  ...,    5,    4,    2],
         [-100, -100, -100,  ...,    8,    8,    2],
         ...,
         [-100, -100, -100,  ...,   10,   11,    2],
         [-100, -100, -100,  ...,   12,    3,    2],
         [-100, -100, -100,  ...,   11,   12,    2]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         [0, 1, 1,  ..., 1, 1, 1],
         ...,
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 1, 1, 1]])}

二,定義模型

下面,我們會(huì)像搭積木建城堡那樣從低往高地構(gòu)建LLaMA模型。

先構(gòu)建4個(gè)基礎(chǔ)組件:旋轉(zhuǎn)位置編碼,多頭注意力、前饋網(wǎng)絡(luò)、層歸一化。類似用最基礎(chǔ)的積木塊搭建了 墻壁,房頂,房門,窗戶 這樣的模塊。

然后用這4個(gè)基礎(chǔ)組件構(gòu)建中間成品: 解碼層。類似用基礎(chǔ)組件構(gòu)建了房間。

接著用多個(gè)中間成品解碼層的堆疊組裝成了LlamaModel完整模型,相當(dāng)于通過構(gòu)建多個(gè)房間建成了城堡的主體結(jié)構(gòu)。

最后我們?cè)贚lamaModel基礎(chǔ)上設(shè)計(jì)了兩種不同的輸出head,一種是語(yǔ)言模型Head,得到了LlamaForCausalLM,可用于文本生成。

另外一種是分類head,得到了LlamaForSequenceClassification,可用于文本分類。

相當(dāng)于我們?cè)诔潜ぶ黧w結(jié)構(gòu)完成的基礎(chǔ)上設(shè)計(jì)了兩種不同的裝修風(fēng)格,一種是加裝了一些游樂設(shè)施以便用于商業(yè)活動(dòng),另一種則是加裝了一些武器以便用于軍事活動(dòng)。


1, 旋轉(zhuǎn)位置編碼: RoPE (使用旋轉(zhuǎn)矩陣實(shí)現(xiàn)的絕對(duì)位置編碼,可以起到相對(duì)位置編碼的效果)

2, 多頭注意力: LlamaAttention (用于融合不同token之間的信息)

3, 前饋網(wǎng)絡(luò): LlamaMLP (用于逐位置將多頭注意力融合后的信息進(jìn)行高維映射變換)

4, 層歸一化: LlamaRMSNorm (用于穩(wěn)定輸入,相當(dāng)于保持每個(gè)詞向量的方向不變,但對(duì)模長(zhǎng)標(biāo)準(zhǔn)化。)


5, Llama解碼層: LlamaDecoderLayer (同時(shí)具備信息融合,信息轉(zhuǎn)換功能的基本結(jié)構(gòu)單元)


6, Llama解碼器: LlamaModel (多個(gè)解碼層的堆疊)


7,Llama語(yǔ)言模型: LlamaForCausalLM (解碼器加上語(yǔ)言模型head,可用于文本生成)

8,Llama分類模型: LlamaForSequenceClassification (解碼器加上分類head,可用于文本分類)


import?math
from?typing?import?List,?Optional,?Tuple,?Union

import?torch
import?torch.nn.functional?as?F
import?torch.utils.checkpoint
from?torch?import?nn
from?torch.nn?import?BCEWithLogitsLoss,?CrossEntropyLoss,?MSELoss

from?transformers.activations?import?ACT2FN
from?transformers.modeling_outputs?import?BaseModelOutputWithPast,?CausalLMOutputWithPast,?SequenceClassifierOutputWithPast
from?transformers.modeling_utils?import?PreTrainedModel
from?transformers.utils?import?add_start_docstrings,?add_start_docstrings_to_model_forward,?logging,?replace_return_docstrings

from?transformers.models.llama.configuration_llama??import?LlamaConfig
from?transformers.models.llama.modeling_llama?import?LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING

logger?=?logging.get_logger('llama')

config?=?LlamaConfig(
????vocab_size=len(vocab),
????hidden_size=512,
????intermediate_size=2752,
????num_hidden_layers=8,
????num_attention_heads=16,
????hidden_act='silu',
????max_position_embeddings=128,
????initializer_range=0.02,
????rms_norm_eps=1e-06,
????use_cache=True,
????pad_token_id=0,
????bos_token_id=1,
????eos_token_id=2,
????tie_word_embeddings=False
)

1,旋轉(zhuǎn)位置編碼 RoPE

旋轉(zhuǎn)位置編碼即使用旋轉(zhuǎn)矩陣表示位置編碼(Rotary Position Encoding),簡(jiǎn)稱RoPE。

關(guān)于RoPE的3個(gè)核心要點(diǎn)知識(shí)如下:

  • RoPE的設(shè)計(jì)思想是使用絕對(duì)位置編碼來(lái)達(dá)到相對(duì)位置編碼的效果。

  • RoPE的實(shí)現(xiàn)方式是使用旋轉(zhuǎn)矩陣來(lái)表示絕對(duì)位置編碼。

  • 使用NTK擴(kuò)展方法可以讓RoPE在短文本上訓(xùn)練并在長(zhǎng)文本上做預(yù)測(cè)。

參考文章:

《博采眾長(zhǎng)的旋轉(zhuǎn)式位置編碼》https://kexue.fm/archives/8265

《RoPE是一種進(jìn)制編碼》https://kexue.fm/archives/9675

(1)絕對(duì)位置編碼和相對(duì)位置編碼

位置編碼一般可以分成絕對(duì)位置編碼和相對(duì)位置編碼。

絕對(duì)位置編碼的優(yōu)點(diǎn)是計(jì)算簡(jiǎn)單高效,缺點(diǎn)是一般效果不如相對(duì)位置編碼。

相對(duì)位置編碼的優(yōu)點(diǎn)是效果較好,缺點(diǎn)是計(jì)算效率不如絕對(duì)位置編碼。

絕對(duì)位置編碼:

相對(duì)位置編碼:

在相對(duì)位置編碼中,注意力權(quán)重的結(jié)果僅僅和參與注意力計(jì)算的token向量的相對(duì)位置有關(guān),不和絕對(duì)位置直接關(guān)聯(lián)。

這符合NLP領(lǐng)域在序列長(zhǎng)度方向上具有平移不變性的特點(diǎn),所以相對(duì)位置編碼一般效果會(huì)優(yōu)于絕對(duì)位置編碼。

不過絕對(duì)位置編碼并非一無(wú)是處,絕對(duì)位置編碼只需要初始化時(shí)對(duì)序列的每個(gè)位置(數(shù)量正比于序列長(zhǎng)度)賦予位置編碼即可,后續(xù)無(wú)需干預(yù)。

而相對(duì)位置編碼要在計(jì)算過程中獲取許多個(gè)(數(shù)量正比于序列長(zhǎng)度平方)相對(duì)位置。

因此絕對(duì)位置編碼更加簡(jiǎn)單高效。

(2)使用旋轉(zhuǎn)矩陣表示位置編碼

上述討論可以看到,絕對(duì)位置編碼和相對(duì)位置編碼互有優(yōu)劣,那么有沒有什么辦法能夠?qū)Χ哌M(jìn)行取長(zhǎng)補(bǔ)短呢?

有的,這個(gè)方法就是RoPE,它的設(shè)計(jì)思想就是使用絕對(duì)位置編碼來(lái)達(dá)到相對(duì)位置編碼的效果。

那么旋轉(zhuǎn)位置編碼如何使用絕對(duì)位置編碼來(lái)達(dá)到相對(duì)位置編碼的效果的呢?答案是使用旋轉(zhuǎn)矩陣來(lái)表示位置編碼。

其中 為旋轉(zhuǎn)矩陣,滿足性質(zhì) 。于是,有:

符合 相對(duì)位置編碼形式。

perfect! 我們用絕對(duì)位置編碼實(shí)現(xiàn)了相對(duì)位置編碼的效果。

那么,旋轉(zhuǎn)矩陣長(zhǎng)什么樣呢?

在二維情形長(zhǎng)下面樣子。

在NLP領(lǐng)域,詞向量的維度一般會(huì)很高(例如4096)。

利用矩陣的分塊思想,可以證明高維情形下擴(kuò)展成下述形式依舊滿足旋轉(zhuǎn)矩陣性質(zhì)

其中 ,即越高的維度對(duì)應(yīng)三角函數(shù)的系數(shù)越小,周期越大,變化越緩慢。

由于旋轉(zhuǎn)矩陣是稀疏矩陣,直接使用乘法計(jì)算會(huì)很浪費(fèi)算力,可以將旋轉(zhuǎn)位置編碼過程由矩陣乘法運(yùn)算簡(jiǎn)化成兩次向量的哈達(dá)瑪積求和。

(3)旋轉(zhuǎn)位置編碼的長(zhǎng)度擴(kuò)展

在LLM的應(yīng)用中,有一個(gè)非常重要的參數(shù),叫做LLM支持的上下文長(zhǎng)度(max context length)。

更長(zhǎng)的上下文長(zhǎng)度允許我們進(jìn)行更多輪次的對(duì)話,允許我們對(duì)更長(zhǎng)的本文進(jìn)行總結(jié)分析,也允許我們生成更長(zhǎng)的文章。

但是在訓(xùn)練LLM的時(shí)候,我們的訓(xùn)練語(yǔ)料大部分是不夠長(zhǎng)的,許多LLM訓(xùn)練時(shí)候設(shè)計(jì)的最大文本長(zhǎng)度都是只有2k,也就是最長(zhǎng)2048個(gè)token。

那么,能否在訓(xùn)練的時(shí)候使用較短的文本,而在推理的時(shí)候擴(kuò)展到長(zhǎng)文本上呢?

是有可能的,我們可以對(duì)RoPE進(jìn)行長(zhǎng)度擴(kuò)展。

我們介紹3種擴(kuò)展方案。

第一種是直接外推:直接外推其實(shí)就是繼續(xù)沿用現(xiàn)有的位置編碼公式,不做任何修改。

在擴(kuò)展長(zhǎng)度不太長(zhǎng)的時(shí)候,例如由2k擴(kuò)展到2.5k時(shí),這種方法可能對(duì)性能的影響并不大。

因?yàn)樾D(zhuǎn)位置編碼只和相對(duì)位置m-n的大小有關(guān),一般具有遠(yuǎn)程衰減性,即相對(duì)距離越大的兩個(gè)token,其相關(guān)性一般越弱。

因此如果我們的模型已經(jīng)從訓(xùn)練數(shù)據(jù)那里學(xué)習(xí)到了token之間的相關(guān)性相對(duì)于相對(duì)距離在0-2k的一個(gè)合適的衰減規(guī)律的時(shí)候,可以設(shè)想把這個(gè)規(guī)律應(yīng)用到0-2.5k也是沒有太大的問題的。

但是如果我們要擴(kuò)展到更長(zhǎng)的長(zhǎng)度,例如從2k擴(kuò)展到32k,這種直接外推的方案通常會(huì)嚴(yán)重地影響性能。因?yàn)槲覀儗W(xué)習(xí)到的衰減規(guī)律有可能在5k的那里就完全衰減截?cái)嗷窘禐?了,這樣我們就無(wú)法捕捉相對(duì)距離長(zhǎng)于5k的兩個(gè)token之間的相互作用,外推就會(huì)導(dǎo)致性能下降。

總結(jié)一下,直接外推對(duì)衰減規(guī)律在長(zhǎng)距離情況下的使用容易出現(xiàn)問題,導(dǎo)致性能下降。

為了減少長(zhǎng)度外推對(duì)性能的影響,我們可以讓訓(xùn)練好的模型在更長(zhǎng)的上下文上做少許步驟的微調(diào)。

第二種是線性內(nèi)插:線性內(nèi)插需要改變位置編碼公式,等效于將位置序號(hào)等比例縮小。

編碼公式變化如 ,當(dāng)從2k擴(kuò)展到32k,等效于需要將位置序號(hào)變成原來(lái)的1/16.

線性內(nèi)插沒有改變模型學(xué)習(xí)到的衰減規(guī)律的應(yīng)用范圍,不考慮微調(diào)的話,其效果一般好于直接外推方案。

但是,擴(kuò)展倍數(shù)非常大的時(shí)候,例如從2k擴(kuò)展到32k,其性能也會(huì)明顯的受到影響。

因?yàn)樵谶@種情況下,衰減規(guī)律在短距離情況下的使用會(huì)受到較嚴(yán)重的影響,本來(lái)距離為1的兩個(gè)token,長(zhǎng)度擴(kuò)展后相當(dāng)于變成了距離為1/16,衰減規(guī)律在短距離時(shí)可能具有非常大的變化率,因此對(duì)相關(guān)性的評(píng)估可能會(huì)極端地偏離合理值。

應(yīng)用線性內(nèi)插時(shí),在長(zhǎng)文本上做少許步驟的微調(diào)也能夠明顯地改善性能。

第三種是NTK擴(kuò)展方式:這種方式綜合了外推和內(nèi)插的優(yōu)點(diǎn),做長(zhǎng)度擴(kuò)展后即使不微調(diào)也能夠保持較好的性能。

前面的分析我們知道直接外推對(duì)衰減規(guī)律在長(zhǎng)距離情況下的使用容易出問題,在短距離情況下的使用不受影響。

而線性內(nèi)插對(duì)衰減規(guī)律在短距離情況下的使用容易出現(xiàn)問題,在長(zhǎng)距離的情況下影響較小。

我們能否將它們綜合起來(lái),在短距離情況下具有外推特性(與擴(kuò)展前基本一致),在長(zhǎng)距離情況下具有內(nèi)插特性(縮放到擴(kuò)展前的范圍),從而使得長(zhǎng)距離情況下和短距離情況下衰減規(guī)律的使用都不太受到影響呢。

我們觀察RoPE位置編碼第行的元素計(jì)算公式 ,可以發(fā)現(xiàn)越大,三角函數(shù)對(duì)應(yīng)的角頻率系數(shù)越小,或者說越低頻,對(duì)應(yīng)的三角函數(shù)變化越慢。

容易得到如下直觀結(jié)論:短距離之間的差異(例如1和5的差異),主要體現(xiàn)在高頻分量(i比較小)上,長(zhǎng)距離之間的差異(例如5000和10000的差異),主要體現(xiàn)在低頻分量(i比較大)上。

為了在短距離情況下具有外推特性,而在長(zhǎng)距離情況下具有內(nèi)插特性,我們可以設(shè)計(jì)一個(gè)和有關(guān)的位置序號(hào)縮放因子,使得在最高頻()時(shí)取值為1(與擴(kuò)展前基本一致),而在最低頻時(shí)()恰好為縮放倍數(shù)的倒數(shù)(縮放到擴(kuò)展前的范圍)。

一種有效的選擇方案是的指數(shù)函數(shù),其效果相當(dāng)于對(duì)中的做一個(gè)縮放,根據(jù)邊界條件容易求得合適的縮放因子為 。

NTK擴(kuò)展方式的要點(diǎn)是高頻外推,低頻內(nèi)插,實(shí)現(xiàn)方法是直接對(duì)底數(shù)base進(jìn)行縮放,類似進(jìn)制編碼轉(zhuǎn)換。

采用NTK擴(kuò)展到長(zhǎng)文本,即使不做微調(diào),性能會(huì)只會(huì)略有下降。

下面是RoPE以及三種長(zhǎng)度擴(kuò)展方式的實(shí)現(xiàn)。

class?LlamaRotaryEmbedding(torch.nn.Module):
????def?__init__(self,?dim,?max_position_embeddings=2048,?base=10000,?device=None):
????????super().__init__()
????????self.dim?=?dim
????????self.max_position_embeddings?=?max_position_embeddings
????????self.base?=?base
????????inv_freq?=?1.0?/?(self.base?**?(torch.arange(0,?self.dim,?2).float().to(device)?/?self.dim))
????????self.register_buffer("inv_freq",?inv_freq,?persistent=False)?#persistent=False將不會(huì)作為state_dict

????????#?Build?here?to?make?`torch.jit.trace`?work.
????????self._set_cos_sin_cache(
????????????seq_len=max_position_embeddings,?device=self.inv_freq.device,?dtype=torch.get_default_dtype()
????????)

????def?_set_cos_sin_cache(self,?seq_len,?device,?dtype):
????????self.max_seq_len_cached?=?seq_len
????????t?=?torch.arange(self.max_seq_len_cached,?device=device,?dtype=self.inv_freq.dtype)

????????freqs?=?torch.einsum("i,j->ij",?t,?self.inv_freq)
????????#?Different?from?paper,?but?it?uses?a?different?permutation?in?order?to?obtain?the?same?calculation
????????emb?=?torch.cat((freqs,?freqs),?dim=-1)
????????self.register_buffer("cos_cached",?emb.cos()[None,?None,?:,?:].to(dtype),?persistent=False)
????????self.register_buffer("sin_cached",?emb.sin()[None,?None,?:,?:].to(dtype),?persistent=False)

????def?forward(self,?x,?seq_len=None):
????????#?x:?[bs,?num_attention_heads,?seq_len,?head_size]
????????#超過預(yù)設(shè)的max_position_embeddings則重新計(jì)算更大的Rope緩存,否則直接在緩存上切片
????????if?seq_len?>?self.max_seq_len_cached:?
????????????self._set_cos_sin_cache(seq_len=seq_len,?device=x.device,?dtype=x.dtype)

????????return?(
????????????self.cos_cached[:,?:,?:seq_len,?...].to(dtype=x.dtype),
????????????self.sin_cached[:,?:,?:seq_len,?...].to(dtype=x.dtype),
????????)

????
class?LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
????"""LlamaRotaryEmbedding?extended?with?linear?scaling.?Credits?to?the?Reddit?user?/u/kaiokendev"""

????def?__init__(self,?dim,?max_position_embeddings=2048,?base=10000,?device=None,?scaling_factor=1.0):
????????self.scaling_factor?=?scaling_factor
????????super().__init__(dim,?max_position_embeddings,?base,?device)

????def?_set_cos_sin_cache(self,?seq_len,?device,?dtype):
????????self.max_seq_len_cached?=?seq_len
????????t?=?torch.arange(self.max_seq_len_cached,?device=device,?dtype=self.inv_freq.dtype)
????????t?=?t?/?self.scaling_factor?#線性內(nèi)插相當(dāng)于將位置序號(hào)等比例縮小

????????freqs?=?torch.einsum("i,j->ij",?t,?self.inv_freq)
????????#?Different?from?paper,?but?it?uses?a?different?permutation?in?order?to?obtain?the?same?calculation
????????emb?=?torch.cat((freqs,?freqs),?dim=-1)
????????self.register_buffer("cos_cached",?emb.cos()[None,?None,?:,?:].to(dtype),?persistent=False)
????????self.register_buffer("sin_cached",?emb.sin()[None,?None,?:,?:].to(dtype),?persistent=False)


class?LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
????"""LlamaRotaryEmbedding?extended?with?Dynamic?NTK?scaling.?Credits?to?the?Reddit?users?/u/bloc97?and?/u/emozilla"""

????def?__init__(self,?dim,?max_position_embeddings=2048,?base=10000,?device=None,?scaling_factor=1.0):
????????self.scaling_factor?=?scaling_factor
????????super().__init__(dim,?max_position_embeddings,?base,?device)

????def?_set_cos_sin_cache(self,?seq_len,?device,?dtype):
????????self.max_seq_len_cached?=?seq_len

????????if?seq_len?>?self.max_position_embeddings:
????????????base?=?self.base?*?(
????????????????(self.scaling_factor?*?seq_len?/?self.max_position_embeddings)?-?(self.scaling_factor?-?1)
????????????)?**?(self.dim?/?(self.dim?-?2))??#NTK擴(kuò)展方式直接對(duì)base進(jìn)行縮放
????????????inv_freq?=?1.0?/?(base?**?(torch.arange(0,?self.dim,?2).float().to(device)?/?self.dim))
????????????self.register_buffer("inv_freq",?inv_freq,?persistent=False)

????????t?=?torch.arange(self.max_seq_len_cached,?device=device,?dtype=self.inv_freq.dtype)

????????freqs?=?torch.einsum("i,j->ij",?t,?self.inv_freq)
????????
????????#此處處理邏輯與原始的ROPE有差異,原始邏輯如下
????????#emb?=?torch.cat((freqs,?freqs),?dim=-1)
????????#emb[...,0::2]=freqs
????????#emb[...,1::2]=freqs
????????
????????
????????#?Different?from?paper,?but?it?uses?a?different?permutation?in?order?to?obtain?the?same?calculation
????????emb?=?torch.cat((freqs,?freqs),?dim=-1)
????????self.register_buffer("cos_cached",?emb.cos()[None,?None,?:,?:].to(dtype),?persistent=False)
????????self.register_buffer("sin_cached",?emb.sin()[None,?None,?:,?:].to(dtype),?persistent=False)
????????
????????
def?rotate_half(x):
????"""Rotates?half?the?hidden?dims?of?the?input."""
????
????#此處邏輯與原始的ROPE有所差異,原始邏輯如下
????#x1?=?x[...,?0::2]?
????#x2?=?x[...,?1::2]
????#res?=?torch.cat((x1,?x2),?dim=-1)
????#res[...,0::2]=-x2
????#res[...,1::2]=x1
????#return?res
????
????x1?=?x[...,?:?x.shape[-1]?//?2]?
????x2?=?x[...,?x.shape[-1]?//?2?:]
????return?torch.cat((-x2,?x1),?dim=-1)


def?apply_rotary_pos_emb(q,?k,?cos,?sin,?position_ids):
????#?The?first?two?dimensions?of?cos?and?sin?are?always?1,?so?we?can?`squeeze`?them.
????cos?=?cos.squeeze(1).squeeze(0)??#?[seq_len,?dim]
????sin?=?sin.squeeze(1).squeeze(0)??#?[seq_len,?dim]
????cos?=?cos[position_ids].unsqueeze(1)??#?[bs,?1,?seq_len,?dim]
????sin?=?sin[position_ids].unsqueeze(1)??#?[bs,?1,?seq_len,?dim]
????q_embed?=?(q?*?cos)?+?(rotate_half(q)?*?sin)
????k_embed?=?(k?*?cos)?+?(rotate_half(k)?*?sin)
????return?q_embed,?k_embed
x?=?torch.randn(1,8,4,2)
rope?=?LlamaRotaryEmbedding(dim=8)
cos,sin?=?rope.forward(x,seq_len=4)
print(cos.shape)?
print(cos)
torch.Size([1, 1, 4, 8])
tensor([[[[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
            1.0000],
          [ 0.5403,  0.9950,  0.9999,  1.0000,  0.5403,  0.9950,  0.9999,
            1.0000],
          [-0.4161,  0.9801,  0.9998,  1.0000, -0.4161,  0.9801,  0.9998,
            1.0000],
          [-0.9900,  0.9553,  0.9996,  1.0000, -0.9900,  0.9553,  0.9996,
            1.0000]]]])

2,多頭注意力 LlamaAttention

這里的LlamaAttention 基本上和《Attention Is All You Need》論文里的是一致的,主要差異有以下一些。

1,k和v的head數(shù)量可以是q的head數(shù)量的幾分之一,類似分組卷積的思想,可以減少參數(shù)規(guī)模。

2,rope位置編碼是每次做多頭注意力時(shí)都進(jìn)行一次,而不是原論文只在輸入的時(shí)候進(jìn)行一次。

3,允許傳入key和value的states的緩存past_key_value,這在多輪對(duì)話中可以減少重復(fù)計(jì)算,起到加速效果。

4,attention_mask是通過加法形式作用到softmax之前的attention矩陣上的。

def?repeat_kv(hidden_states:?torch.Tensor,?n_rep:?int)?->?torch.Tensor:
????"""
????This?is?the?equivalent?of?torch.repeat_interleave(x,?dim=1,?repeats=n_rep).?The?hidden?states?go?from?(batch,
????num_key_value_heads,?seqlen,?head_dim)?to?(batch,?num_attention_heads,?seqlen,?head_dim)
????"""
????batch,?num_key_value_heads,?slen,?head_dim?=?hidden_states.shape
????if?n_rep?==?1:
????????return?hidden_states
????hidden_states?=?hidden_states[:,?:,?None,?:,?:].expand(batch,?num_key_value_heads,?n_rep,?slen,?head_dim)
????return?hidden_states.reshape(batch,?num_key_value_heads?*?n_rep,?slen,?head_dim)


class?LlamaAttention(nn.Module):
????"""Multi-headed?attention?from?'Attention?Is?All?You?Need'?paper"""

????def?__init__(self,?config:?LlamaConfig):
????????super().__init__()
????????self.config?=?config
????????self.hidden_size?=?config.hidden_size
????????self.num_heads?=?config.num_attention_heads
????????self.head_dim?=?self.hidden_size?//?self.num_heads
????????self.num_key_value_heads?=?config.num_key_value_heads
????????self.num_key_value_groups?=?self.num_heads?//?self.num_key_value_heads
????????self.max_position_embeddings?=?config.max_position_embeddings

????????if?(self.head_dim?*?self.num_heads)?!=?self.hidden_size:
????????????raise?ValueError(
????????????????f"hidden_size?must?be?divisible?by?num_heads?(got?`hidden_size`:?{self.hidden_size}"
????????????????f"?and?`num_heads`:?{self.num_heads})."
????????????)
????????self.q_proj?=?nn.Linear(self.hidden_size,?self.num_heads?*?self.head_dim,?bias=False)
????????self.k_proj?=?nn.Linear(self.hidden_size,?self.num_key_value_heads?*?self.head_dim,?bias=False)
????????self.v_proj?=?nn.Linear(self.hidden_size,?self.num_key_value_heads?*?self.head_dim,?bias=False)
????????self.o_proj?=?nn.Linear(self.num_heads?*?self.head_dim,?self.hidden_size,?bias=False)
????????self._init_rope()

????def?_init_rope(self):
????????if?self.config.rope_scaling?is?None:
????????????self.rotary_emb?=?LlamaRotaryEmbedding(self.head_dim,?max_position_embeddings=self.max_position_embeddings)
????????else:
????????????scaling_type?=?self.config.rope_scaling["type"]
????????????scaling_factor?=?self.config.rope_scaling["factor"]
????????????if?scaling_type?==?"linear":
????????????????self.rotary_emb?=?LlamaLinearScalingRotaryEmbedding(
????????????????????self.head_dim,?max_position_embeddings=self.max_position_embeddings,?scaling_factor=scaling_factor
????????????????)
????????????elif?scaling_type?==?"dynamic":
????????????????self.rotary_emb?=?LlamaDynamicNTKScalingRotaryEmbedding(
????????????????????self.head_dim,?max_position_embeddings=self.max_position_embeddings,?scaling_factor=scaling_factor
????????????????)
????????????else:
????????????????raise?ValueError(f"Unknown?RoPE?scaling?type?{scaling_type}")

????def?_shape(self,?tensor:?torch.Tensor,?seq_len:?int,?bsz:?int):
????????return?tensor.view(bsz,?seq_len,?self.num_heads,?self.head_dim).transpose(1,?2).contiguous()

????def?forward(
????????self,
????????hidden_states:?torch.Tensor,
????????attention_mask:?Optional[torch.Tensor]?=?None,
????????position_ids:?Optional[torch.LongTensor]?=?None,
????????past_key_value:?Optional[Tuple[torch.Tensor]]?=?None,
????????output_attentions:?bool?=?False,
????????use_cache:?bool?=?False,
????)?->?Tuple[torch.Tensor,?Optional[torch.Tensor],?Optional[Tuple[torch.Tensor]]]:
????????bsz,?q_len,?_?=?hidden_states.size()

????????if?self.config.pretraining_tp?>?1:
????????????key_value_slicing?=?(self.num_key_value_heads?*?self.head_dim)?//?self.config.pretraining_tp
????????????query_slices?=?self.q_proj.weight.split(
????????????????(self.num_heads?*?self.head_dim)?//?self.config.pretraining_tp,?dim=0
????????????)
????????????key_slices?=?self.k_proj.weight.split(key_value_slicing,?dim=0)
????????????value_slices?=?self.v_proj.weight.split(key_value_slicing,?dim=0)

????????????query_states?=?[F.linear(hidden_states,?query_slices[i])?for?i?in?range(self.config.pretraining_tp)]
????????????query_states?=?torch.cat(query_states,?dim=-1)

????????????key_states?=?[F.linear(hidden_states,?key_slices[i])?for?i?in?range(self.config.pretraining_tp)]
????????????key_states?=?torch.cat(key_states,?dim=-1)

????????????value_states?=?[F.linear(hidden_states,?value_slices[i])?for?i?in?range(self.config.pretraining_tp)]
????????????value_states?=?torch.cat(value_states,?dim=-1)

????????else:
????????????query_states?=?self.q_proj(hidden_states)
????????????key_states?=?self.k_proj(hidden_states)
????????????value_states?=?self.v_proj(hidden_states)

????????query_states?=?query_states.view(bsz,?q_len,?self.num_heads,?self.head_dim).transpose(1,?2)
????????key_states?=?key_states.view(bsz,?q_len,?self.num_key_value_heads,?self.head_dim).transpose(1,?2)
????????value_states?=?value_states.view(bsz,?q_len,?self.num_key_value_heads,?self.head_dim).transpose(1,?2)

????????kv_seq_len?=?key_states.shape[-2]
????????if?past_key_value?is?not?None:
????????????kv_seq_len?+=?past_key_value[0].shape[-2]
????????cos,?sin?=?self.rotary_emb(value_states,?seq_len=kv_seq_len)
????????query_states,?key_states?=?apply_rotary_pos_emb(query_states,?key_states,?cos,?sin,?position_ids)

????????if?past_key_value?is?not?None:
????????????#?reuse?k,?v,?self_attention
????????????key_states?=?torch.cat([past_key_value[0],?key_states],?dim=2)
????????????value_states?=?torch.cat([past_key_value[1],?value_states],?dim=2)

????????past_key_value?=?(key_states,?value_states)?if?use_cache?else?None

????????#?repeat?k/v?heads?if?n_kv_heads?<?n_heads
????????key_states?=?repeat_kv(key_states,?self.num_key_value_groups)
????????value_states?=?repeat_kv(value_states,?self.num_key_value_groups)

????????attn_weights?=?torch.matmul(query_states,?key_states.transpose(2,?3))?/?math.sqrt(self.head_dim)

????????if?attn_weights.size()?!=?(bsz,?self.num_heads,?q_len,?kv_seq_len):
????????????raise?ValueError(
????????????????f"Attention?weights?should?be?of?size?{(bsz,?self.num_heads,?q_len,?kv_seq_len)},?but?is"
????????????????f"?{attn_weights.size()}"
????????????)

????????if?attention_mask?is?not?None:
????????????if?attention_mask.size()?!=?(bsz,?1,?q_len,?kv_seq_len):
????????????????raise?ValueError(
????????????????????f"Attention?mask?should?be?of?size?{(bsz,?1,?q_len,?kv_seq_len)},?but?is?{attention_mask.size()}"
????????????????)
????????????attn_weights?=?attn_weights?+?attention_mask

????????#?upcast?attention?to?fp32
????????attn_weights?=?nn.functional.softmax(attn_weights,?dim=-1,?dtype=torch.float32).to(query_states.dtype)
????????attn_output?=?torch.matmul(attn_weights,?value_states)

????????if?attn_output.size()?!=?(bsz,?self.num_heads,?q_len,?self.head_dim):
????????????raise?ValueError(
????????????????f"`attn_output`?should?be?of?size?{(bsz,?self.num_heads,?q_len,?self.head_dim)},?but?is"
????????????????f"?{attn_output.size()}"
????????????)

????????attn_output?=?attn_output.transpose(1,?2).contiguous()
????????attn_output?=?attn_output.reshape(bsz,?q_len,?self.hidden_size)

????????if?self.config.pretraining_tp?>?1:
????????????attn_output?=?attn_output.split(self.hidden_size?//?self.config.pretraining_tp,?dim=2)
????????????o_proj_slices?=?self.o_proj.weight.split(self.hidden_size?//?self.config.pretraining_tp,?dim=1)
????????????attn_output?=?sum([F.linear(attn_output[i],?o_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)])
????????else:
????????????attn_output?=?self.o_proj(attn_output)

????????if?not?output_attentions:
????????????attn_weights?=?None

????????return?attn_output,?attn_weights,?past_key_value

3,前饋網(wǎng)絡(luò) LlamaMLP

前饋網(wǎng)絡(luò)是一個(gè)2層的感知機(jī)MLP。

先從hidden_size維度up_proj到intermediate_size維度,然后再down_proj還原為hidden_size維度。

這里的主要特色是引入了一個(gè)gate_proj配合激活函數(shù)來(lái)實(shí)現(xiàn)一個(gè)門控注意力的作用。

class?LlamaMLP(nn.Module):
????def?__init__(self,?config):
????????super().__init__()
????????self.config?=?config
????????self.hidden_size?=?config.hidden_size
????????self.intermediate_size?=?config.intermediate_size
????????self.gate_proj?=?nn.Linear(self.hidden_size,?self.intermediate_size,?bias=False)
????????self.up_proj?=?nn.Linear(self.hidden_size,?self.intermediate_size,?bias=False)
????????self.down_proj?=?nn.Linear(self.intermediate_size,?self.hidden_size,?bias=False)
????????self.act_fn?=?ACT2FN[config.hidden_act]

????def?forward(self,?x):
????????if?self.config.pretraining_tp?>?1:
????????????slice?=?self.intermediate_size?//?self.config.pretraining_tp
????????????gate_proj_slices?=?self.gate_proj.weight.split(slice,?dim=0)
????????????up_proj_slices?=?self.up_proj.weight.split(slice,?dim=0)
????????????down_proj_slices?=?self.down_proj.weight.split(slice,?dim=1)

????????????gate_proj?=?torch.cat(
????????????????[F.linear(x,?gate_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)],?dim=-1
????????????)
????????????up_proj?=?torch.cat([F.linear(x,?up_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)],?dim=-1)

????????????intermediate_states?=?(self.act_fn(gate_proj)?*?up_proj).split(slice,?dim=2)
????????????down_proj?=?[
????????????????F.linear(intermediate_states[i],?down_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)
????????????]
????????????down_proj?=?sum(down_proj)
????????else:
????????????down_proj?=?self.down_proj(self.act_fn(self.gate_proj(x))?*?self.up_proj(x))

????????return?down_proj

4,層歸一化 LlamaRMSNorm

這里的層歸一化叫做RMSNorm,和標(biāo)準(zhǔn)的LayerNorm有少許差異。

首先是沒有移除均值,直接除的RootMeanSquare,然后也沒有加上bias。

這兩個(gè)小的修正可以保證在層歸一化不會(huì)改變hidden_states對(duì)應(yīng)的詞向量的方向,只會(huì)改變其模長(zhǎng)。

在一定的意義上具有合理性。

class?LlamaRMSNorm(nn.Module):
????def?__init__(self,?hidden_size,?eps=1e-6):
????????"""
????????LlamaRMSNorm?is?equivalent?to?T5LayerNorm
????????"""
????????super().__init__()
????????self.weight?=?nn.Parameter(torch.ones(hidden_size))
????????self.variance_epsilon?=?eps

????def?forward(self,?hidden_states):
????????input_dtype?=?hidden_states.dtype
????????hidden_states?=?hidden_states.to(torch.float32)
????????variance?=?hidden_states.pow(2).mean(-1,?keepdim=True)
????????hidden_states?=?hidden_states?*?torch.rsqrt(variance?+?self.variance_epsilon)
????????return?self.weight?*?hidden_states.to(input_dtype)

5,Llama解碼層

解碼層LlamaDecoderLayer由LlamaAttention,LlamaMLP,以及兩個(gè)LlamaRMSNorm組成,并使用了兩次殘差結(jié)構(gòu)。

class?LlamaDecoderLayer(nn.Module):
????def?__init__(self,?config:?LlamaConfig):
????????super().__init__()
????????self.hidden_size?=?config.hidden_size
????????self.self_attn?=?LlamaAttention(config=config)
????????self.mlp?=?LlamaMLP(config)
????????self.input_layernorm?=?LlamaRMSNorm(config.hidden_size,?eps=config.rms_norm_eps)
????????self.post_attention_layernorm?=?LlamaRMSNorm(config.hidden_size,?eps=config.rms_norm_eps)

????def?forward(
????????self,
????????hidden_states:?torch.Tensor,
????????attention_mask:?Optional[torch.Tensor]?=?None,
????????position_ids:?Optional[torch.LongTensor]?=?None,
????????past_key_value:?Optional[Tuple[torch.Tensor]]?=?None,
????????output_attentions:?Optional[bool]?=?False,
????????use_cache:?Optional[bool]?=?False,
????)?->?Tuple[torch.FloatTensor,?Optional[Tuple[torch.FloatTensor,?torch.FloatTensor]]]:
????????"""
????????Args:
????????????hidden_states?(`torch.FloatTensor`):?input?to?the?layer?of?shape?`(batch,?seq_len,?embed_dim)`
????????????attention_mask?(`torch.FloatTensor`,?*optional*):?attention?mask?of?size
????????????????`(batch,?1,?tgt_len,?src_len)`?where?padding?elements?are?indicated?by?very?large?negative?values.
????????????output_attentions?(`bool`,?*optional*):
????????????????Whether?or?not?to?return?the?attentions?tensors?of?all?attention?layers.?See?`attentions`?under
????????????????returned?tensors?for?more?detail.
????????????use_cache?(`bool`,?*optional*):
????????????????If?set?to?`True`,?`past_key_values`?key?value?states?are?returned?and?can?be?used?to?speed?up?decoding
????????????????(see?`past_key_values`).
????????????past_key_value?(`Tuple(torch.FloatTensor)`,?*optional*):?cached?past?key?and?value?projection?states
????????"""

????????residual?=?hidden_states

????????hidden_states?=?self.input_layernorm(hidden_states)

????????#?Self?Attention
????????hidden_states,?self_attn_weights,?present_key_value?=?self.self_attn(
????????????hidden_states=hidden_states,
????????????attention_mask=attention_mask,
????????????position_ids=position_ids,
????????????past_key_value=past_key_value,
????????????output_attentions=output_attentions,
????????????use_cache=use_cache,
????????)
????????hidden_states?=?residual?+?hidden_states

????????#?Fully?Connected
????????residual?=?hidden_states
????????hidden_states?=?self.post_attention_layernorm(hidden_states)
????????hidden_states?=?self.mlp(hidden_states)
????????hidden_states?=?residual?+?hidden_states

????????outputs?=?(hidden_states,)

????????if?output_attentions:
????????????outputs?+=?(self_attn_weights,)

????????if?use_cache:
????????????outputs?+=?(present_key_value,)

????????return?outputs

6,Llama解碼器

LlamaModel由多個(gè)Llama解碼層堆疊而成。

有幾個(gè)理解上的要點(diǎn):

1,_make_causal_mask用于構(gòu)造下三角這種mask結(jié)構(gòu)以實(shí)現(xiàn)語(yǔ)言模型的單向注意力。

2,_expand_mask用于將傳入的等特殊符號(hào)相關(guān)的mask信息展開成和attention矩陣相同的張量結(jié)構(gòu)。

3,設(shè)置gradient_checkpointing=True可以節(jié)約顯存。其主要應(yīng)用了torch.utils.checkpoint.checkpoint方法。它的原理非常簡(jiǎn)單,在對(duì)decoder_layer進(jìn)行forward時(shí)不保存中間激活值從而節(jié)約顯存,backward時(shí)重新計(jì)算相關(guān)值,從而通過時(shí)間換取了空間。

4,gradient_checkpointing和use_cache不能同時(shí)設(shè)置為True,前者是為了節(jié)約顯存時(shí)間換空間的,后者是為了節(jié)約時(shí)間空間換時(shí)間。

#?Copied?from?transformers.models.bart.modeling_bart._make_causal_mask
def?_make_causal_mask(
????input_ids_shape:?torch.Size,?dtype:?torch.dtype,?
????device:?torch.device,?past_key_values_length:?int?=?0
):
????"""
????Make?causal?mask?used?for?bi-directional?self-attention.
????"""
????bsz,?tgt_len?=?input_ids_shape
????mask?=?torch.full((tgt_len,?tgt_len),?torch.finfo(dtype).min,?device=device)
????mask_cond?=?torch.arange(mask.size(-1),?device=device)
????mask.masked_fill_(mask_cond?<?(mask_cond?+?1).view(mask.size(-1),?1),?0)
????mask?=?mask.to(dtype)

????if?past_key_values_length?>?0:
????????mask?=?torch.cat([torch.zeros(tgt_len,?past_key_values_length,?dtype=dtype,?device=device),?mask],?dim=-1)
????return?mask[None,?None,?:,?:].expand(bsz,?1,?tgt_len,?tgt_len?+?past_key_values_length)


#?Copied?from?transformers.models.bart.modeling_bart._expand_mask
def?_expand_mask(mask:?torch.Tensor,?dtype:?torch.dtype,?tgt_len:?Optional[int]?=?None):
????"""
????Expands?attention_mask?from?`[bsz,?seq_len]`?to?`[bsz,?1,?tgt_seq_len,?src_seq_len]`.
????"""
????bsz,?src_len?=?mask.size()
????tgt_len?=?tgt_len?if?tgt_len?is?not?None?else?src_len

????expanded_mask?=?mask[:,?None,?None,?:].expand(bsz,?1,?tgt_len,?src_len).to(dtype)
????inverted_mask?=?1.0?-?expanded_mask

????return?inverted_mask.masked_fill(inverted_mask.to(torch.bool),?torch.finfo(dtype).min)


@add_start_docstrings(
????"The?bare?LLaMA?Model?outputting?raw?hidden-states?without?any?specific?head?on?top.",
????LLAMA_START_DOCSTRING,
)
class?LlamaPreTrainedModel(PreTrainedModel):
????config_class?=?LlamaConfig
????base_model_prefix?=?"model"
????supports_gradient_checkpointing?=?True
????_no_split_modules?=?["LlamaDecoderLayer"]
????_skip_keys_device_placement?=?"past_key_values"

????def?_init_weights(self,?module):
????????std?=?self.config.initializer_range
????????if?isinstance(module,?nn.Linear):
????????????module.weight.data.normal_(mean=0.0,?std=std)
????????????if?module.bias?is?not?None:
????????????????module.bias.data.zero_()
????????elif?isinstance(module,?nn.Embedding):
????????????module.weight.data.normal_(mean=0.0,?std=std)
????????????if?module.padding_idx?is?not?None:
????????????????module.weight.data[module.padding_idx].zero_()

????def?_set_gradient_checkpointing(self,?module,?value=False):
????????if?isinstance(module,?LlamaModel):
????????????module.gradient_checkpointing?=?value


@add_start_docstrings(
????"The?bare?LLaMA?Model?outputting?raw?hidden-states?without?any?specific?head?on?top.",
????LLAMA_START_DOCSTRING,
)
class?LlamaModel(LlamaPreTrainedModel):
????"""
????Transformer?decoder?consisting?of?*config.num_hidden_layers*?layers.?Each?layer?is?a?[`LlamaDecoderLayer`]

????Args:
????????config:?LlamaConfig
????"""

????def?__init__(self,?config:?LlamaConfig):
????????super().__init__(config)
????????self.padding_idx?=?config.pad_token_id
????????self.vocab_size?=?config.vocab_size

????????self.embed_tokens?=?nn.Embedding(config.vocab_size,?config.hidden_size,?self.padding_idx)
????????self.layers?=?nn.ModuleList([LlamaDecoderLayer(config)?for?_?in?range(config.num_hidden_layers)])
????????self.norm?=?LlamaRMSNorm(config.hidden_size,?eps=config.rms_norm_eps)

????????self.gradient_checkpointing?=?False
????????#?Initialize?weights?and?apply?final?processing
????????self.post_init()

????def?get_input_embeddings(self):
????????return?self.embed_tokens

????def?set_input_embeddings(self,?value):
????????self.embed_tokens?=?value

????#?Copied?from?transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
????def?_prepare_decoder_attention_mask(self,?attention_mask,?input_shape,?inputs_embeds,?past_key_values_length):
????????#?create?causal?mask
????????#?[bsz,?seq_len]?->?[bsz,?1,?tgt_seq_len,?src_seq_len]
????????combined_attention_mask?=?None
????????if?input_shape[-1]?>?1:
????????????combined_attention_mask?=?_make_causal_mask(
????????????????input_shape,
????????????????inputs_embeds.dtype,
????????????????device=inputs_embeds.device,
????????????????past_key_values_length=past_key_values_length,
????????????)

????????if?attention_mask?is?not?None:
????????????#?[bsz,?seq_len]?->?[bsz,?1,?tgt_seq_len,?src_seq_len]
????????????expanded_attn_mask?=?_expand_mask(attention_mask,?inputs_embeds.dtype,?tgt_len=input_shape[-1]).to(
????????????????inputs_embeds.device
????????????)
????????????combined_attention_mask?=?(
????????????????expanded_attn_mask?if?combined_attention_mask?is?None?else?expanded_attn_mask?+?combined_attention_mask
????????????)

????????return?combined_attention_mask

????@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
????def?forward(
????????self,
????????input_ids:?torch.LongTensor?=?None,
????????attention_mask:?Optional[torch.Tensor]?=?None,
????????position_ids:?Optional[torch.LongTensor]?=?None,
????????past_key_values:?Optional[List[torch.FloatTensor]]?=?None,
????????inputs_embeds:?Optional[torch.FloatTensor]?=?None,
????????use_cache:?Optional[bool]?=?None,
????????output_attentions:?Optional[bool]?=?None,
????????output_hidden_states:?Optional[bool]?=?None,
????????return_dict:?Optional[bool]?=?None,
????)?->?Union[Tuple,?BaseModelOutputWithPast]:
????????output_attentions?=?output_attentions?if?output_attentions?is?not?None?else?self.config.output_attentions
????????output_hidden_states?=?(
????????????output_hidden_states?if?output_hidden_states?is?not?None?else?self.config.output_hidden_states
????????)
????????use_cache?=?use_cache?if?use_cache?is?not?None?else?self.config.use_cache

????????return_dict?=?return_dict?if?return_dict?is?not?None?else?self.config.use_return_dict

????????#?retrieve?input_ids?and?inputs_embeds
????????if?input_ids?is?not?None?and?inputs_embeds?is?not?None:
????????????raise?ValueError("You?cannot?specify?both?decoder_input_ids?and?decoder_inputs_embeds?at?the?same?time")
????????elif?input_ids?is?not?None:
????????????batch_size,?seq_length?=?input_ids.shape
????????elif?inputs_embeds?is?not?None:
????????????batch_size,?seq_length,?_?=?inputs_embeds.shape
????????else:
????????????raise?ValueError("You?have?to?specify?either?decoder_input_ids?or?decoder_inputs_embeds")

????????seq_length_with_past?=?seq_length
????????past_key_values_length?=?0

????????if?past_key_values?is?not?None:
????????????past_key_values_length?=?past_key_values[0][0].shape[2]
????????????seq_length_with_past?=?seq_length_with_past?+?past_key_values_length

????????if?position_ids?is?None:
????????????device?=?input_ids.device?if?input_ids?is?not?None?else?inputs_embeds.device
????????????position_ids?=?torch.arange(
????????????????past_key_values_length,?seq_length?+?past_key_values_length,?dtype=torch.long,?device=device
????????????)
????????????position_ids?=?position_ids.unsqueeze(0).view(-1,?seq_length)
????????else:
????????????position_ids?=?position_ids.view(-1,?seq_length).long()

????????if?inputs_embeds?is?None:
????????????inputs_embeds?=?self.embed_tokens(input_ids)
????????#?embed?positions
????????if?attention_mask?is?None:
????????????attention_mask?=?torch.ones(
????????????????(batch_size,?seq_length_with_past),?dtype=torch.bool,?device=inputs_embeds.device
????????????)
????????attention_mask?=?self._prepare_decoder_attention_mask(
????????????attention_mask,?(batch_size,?seq_length),?inputs_embeds,?past_key_values_length
????????)

????????hidden_states?=?inputs_embeds

????????if?self.gradient_checkpointing?and?self.training:
????????????if?use_cache:
????????????????logger.warning_once(
????????????????????"`use_cache=True`?is?incompatible?with?gradient?checkpointing.?Setting?`use_cache=False`..."
????????????????)
????????????????use_cache?=?False

????????#?decoder?layers
????????all_hidden_states?=?()?if?output_hidden_states?else?None
????????all_self_attns?=?()?if?output_attentions?else?None
????????next_decoder_cache?=?()?if?use_cache?else?None

????????for?idx,?decoder_layer?in?enumerate(self.layers):
????????????if?output_hidden_states:
????????????????all_hidden_states?+=?(hidden_states,)

????????????past_key_value?=?past_key_values[idx]?if?past_key_values?is?not?None?else?None

????????????if?self.gradient_checkpointing?and?self.training:

????????????????def?create_custom_forward(module):
????????????????????def?custom_forward(*inputs):
????????????????????????#?None?for?past_key_value
????????????????????????return?module(*inputs,?output_attentions,?None)

????????????????????return?custom_forward

????????????????layer_outputs?=?torch.utils.checkpoint.checkpoint(
????????????????????create_custom_forward(decoder_layer),
????????????????????hidden_states,
????????????????????attention_mask,
????????????????????position_ids,
????????????????????None,
????????????????)
????????????else:
????????????????layer_outputs?=?decoder_layer(
????????????????????hidden_states,
????????????????????attention_mask=attention_mask,
????????????????????position_ids=position_ids,
????????????????????past_key_value=past_key_value,
????????????????????output_attentions=output_attentions,
????????????????????use_cache=use_cache,
????????????????)

????????????hidden_states?=?layer_outputs[0]

????????????if?use_cache:
????????????????next_decoder_cache?+=?(layer_outputs[2?if?output_attentions?else?1],)

????????????if?output_attentions:
????????????????all_self_attns?+=?(layer_outputs[1],)

????????hidden_states?=?self.norm(hidden_states)

????????#?add?hidden?states?from?the?last?decoder?layer
????????if?output_hidden_states:
????????????all_hidden_states?+=?(hidden_states,)

????????next_cache?=?next_decoder_cache?if?use_cache?else?None
????????if?not?return_dict:
????????????return?tuple(v?for?v?in?[hidden_states,?next_cache,?all_hidden_states,?all_self_attns]?if?v?is?not?None)
????????return?BaseModelOutputWithPast(
????????????last_hidden_state=hidden_states,
????????????past_key_values=next_cache,
????????????hidden_states=all_hidden_states,
????????????attentions=all_self_attns,
????????)

7,Llama語(yǔ)言模型

Llama語(yǔ)言模型 LlamaForCausalLM是在Llama解碼器LlamaModel的基礎(chǔ)上增加了一個(gè)lm_head作為Generator。

從而實(shí)現(xiàn)了一個(gè)完整的語(yǔ)言模型。

除此之外,Llama語(yǔ)言模型還實(shí)現(xiàn)了以下重要功能。

1,loss計(jì)算功能。當(dāng)forward方法中傳入labels時(shí),會(huì)自動(dòng)計(jì)算語(yǔ)言模型的交叉熵?fù)p失。注意labels中的-100會(huì)被忽略不參與計(jì)算。

2,文本生成generate方法。這個(gè)方法繼承自PreTrainedModel,可以設(shè)置model.generation_config.num_beams選擇束搜索的束寬度,默認(rèn)為1即貪心搜索。

_CONFIG_FOR_DOC?=?"LlamaConfig"

class?LlamaForCausalLM(LlamaPreTrainedModel):
????_tied_weights_keys?=?["lm_head.weight"]

????def?__init__(self,?config):
????????super().__init__(config)
????????self.model?=?LlamaModel(config)
????????self.vocab_size?=?config.vocab_size
????????self.lm_head?=?nn.Linear(config.hidden_size,?config.vocab_size,?bias=False)

????????#?Initialize?weights?and?apply?final?processing
????????self.post_init()

????def?get_input_embeddings(self):
????????return?self.model.embed_tokens

????def?set_input_embeddings(self,?value):
????????self.model.embed_tokens?=?value

????def?get_output_embeddings(self):
????????return?self.lm_head

????def?set_output_embeddings(self,?new_embeddings):
????????self.lm_head?=?new_embeddings

????def?set_decoder(self,?decoder):
????????self.model?=?decoder

????def?get_decoder(self):
????????return?self.model

????@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
????@replace_return_docstrings(output_type=CausalLMOutputWithPast,?config_class=_CONFIG_FOR_DOC)
????def?forward(
????????self,
????????input_ids:?torch.LongTensor?=?None,
????????attention_mask:?Optional[torch.Tensor]?=?None,
????????position_ids:?Optional[torch.LongTensor]?=?None,
????????past_key_values:?Optional[List[torch.FloatTensor]]?=?None,
????????inputs_embeds:?Optional[torch.FloatTensor]?=?None,
????????labels:?Optional[torch.LongTensor]?=?None,
????????use_cache:?Optional[bool]?=?None,
????????output_attentions:?Optional[bool]?=?None,
????????output_hidden_states:?Optional[bool]?=?None,
????????return_dict:?Optional[bool]?=?None,
????)?->?Union[Tuple,?CausalLMOutputWithPast]:

????????output_attentions?=?output_attentions?if?output_attentions?is?not?None?else?self.config.output_attentions
????????output_hidden_states?=?(
????????????output_hidden_states?if?output_hidden_states?is?not?None?else?self.config.output_hidden_states
????????)
????????return_dict?=?return_dict?if?return_dict?is?not?None?else?self.config.use_return_dict

????????#?decoder?outputs?consists?of?(dec_features,?layer_state,?dec_hidden,?dec_attn)
????????outputs?=?self.model(
????????????input_ids=input_ids,
????????????attention_mask=attention_mask,
????????????position_ids=position_ids,
????????????past_key_values=past_key_values,
????????????inputs_embeds=inputs_embeds,
????????????use_cache=use_cache,
????????????output_attentions=output_attentions,
????????????output_hidden_states=output_hidden_states,
????????????return_dict=return_dict,
????????)

????????hidden_states?=?outputs[0]
????????if?self.config.pretraining_tp?>?1:
????????????lm_head_slices?=?self.lm_head.weight.split(self.vocab_size?//?self.config.pretraining_tp,?dim=0)
????????????logits?=?[F.linear(hidden_states,?lm_head_slices[i])?for?i?in?range(self.config.pretraining_tp)]
????????????logits?=?torch.cat(logits,?dim=-1)
????????else:
????????????logits?=?self.lm_head(hidden_states)
????????logits?=?logits.float()

????????loss?=?None
????????if?labels?is?not?None:
????????????#?Shift?so?that?tokens?<?n?predict?n
????????????shift_logits?=?logits[...,?:-1,?:].contiguous()
????????????shift_labels?=?labels[...,?1:].contiguous()
????????????#?Flatten?the?tokens
????????????loss_fct?=?CrossEntropyLoss()
????????????shift_logits?=?shift_logits.view(-1,?self.config.vocab_size)
????????????shift_labels?=?shift_labels.view(-1)
????????????#?Enable?model?parallelism
????????????shift_labels?=?shift_labels.to(shift_logits.device)
????????????loss?=?loss_fct(shift_logits,?shift_labels)

????????if?not?return_dict:
????????????output?=?(logits,)?+?outputs[1:]
????????????return?(loss,)?+?output?if?loss?is?not?None?else?output

????????return?CausalLMOutputWithPast(
????????????loss=loss,
????????????logits=logits,
????????????past_key_values=outputs.past_key_values,
????????????hidden_states=outputs.hidden_states,
????????????attentions=outputs.attentions,
????????)

????def?prepare_inputs_for_generation(
????????self,?input_ids,?past_key_values=None,?attention_mask=None,?inputs_embeds=None,?**kwargs
????):
????????if?past_key_values:
????????????input_ids?=?input_ids[:,?-1:]

????????position_ids?=?kwargs.get("position_ids",?None)
????????if?attention_mask?is?not?None?and?position_ids?is?None:
????????????#?create?position_ids?on?the?fly?for?batch?generation
????????????position_ids?=?attention_mask.long().cumsum(-1)?-?1
????????????position_ids.masked_fill_(attention_mask?==?0,?1)
????????????if?past_key_values:
????????????????position_ids?=?position_ids[:,?-1].unsqueeze(-1)

????????#?if?`inputs_embeds`?are?passed,?we?only?want?to?use?them?in?the?1st?generation?step
????????if?inputs_embeds?is?not?None?and?past_key_values?is?None:
????????????model_inputs?=?{"inputs_embeds":?inputs_embeds}
????????else:
????????????model_inputs?=?{"input_ids":?input_ids}

????????model_inputs.update(
????????????{
????????????????"position_ids":?position_ids,
????????????????"past_key_values":?past_key_values,
????????????????"use_cache":?kwargs.get("use_cache"),
????????????????"attention_mask":?attention_mask,
????????????}
????????)
????????return?model_inputs

????@staticmethod
????def?_reorder_cache(past_key_values,?beam_idx):
????????reordered_past?=?()
????????for?layer_past?in?past_key_values:
????????????reordered_past?+=?(
????????????????tuple(past_state.index_select(0,?beam_idx.to(past_state.device))?for?past_state?in?layer_past),
????????????)
????????return?reordered_past

8,Llama分類模型

LlamaForSequenceClassification是一個(gè)序列分類模型。

這個(gè)分類模型可以用來(lái)訓(xùn)練RLHF流程中的Reward模型。

@add_start_docstrings(
????"""
????The?LLaMa?Model?transformer?with?a?sequence?classification?head?on?top?(linear?layer).

????[`LlamaForSequenceClassification`]?uses?the?last?token?in?order?to?do?the?classification,?as?other?causal?models
????(e.g.?GPT-2)?do.

????Since?it?does?classification?on?the?last?token,?it?requires?to?know?the?position?of?the?last?token.?If?a
????`pad_token_id`?is?defined?in?the?configuration,?it?finds?the?last?token?that?is?not?a?padding?token?in?each?row.?If
????no?`pad_token_id`?is?defined,?it?simply?takes?the?last?value?in?each?row?of?the?batch.?Since?it?cannot?guess?the
????padding?tokens?when?`inputs_embeds`?are?passed?instead?of?`input_ids`,?it?does?the?same?(take?the?last?value?in
????each?row?of?the?batch).
????""",
????LLAMA_START_DOCSTRING,
)
class?LlamaForSequenceClassification(LlamaPreTrainedModel):
????def?__init__(self,?config):
????????super().__init__(config)
????????self.num_labels?=?config.num_labels
????????self.model?=?LlamaModel(config)
????????self.score?=?nn.Linear(config.hidden_size,?self.num_labels,?bias=False)

????????#?Initialize?weights?and?apply?final?processing
????????self.post_init()

????def?get_input_embeddings(self):
????????return?self.model.embed_tokens

????def?set_input_embeddings(self,?value):
????????self.model.embed_tokens?=?value

????@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
????def?forward(
????????self,
????????input_ids:?torch.LongTensor?=?None,
????????attention_mask:?Optional[torch.Tensor]?=?None,
????????position_ids:?Optional[torch.LongTensor]?=?None,
????????past_key_values:?Optional[List[torch.FloatTensor]]?=?None,
????????inputs_embeds:?Optional[torch.FloatTensor]?=?None,
????????labels:?Optional[torch.LongTensor]?=?None,
????????use_cache:?Optional[bool]?=?None,
????????output_attentions:?Optional[bool]?=?None,
????????output_hidden_states:?Optional[bool]?=?None,
????????return_dict:?Optional[bool]?=?None,
????)?->?Union[Tuple,?SequenceClassifierOutputWithPast]:
????????r"""
????????labels?(`torch.LongTensor`?of?shape?`(batch_size,)`,?*optional*):
????????????Labels?for?computing?the?sequence?classification/regression?loss.?Indices?should?be?in?`[0,?...,
????????????config.num_labels?-?1]`.?If?`config.num_labels?==?1`?a?regression?loss?is?computed?(Mean-Square?loss),?If
????????????`config.num_labels?>?1`?a?classification?loss?is?computed?(Cross-Entropy).
????????"""
????????return_dict?=?return_dict?if?return_dict?is?not?None?else?self.config.use_return_dict

????????transformer_outputs?=?self.model(
????????????input_ids,
????????????attention_mask=attention_mask,
????????????position_ids=position_ids,
????????????past_key_values=past_key_values,
????????????inputs_embeds=inputs_embeds,
????????????use_cache=use_cache,
????????????output_attentions=output_attentions,
????????????output_hidden_states=output_hidden_states,
????????????return_dict=return_dict,
????????)
????????hidden_states?=?transformer_outputs[0]
????????logits?=?self.score(hidden_states)

????????if?input_ids?is?not?None:
????????????batch_size?=?input_ids.shape[0]
????????else:
????????????batch_size?=?inputs_embeds.shape[0]

????????if?self.config.pad_token_id?is?None?and?batch_size?!=?1:
????????????raise?ValueError("Cannot?handle?batch?sizes?>?1?if?no?padding?token?is?defined.")
????????if?self.config.pad_token_id?is?None:
????????????sequence_lengths?=?-1
????????else:
????????????if?input_ids?is?not?None:
????????????????sequence_lengths?=?(torch.eq(input_ids,?self.config.pad_token_id).long().argmax(-1)?-?1).to(
????????????????????logits.device
????????????????)
????????????else:
????????????????sequence_lengths?=?-1

????????pooled_logits?=?logits[torch.arange(batch_size,?device=logits.device),?sequence_lengths]

????????loss?=?None
????????if?labels?is?not?None:
????????????labels?=?labels.to(logits.device)
????????????if?self.config.problem_type?is?None:
????????????????if?self.num_labels?==?1:
????????????????????self.config.problem_type?=?"regression"
????????????????elif?self.num_labels?>?1?and?(labels.dtype?==?torch.long?or?labels.dtype?==?torch.int):
????????????????????self.config.problem_type?=?"single_label_classification"
????????????????else:
????????????????????self.config.problem_type?=?"multi_label_classification"

????????????if?self.config.problem_type?==?"regression":
????????????????loss_fct?=?MSELoss()
????????????????if?self.num_labels?==?1:
????????????????????loss?=?loss_fct(pooled_logits.squeeze(),?labels.squeeze())
????????????????else:
????????????????????loss?=?loss_fct(pooled_logits,?labels)
????????????elif?self.config.problem_type?==?"single_label_classification":
????????????????loss_fct?=?CrossEntropyLoss()
????????????????loss?=?loss_fct(pooled_logits.view(-1,?self.num_labels),?labels.view(-1))
????????????elif?self.config.problem_type?==?"multi_label_classification":
????????????????loss_fct?=?BCEWithLogitsLoss()
????????????????loss?=?loss_fct(pooled_logits,?labels)
????????if?not?return_dict:
????????????output?=?(pooled_logits,)?+?transformer_outputs[1:]
????????????return?((loss,)?+?output)?if?loss?is?not?None?else?output

????????return?SequenceClassifierOutputWithPast(
????????????loss=loss,
????????????logits=pooled_logits,
????????????past_key_values=transformer_outputs.past_key_values,
????????????hidden_states=transformer_outputs.hidden_states,
????????????attentions=transformer_outputs.attentions,
????????)

三,訓(xùn)練模型

下面,我們來(lái)訓(xùn)練一個(gè)LlamaForCausalLM 實(shí)現(xiàn)兩數(shù)之和的任務(wù)。

config?=?LlamaConfig(
????vocab_size=len(vocab),
????hidden_size=512,
????intermediate_size=2752,
????num_hidden_layers=8,
????num_attention_heads=16,
????num_key_value_heads=4,
????rope_scaling?=?None,
????hidden_act='silu',
????max_position_embeddings=128,
????initializer_range=0.02,
????rms_norm_eps=1e-06,
????use_cache=True,
????pad_token_id=0,
????bos_token_id=1,
????eos_token_id=2,
????tie_word_embeddings=False,
????pretraining_tp?=?1,
????max_new_tokens?=?100
)
#試算一下
model?=?LlamaForCausalLM(config)
out?=?model.forward(**batch)
print(out.loss)

tensor(2.7630, grad_fn=)

from?torchkeras?import?KerasModel?
from?accelerate?import?Accelerator?

class?StepRunner:
????def?__init__(self,?net,?loss_fn,?accelerator=None,?stage?=?"train",?metrics_dict?=?None,?
?????????????????optimizer?=?None,?lr_scheduler?=?None
?????????????????):
????????self.net,self.loss_fn,self.metrics_dict,self.stage?=?net,loss_fn,metrics_dict,stage
????????self.optimizer,self.lr_scheduler?=?optimizer,lr_scheduler
????????self.accelerator?=?accelerator?if?accelerator?is?not?None?else?Accelerator()?
????????if?self.stage=='train':
????????????self.net.train()?
????????else:
????????????self.net.eval()
????
????def?__call__(self,?batch):
????????
????????#loss
????????with?self.accelerator.autocast():
????????????loss?=?self.net(**batch).loss

????????#backward()
????????if?self.stage=="train"?and?self.optimizer?is?not?None:????????
????????????self.accelerator.backward(loss)
????????????if?self.accelerator.sync_gradients:
????????????????self.accelerator.clip_grad_norm_(self.net.parameters(),?1.0)
????????????self.optimizer.step()
????????????if?self.lr_scheduler?is?not?None:
????????????????self.lr_scheduler.step()
????????????self.optimizer.zero_grad()
????????????
????????all_loss?=?self.accelerator.gather(loss).sum()
????????
????????#losses?(or?plain?metrics?that?can?be?averaged)
????????step_losses?=?{self.stage+"_loss":all_loss.item()}
????????
????????#metrics?(stateful?metrics)
????????step_metrics?=?{}
????????
????????if?self.stage=="train":
????????????if?self.optimizer?is?not?None:
????????????????step_metrics['lr']?=?self.optimizer.state_dict()['param_groups'][0]['lr']
????????????else:
????????????????step_metrics['lr']?=?0.0
????????return?step_losses,step_metrics
????
KerasModel.StepRunner?=?StepRunner
keras_model?=?KerasModel(model,loss_fn?=?None,
????????optimizer=torch.optim.AdamW(model.parameters(),lr=3e-5))


#加載?之前訓(xùn)練過的權(quán)重
ckpt_path?=?'llama_twosum'

keras_model.fit(train_data?=?dl_train,
????????????????val_data?=?dl_val,
????????????????epochs=100,patience=5,
????????????????monitor='val_loss',mode='min',
????????????????ckpt_path?=?ckpt_path,
????????????????mixed_precision='fp16'
???????????????)

Llama深入淺出,llama

四,使用模型

from?transformers.generation.utils?import?GenerationConfig
model.generation_config?=?GenerationConfig.from_dict({'num_beams':1,
????????????????????????????'max_new_tokens':100,
????????????????????????????'max_length':200})
model.generation_config.num_beams=1
model.generation_config.max_new_tokens?=?100?
model.generation_config.max_length=200
def?get_ans(tensor)?->"str":
????s?=?"".join([vocab_r[i]?for?i?in?tensor.tolist()])
????ans?=?s[s.find('=')+1:s.find('<EOS>')].replace('<BOS>','').replace('<EOS>','')
????return?ans
x,y?=?get_data()?
print('x:?'+''.join(x).replace('<BOS>',''))
print('y:?'+''.join(y).replace('<EOS>',''))
x: 3481340050+90157504501803=
y: 90160985841853
input_ids?=?torch.tensor([[vocab[i]?for?i?in?x]])?
out?=?model.generate(inputs=input_ids)
out

tensor([[ 1, ?5, ?6, 10, ?3, ?5, ?6, 12, 12, ?7, 12, 13, 11, 12, ?3, ?7, ?9, ?7, 12, ?6, ?7, 12, ?3, 10, 12, ?5, 14, 11, 12, ?3, ?8, 12, 11, 10, ?7, 10, 6, ?3, 10, ?7, ?5, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, 2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, 12, ?2, ?2, ?2, ?2, ?2, ?2, ?2, 2, 12, ?3, 12, ?3]])

get_ans(out[0])

'90160985841853'

五,評(píng)估模型

from?tqdm?import?tqdm?
loop?=?tqdm(range(1,201))
correct?=?0
for?i?in?loop:
????x,y?=?get_data()?
????input_ids?=?torch.tensor([[vocab[i]?for?i?in?x]])?
????out?=?model.generate(inputs=input_ids)
????pred?=?get_ans(out[0])
????gt?=?''.join(y).replace('<EOS>','')
????if?pred==gt:
????????correct+=1
????loop.set_postfix(acc?=?correct/i)
????
print("acc=",correct/len(loop))

acc= 0.99

漂亮,我們的測(cè)試準(zhǔn)確率達(dá)到了99%!

公眾號(hào)算法美食屋后臺(tái)回復(fù)關(guān)鍵詞:torchkeras,獲取本文notebook源碼,以及更多有趣范例~

Llama深入淺出,llamaLlama深入淺出,llama文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-699877.html

到了這里,關(guān)于Llama深入淺出的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來(lái)自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 深入淺出IAM(1)

    深入淺出IAM(1)

    在本人即將入職的一份基礎(chǔ)架構(gòu)的工作前,我提前聯(lián)系到了團(tuán)隊(duì)leader并跟他進(jìn)行了一次1-1。談話中提到了我可能會(huì)先上手的一個(gè)項(xiàng)目是IAM相關(guān)的實(shí)現(xiàn),于是趁著入職前的間隙,我學(xué)習(xí)了部分優(yōu)秀開源IAM項(xiàng)目實(shí)現(xiàn)思路以及騰訊云開發(fā)專家孔老師的專欄。 在反復(fù)思考和總結(jié)提煉后

    2024年02月05日
    瀏覽(34)
  • 深入淺出前端本地儲(chǔ)存

    深入淺出前端本地儲(chǔ)存

    2021 年,如果你的前端應(yīng)用,需要在瀏覽器上保存數(shù)據(jù),有三個(gè)主流方案: Cookie Web Storage (LocalStorage) IndexedDB 這些方案就是如今應(yīng)用最廣、瀏覽器兼容性最高的三種前端儲(chǔ)存方案 今天這篇文章就聊一聊這三種方案的歷史,優(yōu)缺點(diǎn),以及各自在今天的適用場(chǎng)景 文章在后面還會(huì)提

    2024年04月17日
    瀏覽(28)
  • 深入淺出Kafka

    深入淺出Kafka

    這個(gè)主題 武哥漫談IT ,作者駱俊武 講得更好 首先我們得去官網(wǎng)看看是怎么介紹Kafka的: https://kafka.apache.org/intro Apache Kafka is an open-source distributed event streaming platform. 翻譯成中文就是:Apache Kafka 是一個(gè)開源的分布式流處理平臺(tái)。 Kafka 不是一個(gè)消息系統(tǒng)嗎?為什么被稱為分布式

    2023年04月11日
    瀏覽(27)
  • 深入淺出 Typescript

    深入淺出 Typescript

    TypeScript 是 JavaScript 的一個(gè)超集,支持 ECMAScript 6 標(biāo)準(zhǔn)(ES6 教程)。 TypeScript 由微軟開發(fā)的自由和開源的編程語(yǔ)言。 TypeScript 設(shè)計(jì)目標(biāo)是開發(fā)大型應(yīng)用,它可以編譯成純 JavaScript,編譯出來(lái)的 JavaScript 可以運(yùn)行在任何瀏覽器上。 TypeScript JavaScript JavaScript 的超集,用于解決大型

    2024年02月14日
    瀏覽(38)
  • 深入淺出理解HTTPS

    深入淺出理解HTTPS

    1.對(duì)稱密鑰(Symmetric Encryption) 對(duì)稱密鑰加密算法使用相同的 密鑰(Symmetric key) 來(lái)進(jìn)行數(shù)據(jù) 加密(encryption) 和 解密(decryption) 加密和解密過程都使用相同的密鑰,因此 加密速度較快 ,適用于大量數(shù)據(jù)的加密。 問題在于密鑰的管理:在通信雙方交流之前,需要確保安全地分

    2024年02月10日
    瀏覽(25)
  • 機(jī)器學(xué)習(xí)深入淺出

    目錄 機(jī)器學(xué)習(xí)基本概念 機(jī)器學(xué)習(xí)算法類型 機(jī)器學(xué)習(xí)的實(shí)現(xiàn)步驟 機(jī)器學(xué)習(xí)三個(gè)基本要素 機(jī)器學(xué)習(xí)相關(guān)應(yīng)用 1.語(yǔ)音識(shí)別 2.圖像識(shí)別 機(jī)器學(xué)習(xí)是一種人工智能的分支,它使用算法和數(shù)學(xué)模型來(lái)讓計(jì)算機(jī)自主學(xué)習(xí)數(shù)據(jù)并做出預(yù)測(cè)和決策。這種技術(shù)正在被廣泛應(yīng)用于各種領(lǐng)域,包括

    2023年04月08日
    瀏覽(17)
  • 深入淺出Spring AOP

    深入淺出Spring AOP

    第1章:引言 大家好,我是小黑,咱們今天要聊的是Java中Spring框架的AOP(面向切面編程)。對(duì)于程序員來(lái)說,理解AOP對(duì)于掌握Spring框架來(lái)說是超級(jí)關(guān)鍵的。它像是魔法一樣,能讓咱們?cè)诓桓淖冊(cè)写a的情況下,給程序增加各種功能。 AOP不僅僅是一個(gè)編程范式,它更是一種思

    2024年01月20日
    瀏覽(28)
  • 深入淺出以太坊MEV

    深入淺出以太坊MEV

    要介紹MEV,就繞不開Front-running(直譯為搶跑),也稱為Priority Gas Auctions (PGAs),實(shí)際上是一個(gè)意思。 我們都知道,常規(guī)意義上,在以太坊上提交交易是一個(gè)看似有序的過程,現(xiàn)在重新審視一下這個(gè)過程: 用戶需要在錢包構(gòu)建交易,并簽名,錢包后端會(huì)將該交易廣播到 P2P 網(wǎng)絡(luò)

    2024年02月09日
    瀏覽(22)
  • 深入淺出C++ ——線程庫(kù)

    深入淺出C++ ——線程庫(kù)

    ??在C++11之前,涉及到多線程問題,都是和平臺(tái)相關(guān)的,比如windows和linux下各有自己的接口,這使得代碼的可移植性比較差。C++11中最重要的特性就是對(duì)線程進(jìn)行支持了,使得C++在并行編程時(shí)不需要依賴第三方庫(kù),而且在原子操作中還引入了 原子類 的概念。要使用標(biāo)準(zhǔn)庫(kù)中

    2024年02月03日
    瀏覽(98)
  • 深入淺出如何防勒索病毒

    深入淺出如何防勒索病毒

    1.勒索病毒是如何傳播的 這是不法分子通過改造之前泄露的NSA黑客武器庫(kù)中“永恒之藍(lán)”攻擊程序發(fā)起的網(wǎng)絡(luò)攻擊事件?!坝篮阒{(lán)”通過掃描開放445文件共享端口的Windows電腦甚至是電子信息屏,無(wú)需用戶進(jìn)行任何操作,只要開機(jī)聯(lián)網(wǎng),不法分子就能在電腦和服務(wù)器中植入勒

    2024年02月04日
    瀏覽(19)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包