增加LLM上下文長(zhǎng)度可以提升大語(yǔ)言模型在一些任務(wù)上的表現(xiàn),這包括多輪長(zhǎng)對(duì)話、長(zhǎng)文本摘要、視覺(jué)-語(yǔ)言Transformer模型的高分辨4k模型的理解力以及代碼生成、圖像以及音頻生成等。
對(duì)長(zhǎng)上下文場(chǎng)景,在解碼階段,緩存先前token的Key和Value(KV)需要巨大的內(nèi)存開(kāi)銷,其次主流的LLM模型在推理的時(shí)候上下文長(zhǎng)度都小于等于訓(xùn)練時(shí)的上下文長(zhǎng)度。為了約束長(zhǎng)文本時(shí)緩存先前KV的內(nèi)存和計(jì)算量,很容易想到的方法是對(duì)KV進(jìn)行加窗選擇,這樣可以限制參與當(dāng)前token計(jì)算的KV歷史數(shù)量,將內(nèi)存和計(jì)算量約束在可控的范圍內(nèi)。
Llama 2官方支持的標(biāo)準(zhǔn)版模型(下稱基座模型)上下文長(zhǎng)度是是4k,而Chinese-LLaMA-Alpaca-2基于支持16K上下文,并可通過(guò)NTK方法進(jìn)一步擴(kuò)展至24K+,在大語(yǔ)言模型之十一 Transformer后繼者Retentive Networks (RetNet)博客中知道隨著上下文長(zhǎng)度的增加,基座模型在訓(xùn)練和推理時(shí)Attention模塊的內(nèi)存和算力需求按照文本長(zhǎng)度的平方倍增加。
EFFICIENT STREAMING LANGUAGE MODELS WITH ATTENTION SINKS,觀察到一個(gè)有趣的現(xiàn)象,該論文稱之為attention sink,即初始token的KV對(duì)于加窗Attention方法的得到的模型性能較為重要(盡管從語(yǔ)義上來(lái)看,對(duì)于長(zhǎng)文本而言初始的token并不重要),StreamingLLM是根據(jù)這一思想(保留有限數(shù)量的最近KV以及初始的Attention sink)開(kāi)源的不限上下文長(zhǎng)度方法(論文中給的上下文token可長(zhǎng)達(dá)4百萬(wàn))。
LongLoRA的paper以Llama-2中給出了文本增加的算力需求增加情況,當(dāng)上下文長(zhǎng)度從2k變?yōu)?k的時(shí),模型的自注意力模塊計(jì)算量將增加16倍。
想要減少長(zhǎng)上下文的內(nèi)存和算力需求,可以從改變Transformer的Attention結(jié)構(gòu)以及優(yōu)化GPU計(jì)算單元這兩個(gè)方面出發(fā),在充分利用GPU算力基礎(chǔ)上(FlashAttention結(jié)構(gòu)),優(yōu)化Transformer的Attention計(jì)算和推理結(jié)構(gòu)(如LongLoRA),此外還有在大語(yǔ)言模型之十一 Transformer后繼者Retentive Networks (RetNet)博客中提及微軟的RetNet,其將并行計(jì)算的Attention改成了RNN結(jié)構(gòu)的Attention結(jié)構(gòu),既可以在訓(xùn)練時(shí)并行又可以在推理的時(shí)RNN計(jì)算,這使得算力并不會(huì)隨著上下文長(zhǎng)度增加而顯著上升,幾乎可以做到不限上下文長(zhǎng)度。
本篇介紹的基座模型Llama-2 7B將上下文長(zhǎng)度從4k擴(kuò)充到100k,基于FlashAttention和LongLoRA技術(shù),二者利用了GPU和Transformer的Attention結(jié)構(gòu)改進(jìn)兩方面的技術(shù)。此外通過(guò)finetune來(lái)擴(kuò)展上下文長(zhǎng)度還有以下一些方法:
- Position Interpolation: 增強(qiáng)版的RoPE將Llama擴(kuò)展到32K
- Focused Transformer使用contrastive learning 訓(xùn)練得到 LongLLaMA
- Landmark attention計(jì)算量比1和2更高,但是精度損失較大,將長(zhǎng)文本壓縮成了retrieved tokens
- NTK-aware
- Yarn
- positional Skipping
- out-of-distribution related method
StreamingLLM
該方法利用attention sink現(xiàn)象,采用緩存頭尾KV的方式,并不需要額外重新訓(xùn)練模型,只需要改變Attention模塊的forward計(jì)算使用到的歷史信息即可,即在基座模型上修改Attention的前向推理即可。該方法的對(duì)比原理圖如下:
這是基于LLama-2-13B的測(cè)試數(shù)據(jù)情況,圖a是密集注意力模型,即計(jì)算當(dāng)前token時(shí),會(huì)緩存先前T個(gè)token的KV值,緩存的長(zhǎng)度是大于訓(xùn)練時(shí)上下文長(zhǎng)度的,圖b是加窗注意力模型,T-L個(gè)token狀態(tài)將不再緩存,只緩存最近的L個(gè)KV值,這使得計(jì)算量大大減少,但是PPL卻并不是很理想。圖c的滑窗方法會(huì)重新計(jì)算最近L個(gè)token緩存的KV值,但是計(jì)算復(fù)雜度高,耗時(shí)長(zhǎng),圖d是StreamingLLM中給出的方法,始終保留Attention sink(初始幾個(gè)token的KV),然后在同保留的最近L個(gè)token的KV值一起計(jì)算當(dāng)前token的Attention值,其在計(jì)算效率和混淆度上取得了不錯(cuò)的收益。
官方給出了前向推理llama的示例方法,enable_streaming_llm,這一方法并不一定在所有LLM上都能取得該表現(xiàn),我個(gè)人覺(jué)得采取這一方法需要較為謹(jǐn)慎。這里分析主要是為了梳理其中的思想方法。
如果對(duì)圖中每行的意義不是很理解,那么可以參考《大語(yǔ)言模型之四-LlaMA-2從模型到應(yīng)用》中圖三第一個(gè)Linear層的“你好!”為例,
- 第一次輸入是你,對(duì)應(yīng)第一行,因?yàn)槭堑谝粋€(gè)token,因而先前的緩存的KV沒(méi)有;
- 第二次輸入是好,對(duì)應(yīng)于第二行,因?yàn)槭堑诙€(gè)token(深色),淺藍(lán)色則是緩存的“你”輸入是的KV,可以依次類推,
- 第三次輸入是!,對(duì)應(yīng)于第三行,兩個(gè)淺藍(lán)色塊是“你好”,從這個(gè)下三角可以看出模型是因果的(即當(dāng)前的輸入token只能看到歷史的KV)。
替換Attention模塊的forward方法如下:
def enable_llama_pos_shift_attention(model):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
enable_llama_pos_shift_attention(
module,
)
if isinstance(module, LlamaAttention):
model._modules[name].forward = types.MethodType(
llama_pos_shift_attention_forward, model._modules[name]
)
啟用LLAMA模型中的位置偏移注意力機(jī)制。它遍歷了給定模型的所有模塊,如果模塊是LlamaAttention類型,則將其前向傳遞函數(shù)替換為llama_pos_shift_attention_forward函數(shù)。該函數(shù)實(shí)現(xiàn)了位置偏移注意力機(jī)制,它可以在LLAMA模型中提高性能。
parser.add_argument("--start_size", type=int, default=4)
parser.add_argument("--recent_size", type=int, default=2000)
默認(rèn)緩存前四個(gè)token的KV值,并且緩存最近2000個(gè)KV值。
LongLoRA
LongLoRA的paper以Llama-2為實(shí)驗(yàn)對(duì)象,LongLoRA在Attention層和權(quán)重層都節(jié)約了算力資源。論文中主要引入了shifted short attention,這個(gè)結(jié)構(gòu)和Flash-Attention是兼容的,并且在推理的時(shí)候并不需要,論文中將Llama-2的7B模型上下文從4K擴(kuò)充到了100K,13B基座模型擴(kuò)充到了64K,70B基座模型上下文擴(kuò)充到了32K。
LoRA通過(guò)低秩矩陣將self-Attention矩陣進(jìn)行了線性投影(矩陣的低秩分解),即將一個(gè)參數(shù)量為NxN的矩陣分解為Nxd和dxN的兩個(gè)矩陣,而這2xd卻遠(yuǎn)小于N,如d取64,而N取4096,這就使得需要訓(xùn)練的參數(shù)量大大減少。LoRA方法可行的前提是假設(shè)基座模型在遷移的時(shí)候本質(zhì)上是一個(gè)低秩(low intrinsic rank)問(wèn)題,但是這種方法對(duì)長(zhǎng)上下文情況,單單采用LoRA會(huì)導(dǎo)致混淆度(perplexity)較高。在經(jīng)典的Transformer架構(gòu)中,LoRA方法通常只改變Attention層權(quán)重。
LongLoRA相比LoRA有兩點(diǎn)改進(jìn),第一點(diǎn)是:類似Attention層一樣,embedding和normalization層參數(shù)也會(huì)參與微調(diào)訓(xùn)練調(diào)整;第二點(diǎn)是:S2(shift short) Attention機(jī)制,這兩點(diǎn)的修改顯示在paper圖中:
從右邊可以看到LoRA是Self-Attention而言的,而LongLoRA除了有紅色的LoRA部分,對(duì)Embedding和Norm層都進(jìn)行參數(shù)調(diào)整(圖中有??的部分都是調(diào)整的對(duì)象),對(duì)于算力以及內(nèi)存的需求對(duì)比這里不展示了,有興趣可以進(jìn)一步參考論文。
有必要對(duì)shift short進(jìn)行展開(kāi)一下,假設(shè)LLM訓(xùn)練的時(shí)候是按照2k上下文進(jìn)行的,對(duì)于一個(gè)長(zhǎng)度為8k的輸入文本token,記為[1,2,…,8192],則在訓(xùn)練的時(shí)候會(huì)被分為4個(gè)組,每個(gè)組上下文長(zhǎng)度是2k,即:
[1,2,…,2048],[2049,2050,…,4096],[4097,4098,…,6144],[6145,6146,…,8192],這四個(gè)組在訓(xùn)練的時(shí)候是分開(kāi)進(jìn)行的,每個(gè)組之間Attention模塊并沒(méi)有進(jìn)行信息交互,如果按照傳統(tǒng)的方式分為n(8192)個(gè)組[1,2,…,2048],[2,3,…,2049]…,這樣訓(xùn)練的時(shí)間又會(huì)非常多,S2-Attention就是在這個(gè)方式下提出的,方法是按照目標(biāo)上下文長(zhǎng)度的1/2循環(huán)移動(dòng)上下文token,對(duì)于上面的8192個(gè)token,shift之后為[1025,1026,…,3072],[3073,3074,…,5120],[5121,5122,…,7168],[7169,7170,…,1023,1024],這使得每個(gè)shift之后的每個(gè)分組有一半是前一個(gè)分組,一半是后一個(gè)分組,如[1025,1026,…,3072],各有一半在[1,2,…,2048]和[2049,2050,…,4096]。
Deep Vison lab開(kāi)源了longLoRA的fine-tune代碼,fine-tune.py
首先加載模型和tokenizer,這在前幾篇博客中反復(fù)提及了的,使用的是Huggingface提供的接口。
# Load model and tokenizer
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
)
多線程加載訓(xùn)練數(shù)據(jù)集
rank = int(os.environ.get('RANK', -1))
if rank > 0:
barrier()
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=128, remove_columns=["text", "meta"])
if rank == 0:
barrier()
接來(lái)下是loRA參數(shù)配置
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=targets,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
train是常規(guī)的train過(guò)程
trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args,
train_dataset=dataset["train"],
eval_dataset=None,
data_collator=data_collator)
trainer.train()
S2-Attention的實(shí)現(xiàn)和StreamingLLM的方法類似,都是更改基座模型Attention的forward方法。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-737811.html
def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False):
if use_flash_attn:
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
if inference:
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inference
else:
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattn
else:
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn
其自身實(shí)現(xiàn)的方法添加了flashattn方法,這里可以忽略,在計(jì)算atten時(shí),多了一個(gè)shift操作,修改(就是增加了幾行代碼)的代碼行數(shù)并不多。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-737811.html
# shift
def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)
qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
return qkv
query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)
到了這里,關(guān)于大語(yǔ)言模型之十六-基于LongLoRA的長(zhǎng)文本上下文微調(diào)Llama-2的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!