這一章我們聊聊有哪些方案可以不用微調(diào)直接讓大模型支持超長文本輸入,注意這里主要針對(duì)無限輸入場(chǎng)景。之前在BERT系列中我們就介紹過稀疏注意力和片段遞歸的一些長文本建模方案長文本建模 BigBird & Longformer & Reformer & Performer,不過以上方案無一例外都需要在訓(xùn)練階段引入。針對(duì)當(dāng)前大模型微調(diào)成本高的問題,更多研究放到如何在模型外部支持長文本輸入。先討論下為啥當(dāng)前的大模型會(huì)在推理時(shí)存在輸入長度的限制,主要有以下幾點(diǎn)原因
-
Attention矩陣計(jì)算復(fù)雜度:在不引入稀疏注意力的條件下,Attention矩陣的內(nèi)存和計(jì)算復(fù)雜度是\(O(序列長度^2)\),文本長度的上升會(huì)帶來顯存的指數(shù)增長。
-
訓(xùn)練耗時(shí):訓(xùn)練階段的文本長度會(huì)顯著影響訓(xùn)練速度, 因此2048一般是當(dāng)前預(yù)訓(xùn)練常見的最大長度。
- 位置編碼的外推性: 這里的外推性是指推理長度超過訓(xùn)練長度。包括推理會(huì)出現(xiàn)沒訓(xùn)練過的位置編碼,以及注意力機(jī)制需要處理比訓(xùn)練更長的輸入。已有的旋轉(zhuǎn)位置編碼等相對(duì)位置編碼已經(jīng)具有了外推性,既推理長度可以超過訓(xùn)練長度,但在ALibi位置編碼的測(cè)試中,這種外推性是以大幅性能損失為代價(jià)的。
針對(duì)以上問題本章介紹4種方案:顯式搜索的知識(shí)庫外掛方案,隱式搜索的Unlimiformer, 并行輸入的pcw和并行解碼NBCE。
顯式搜索: 知識(shí)庫外掛
- paper: Unleashing Infinite-Length Input Capacity for Large-scale Language Models with Self-Controlled Memory System
- 看到最無敵的應(yīng)用,文本和表格解析超厲害https://chatdoc.com/?viaurl=ainavpro.com
- ChatGPT代碼實(shí)現(xiàn): https://github.com/arc53/DocsGPT
- ChatGLM代碼實(shí)現(xiàn): https://github.com/imClumsyPanda/langchain-ChatGLM
- 適用于大規(guī)模知識(shí)問答場(chǎng)景
這塊可能是GPT后比較火的方向,有一陣每天都能看到類似的新應(yīng)用,從GPT讀論文,再到百科問答,搭配langchain框架,在DocQA,KBQA的場(chǎng)景簡直無往不利, 以上分別給出了基于ChatGPT和ChatGLM的兩個(gè)實(shí)現(xiàn)方案。
實(shí)現(xiàn)的步驟基本可以被下圖概括
- 長文本解析切分成chunk: 實(shí)際使用過程中發(fā)現(xiàn)文本解析竟然是最核心的部分,能否把需要保留語義完整性的段落拆成整段,能否高質(zhì)量的解析表格,和結(jié)構(gòu)化數(shù)據(jù),對(duì)后續(xù)QA的影響最大
- 文本向量化:中文可用的embedding模型有不少,也可以基于simcse,consert在垂直領(lǐng)域做進(jìn)一步的微調(diào)。在向量化階段主要的問題是文本截?cái)鄮淼纳舷挛膿p失會(huì)影響召回,因此可以嘗試重疊切分,拼接摘要/標(biāo)題等方式
- 向量入庫:需要高效向量檢索的數(shù)據(jù)庫,Milvus、Pinecone,這塊最近也火了一波初創(chuàng)公司
- 用戶問題改寫:在多輪QA的場(chǎng)景,對(duì)話歷史有兩種使用方式,其一使用歷史對(duì)話對(duì)當(dāng)前query進(jìn)行改寫再召回,其二種是使用原始用戶query去召回文本,在回復(fù)階段引入對(duì)話歷史
- 召回:基于用戶query或改寫query進(jìn)行向量化檢索,topK或者閾值召回。除了考慮相關(guān)性,在部分場(chǎng)景也要考慮時(shí)效性,文本質(zhì)量等等
- 答案生成:使用召回文檔拼接用戶query進(jìn)行答案生成,這一步往往還需要用到模型摘要,Refine等能力,核心是對(duì)以上召回的長文本進(jìn)行壓縮
搜索法最大的優(yōu)點(diǎn)是實(shí)現(xiàn)簡單,不過也有許多限制就是只能支持NLU任務(wù),以及會(huì)破壞輸入文本的上下文連續(xù)性,和文本順序。但在大規(guī)模知識(shí)問答這塊算是現(xiàn)在看到最好的方案。
隱式搜索:Unlimiformer
- Unlimiformer: Long-Range Transformers with Unlimited Length Input
- https://github.com/abertsch72/unlimiformer
- 適用于Encoder-Decoder模型,長文本摘要等場(chǎng)景
特意起了個(gè)隱式搜索的標(biāo)題,是因?yàn)楹蜕厦娴奈谋舅阉鲗?shí)現(xiàn)有異曲同工之妙,本質(zhì)的差異只是以上是離散文本塊的搜索。而Unlimiformer是在解碼階段對(duì)超長輸入,token粒度的輸出層embedding進(jìn)行檢索,選擇最相關(guān)的Top Token計(jì)算Attention。
首先對(duì)于超長輸入,unlimiformr采用以上提到的重疊切分的方法,重疊率50%,這樣可以更好保留上文和文本連貫性,例如第一段文本是1-500字,第二段重疊250字取250-750字。然后使用Encoder對(duì)每段文本進(jìn)行獨(dú)立編碼,繞過Attention的平方復(fù)雜度問題。最后輸出每段文本的Embedding,注意這里不是文本整體embedidng, 而是后半部分(250~500字)每個(gè)Token最上層的Embedding,并寫入向量索引,這里用的是Faiss。
在解碼層,每一步解碼,query都會(huì)檢索注意力最高的Top-k個(gè)輸入Token,作為編碼器部分的信息用于解碼器的解碼。這里簡單回憶下Attention計(jì)算, Top-K個(gè)Token就是讓以下注意力取值最高的key。
考慮Decoder的每一層(N層)中的每一個(gè)head(L個(gè)頭)都需要和Encoder的輸出層進(jìn)行交互, 檢索Top Key,如果存儲(chǔ)每一層每個(gè)head的Key,需要構(gòu)建\(O(L*N*seqlen)\)的向量存儲(chǔ)。對(duì)此作者進(jìn)行了優(yōu)化,改變了以下QK的計(jì)算順序,用每一層每個(gè)頭Key的映射矩陣對(duì)Q進(jìn)行映射,這樣只需要存儲(chǔ)一份seq_len的編碼向量(\(h_{encoder}\)),在每一層檢索時(shí)用映射后的Q進(jìn)行檢索既可,其實(shí)就是時(shí)間換空間
unlimiformer提供了代碼實(shí)現(xiàn),核心代碼抽出來看下有兩塊
- 超長文本編碼:對(duì)文本進(jìn)行切塊,分別編碼,取后半部分
for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices:
chunk = input_ids[:, context_start_ind:context_end_ind]
chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind]
hidden_states = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels, return_dict=True)
last_hidden = hidden_states.encoder_last_hidden_state # (batch, chunked_source_len, dim)
to_add = last_hidden[:, update_start_ind:update_end_ind].detach()
to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind]
- 向前計(jì)算檢索Top-key用于Attention矩陣的計(jì)算
def attention_forward_hook(self, module, input, output):
# output: (batch, time, 3 * heads * attention_dim)
with torch.no_grad():
query = self.process_query(output)[:,-1] # (batch * beam, head, dim)
query = query[:, self.head_nums] # (batch * beam, head, dim)
#這是前面提到的計(jì)算優(yōu)化使用每層每個(gè)head的Key映射矩陣對(duì)Query進(jìn)行映射用于搜索
attention_layer_list = self.attention_layer_to_capture(self.layer_begin, self.layer_end)
k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index]
# modify query by k_projs
k_proj = k_proj_layer.weight
k_proj = k_proj.view(1, self.num_heads, query.shape[-1], k_proj.shape[0]) # (1, num_heads, attn_dim, embed_dim)
datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim)
datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim)
datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim)
datastore_query = datastore_query.view((self.datastore.batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim)
# 這里進(jìn)行Top Key的檢索:得到Key的索引,Embedding和得分
top_search_key_scores, top_search_key_indices = self.datastore.search(datastore_query, k=self.actual_model_window_size)
embeddings = torch.take_along_dim(input=self.embeddings.unsqueeze(1),
indices=top_search_key_indices.unsqueeze(-1).to(self.embeddings.device), dim=-2)
##后面就是常規(guī)的對(duì)Embedding進(jìn)行Key和Value的映射然后做Attention了
和前面的文本檢索對(duì)比,unlimiformer的存儲(chǔ)成本會(huì)更高,因?yàn)橐鎯?chǔ)token粒度的Embedding信息,更適用于on-the-fly的長文本推理使用,例如針對(duì)單一文檔的QA,只存儲(chǔ)當(dāng)前文檔,而前面文本塊檢索方案更適合一些大規(guī)模知識(shí),批量的文檔的存儲(chǔ)。
但其實(shí)unlimiformer直接對(duì)Token進(jìn)行離散召回,這一點(diǎn)我讓我有些困惑,這樣單一token的檢索召回,真的不會(huì)破壞上文連續(xù)性么?還是說Encoder編碼方式已經(jīng)保證了檢索召回大概率會(huì)召回成段的Token,又或者說每個(gè)Token的Embedding內(nèi)已經(jīng)充分編碼了連續(xù)上下文的信息,召回離散Token也不會(huì)出現(xiàn)割裂的語義信息?哈哈考慮unlimiformer只支持Encoder-Decoder的框架,和我們用的Decoder框架不適配,我決定不細(xì)糾結(jié)了!有在中文嘗試過效果的童鞋可以分享下~
并行輸入:PCW
- Parallel Context Windows for Large Language Models
- https://github.com/AI21Labs/Parallel-Context-Windows
- 適用于Decoder模型,以及小規(guī)模內(nèi)容理解場(chǎng)景
同樣是對(duì)超長文本進(jìn)行切塊,然后獨(dú)立編碼,PCW使用的是Decoder框架。和unlimiformer只使用Top-Key進(jìn)行解碼,PCW在解碼過程中對(duì)全部輸入上文進(jìn)行Attention。對(duì)比Encoder-Decoder框架,因?yàn)檩斎牒洼敵龆荚贒ecoder側(cè),PCW需要解決兩個(gè)問題:位置編碼和注意力矩陣如何調(diào)整, 下圖基本概括了這兩個(gè)細(xì)節(jié)
position_ids = attention_mask.long().cumsum(-1) - 1
n_task_tokens = position_ids.shape[1] - sum_windows_size
# 保證解碼器的位置編碼比最長上文要長度+1
position_ids[0, -n_task_tokens:] = torch.arange(max_window_size, max_window_size + n_task_tokens, 1)
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: # i.e., first token is already generated
position_ids = position_ids[:, -1].unsqueeze(-1)
elif windows_key_values: # i.e., we are in the first token generation #其實(shí)就是取-n_task_tokens:
position_ids = position_ids[:, sum_windows_size:]
- 注意力矩陣
- 輸入文本進(jìn)行截?cái)嗪蟾髯元?dú)立通過Decoder進(jìn)行編碼。因此每一段輸入的文本的注意力矩陣是相互獨(dú)立的。這塊不需要修改注意力矩陣的實(shí)現(xiàn),只需要文本chunk后分別過模型即可。得到每段文本的past-key-values直接進(jìn)行拼接
def combine_past_key_values(past_lst: List[Tuple[Tuple[torch.Tensor]]],
contains_bos_token: bool = True) -> Tuple[Tuple[torch.Tensor]]:
# 這里past_lst是每段文本的past-key-value
# GPT是n_layer * 2(key+value) * tensor(seq_len,batch,n_head,n_hidden)
# 注意不同模型past-key-value的shape不同
# Chatglm是n_layer * 2(key+value) * tensor(seq_len,batch, n_head, n_hidden)
return tuple(
(torch.cat([c[i][0] for c in past_lst], dim=2),
torch.cat([c[i][1] for c in past_lst], dim=2))
for i in range(len(past_lst[0])))
- 解碼器對(duì)全部上文進(jìn)行Attention計(jì)算:這里需要修改Attention把上文的全部Attention進(jìn)行拼接,讓解碼器的每一步可以對(duì)全部上文計(jì)算Attention
res['past_attention_mask'] = torch.cat([window['attention_mask'] for window in windows], dim=1)
combined_attention_mask = torch.cat((cache['past_attention_mask'], encoded_task_text['attention_mask']), dim=1)
考慮ChatGLM本身是二維的Attention矩陣和位置編碼,特殊的BOS和GMASK,我重寫了PCW,但是在長文本QA問題上表現(xiàn)比較一般,表現(xiàn)在當(dāng)上文多段文本無明顯關(guān)系的時(shí)候例如多個(gè)完全無關(guān)的新聞,在進(jìn)行問答的時(shí)候,正確答案中會(huì)混雜很多無關(guān)的文本變短,以及這個(gè)問題當(dāng)上文片段變多,或者指令問題變多的時(shí)候會(huì)變得越來越嚴(yán)重,直到開始完全胡說八道。當(dāng)然不排除我寫bug了哈哈哈,但我自己是真的沒查出來。
不過也有一種可能,是PCW是在輸入層就開始對(duì)超長上文進(jìn)行Attention,因?yàn)椴煌衔牡奈恢镁幋a相同,一定程度上會(huì)讓解碼注意力變得非常分散,導(dǎo)致注意力的熵值變高,解碼的不確定性變大,更容易出現(xiàn)亂碼。
并行解碼:NBCE
- 蘇劍林. (May. 23, 2023). 《NBCE:使用樸素貝葉斯擴(kuò)展LLM的Context處理長度 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9617
- 蘇劍林. (May. 31, 2023). 《關(guān)于NBCE方法的一些補(bǔ)充說明和分析 》[Blog post]. Retrieved from https://spaces.ac.cn/archives/9632
- https://github.com/bojone/NBCE
- 適用于Encoder-Decoder模型,長文本內(nèi)容理解如摘要問答等場(chǎng)景
壓軸的必須是蘇神的NBCE!這里我把看完博客后的理解進(jìn)行簡單的總結(jié),詳細(xì)推理請(qǐng)看去蘇神的科學(xué)空間!答應(yīng)我一定要去看!每次看蘇神推導(dǎo),都會(huì)覺得數(shù)學(xué)之魂在燃燒!
NBCE的原理簡單解釋如下圖,和PCW相同是對(duì)每段上文進(jìn)行獨(dú)立編碼,但差異在于PCW是在輸入層進(jìn)行融合,而NBCE是在輸出層對(duì)每一個(gè)Step輸出的預(yù)測(cè)token的概率矩陣進(jìn)行融合,更大程度上避免了注意力被分散,保證了解碼的合理性。
既然說了是簡化假設(shè),因此可以對(duì)上式進(jìn)行一些調(diào)優(yōu),核心是讓模型對(duì)上文的解碼更加準(zhǔn)確,降低無關(guān)上文帶來的解碼噪聲,比較重要的優(yōu)化包括
- 準(zhǔn)確率優(yōu)化解碼
以上解碼概率求和,其實(shí)是對(duì)k段文本生成的\(vocab * K\)的概率矩陣,沿K做AvergePooling,得到最終\(vocab*1\)的解碼概率。但考慮LM訓(xùn)練其實(shí)是擬合one-hot(出現(xiàn)概率最高的詞),也就是除了概率最高的幾個(gè)token之外其余token的預(yù)測(cè)概率都不靠譜。如果直接取平均的多路打分,很容易投出一個(gè)在各段文本上打分都不高不低的token,上文越多這個(gè)問題越明顯。但其實(shí)在閱讀理解例如抽取,QA問題的解碼策略上我們要的是在某段文本上打分置信度最高的token,因?yàn)榇鸢竿粊碜砸粋€(gè)上文片段。
因此蘇神給出了兩種準(zhǔn)確率更高的解碼方案,一個(gè)是MaxPooling+GreedySearch,其實(shí)就是對(duì)\(vocab*k\)的概率矩陣取全局概率最高的token,另一個(gè)是最小熵+RandomSampling,也就是從多段上文中取1個(gè)預(yù)測(cè)置信度最高的上文進(jìn)行解碼。這里其實(shí)是和PCW最大的差異,也就是在解碼層進(jìn)行融合,并通過熵值較低的融合策略來保證解碼的準(zhǔn)確率。
以及后面蘇神還通過Top-P來進(jìn)一步過濾尾部的噪聲,以及通過控制每一步解碼的轉(zhuǎn)移概率,來讓解碼器不會(huì)在不同上文片段之間反復(fù)切換,而是保證連續(xù)的解碼片段大概率來自相同的上文片段。
- Context-aware解碼
基于上文來進(jìn)行解碼的一個(gè)核心是為了降低模型回答胡說八道的概率。例如在金融場(chǎng)景我們直接問chatgpt基金贖回費(fèi)用是多少 vs 我們基于某個(gè)基金的介紹問模型該基金的贖回費(fèi)用是多少,后者得到的答案一定是更準(zhǔn)確的。而其實(shí)以上二者的差異在于條件(上文)解碼和無條件解碼, 因此可以通過diff無條件編碼的方式來提高解碼對(duì)上文的依賴程度(reliablity)。如下圖
因此蘇神把把n變成超參Beta, 控制條件概率和無條件概率的占比,Beta越高解碼和上文的關(guān)聯(lián)度越高,QA等場(chǎng)景的解碼準(zhǔn)確率越高,生成自由度越低。
當(dāng)前NBCE的局限性在于無法處理上文片段之間的位置關(guān)系,以及無法處理解碼需要依賴多個(gè)上文片段的場(chǎng)景。后者感覺可以通過預(yù)測(cè)概率矩陣的相關(guān)性修改Pooling方式,而前者
基于蘇神提供的代碼,在chatglm上做了嘗試,只需要簡單調(diào)整下輸入輸出的部分就可以直接使用。我在論文,書籍,和新聞上進(jìn)行摘要,實(shí)體抽取和QA問答后發(fā)現(xiàn),INT8量化的模型效果似乎要略優(yōu)于FP16, 顯著優(yōu)于INT4。INT8量化下,10K左右的輸入,顯存占用基本可以限制在單卡A100(40g),大家可以自行嘗試下~
@torch.inference_mode()
def generate(max_tokens):
device = torch.device('cuda')
"""Naive Bayes-based Context Extension 演示代碼
"""
inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
input_ids = inputs.input_ids
n = input_ids.shape[0]
with torch.no_grad():
for i in range(max_tokens):
# 模型輸出
model_input = model.prepare_inputs_for_generation(input_ids)
outputs = model(**model_input,
return_dict=True,
use_cache=True
)
"""
中間代碼不變
"""
# 把唯一的回答擴(kuò)充到每一個(gè)batch進(jìn)行下一輪的解碼
next_tokens = next_tokens.unsqueeze(-1).tile(n, 1)
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
# 更新past-key-values, 更新attention_mask, 更新position_ids
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
想看更全的大模型相關(guān)論文梳理·微調(diào)及預(yù)訓(xùn)練數(shù)據(jù)和框架·AIGC應(yīng)用,移步Github >>?DecryptPropmt文章來源:http://www.zghlxwxcb.cn/news/detail-480690.html
Reference文章來源地址http://www.zghlxwxcb.cn/news/detail-480690.html
- https://blog.langchain.dev/langchain-chat/
- https://blog.frankzhao.cn/build_gpt_bot_for_doc/
- https://zhuanlan.zhihu.com/p/616620170
- ALiBi:Train short, test long:attention with linear bias enables input length extrapolation
- https://github.com/ofirpress/attention_with_linear_biases
- Trusting Your Evidence: Hallucinate Less with Context-aware Decoding
到了這里,關(guān)于解密Prompt系列8. 無需訓(xùn)練讓LLM支持超長輸入:知識(shí)庫 & unlimiformer & PCW & NBCE的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!