論文鏈接:Continual Learning with Pre-Trained Models: A Survey
代碼鏈接:Github: LAMDA-PILOT
持續(xù)學習(Continual Learning, CL)旨在使模型在學習新知識的同時能夠保留原來的知識信息了,然而現(xiàn)實任務中,模型并不能很好地保留原始信息,這也就是常說的災害性遺忘(Catastrophic forgetting)問題。傳統(tǒng)的CL方法需要從頭開始訓練模型(從隨機初始化參數(shù)開始訓練),目前基于大規(guī)模數(shù)據(jù)訓練得到的預訓練模型為持續(xù)學習帶來了新的研究思路,預訓練模型魯棒的泛化性給予新任務學習較為成熟的參數(shù),也因此基于預訓練模型的CL方法已逐漸成為研究熱點。
作者將基于預訓練模型的CL方法分為三種:Prompt-based方法、representation-based方法和model mixed-based方法。
1. Prompt-based 方法
在使用模型全局tuning的方式適應下游任務時,預訓練模型的泛化性能會被嚴重削弱,因此Prompt-based方法在保持預訓練模型參數(shù)權(quán)重不變的條件下, 增加額外可學習的Prompt tuning 模塊來實現(xiàn)對下游任務的泛化,這樣就能較好地保持原模型的泛化性能。
Vison Prompt Tuning(VPT)在feature patch中串聯(lián)了一組可學習的參數(shù)P,然后使用最小化交叉熵損失的方式將特定任務的信息嵌入到預訓練模型中。
VPT這種方式雖然可以較好地保留模型的泛化性,但是,在面對新的任務時,以往的Prompt模塊的知識同樣被覆蓋,依舊遭遇了災難性以往問題。為此,有學者提出了Prompt Pool的概念,設計了Prompt模塊的集合,即P={P1,P2,…,Pm}(m表示該Pool的最大尺寸)。Prompt Pool的思想有效避免了單一Prompt的問題,但是Pool的設計使得其需要進行Prompt Selection操作,也就是需要將特定任務與其對應的Prompt模塊進行索引匹配。
L2P算法是一種較為常用的Prompt selection算法,該算法設計了一種Key-Query的Prompt匹配方法,也就是為每一個Prompt提供一個可學習的索引鍵k,即P={(k1,P1),(k2,P2),…,(km,Pm)。L2P利用預訓練模型將輸入特征編碼到Key對用的嵌入空間中,然后利用余弦距離損失函數(shù)在已有的Pool中搜索最近似的Key。接著,利用如交叉熵損失等方法對搜索到的Key對應的Prompt進行進行優(yōu)化。
類似的Prompt Selection 算法很多,如DualPrompt算法,該算法將Prompt進行解耦,分化為General Prompt和Expert Prompt。General Prompt面向所有任務,為所有任務中共享信息,而Expert Prompt針對獨立任務,數(shù)量與任務量一致。其采用了和L2P相同的key-query匹配策略。
Prompt Selection雖然可行,但仍是硬匹配,選項有限?;谧⒁饬π畔⒓訖?quán)的Prompt Combination方法則有效緩解了該問題。如CODA-Prompt通過對Prompt Pool進行注意力機制嵌入,為每個注意力賦予自適應權(quán)重,進而求算全局Key-Query的加權(quán)和,實現(xiàn)可學習式Prompt組合。我覺得稀疏式注意力Prompt combination應該也是很有趣的研究。
從根本上來說Prompt Combination仍受制于Prompt Pool的范圍。為此, 許多學者則開展Prompt Generation有關(guān)的研究,如DAP,其利用MLP進行特定任務提示信息的編碼生成。
優(yōu)點:
- Prompt 有助于彌合domain gap,并可有效地對特定任務的知識進行編碼。
- Prompt Design 屬于lightweight模塊,與input feature具有相同的維度,因此保存Prompt是parameter-efficient,適用于邊緣場景。
- Prompt Pool作為預訓練模型的外部存儲器,其支持自適應知識的檢索和特定實例的預測。
缺點:文章來源:http://www.zghlxwxcb.cn/news/detail-858232.html
- 一些研究]發(fā)現(xiàn)L2P中的prompt selection過程收斂到一個單點,使得prompt selection只集中在特定子集上。
- 由于key和query在整個學習過程中不斷變化,這些參數(shù)的更新將會消除先前任務的參數(shù),導致matchimg-level和prompt-level的遺忘,使prompt selection成為CL的瓶頸。
- 固定大小的Prompt Pool會使得模型的表示能力受限。但是,若Prompt Pool隨著數(shù)據(jù)的發(fā)展而增長,可能會為舊任務檢索新的提示,導致訓練和測試之間的不匹配。
- 最后,一些研究發(fā)現(xiàn)prompt-based CL的性能低于簡單的representation-based的baseline性能。并且批量提示有損比較的公平性。
2. Representation-based 方法
representation-based方法直接利用預訓練模型強大的泛化性和通用性來實現(xiàn)持續(xù)學習。比如Simple-CIL方法,該算法是ADAM算法原文中提出的Baseline,Simple-CIL凍結(jié)預訓練模型參數(shù),并通過求算類別中心的方式來構(gòu)建Classifier。具體來說,在面對很多類別時,計算同類的embedding或features的平均值,并將該平均值作為該類別的標準(prototype),最后結(jié)合類別標準與余弦比較的方法替換模型的原始Classifier。
雖然基于prototype的方法存在一定的作用,但是并未很好地適應下游任務。為此,一些研究在基于prototype方法的基礎上結(jié)合了外置參數(shù)高效調(diào)節(jié)模塊或者外置適配器來使得預訓練模型更加適應下游任務,如ADAM等。
ADAM等算法在進行類別標準設定時,類別標準之間的仍存在聯(lián)系,導致任務效果降低。為此,RanPAC算法則采用online LDA classifier來去除原始方法prototype計算結(jié)果之間的相關(guān)性,加大類別間的分布差異。此外,RanPAC算法利用Random Projection layer將features映射到高維空間中,并在高維空間中進行prototype的計算,以使得特征分布符合高斯擬合。
相較于前面將預訓練模型的通用語和適應性分離處理的方式,SLCA算法采用了差異學習率調(diào)整和特征經(jīng)驗重播的方式進行持續(xù)學習研究。該算法使用較小的learn rate調(diào)整模型主體部分,而使用較大的learn rate 調(diào)節(jié)模型的classifier,以實現(xiàn)模型的逐步微調(diào)和classifier的快速適應。為了避免忘記以前的分類器,SLCA還對分類特征分布進行建模,并重播它們以校準classifier。
優(yōu)點:
由于class prototype代表了對應類別最常見的標準格式,因此利用其構(gòu)建模型具有直觀和可解釋性。
Representation-based 方法主要是凍結(jié)backbone和更新classifier權(quán)重。lightweight的更新成本增加了其現(xiàn)實應用的可行性。
缺點:
將不同模型的特征連接起來形成class prototype,容易造成模型信息冗余。例如,不同的backbone中存在重復提取共享特征。
當下游任務涉及多個領(lǐng)域時,在第一階段調(diào)整模型不足以彌合數(shù)據(jù)集之間的領(lǐng)域差距。在這種情況下,不斷調(diào)整backbone可能更適合提取特定于任務的特征。
3. Model Mixture-based 方法
Model Mixture-based 方法在持續(xù)學習工程中構(gòu)建了一組模型,然后再推理階段通過Model Ensemble和Model Merge來進行信息綜合決策。
Model Ensemble中,ESN算法憑借預訓練模型強大的通用性,構(gòu)建多個classifier,在面對新任務重新初始化和訓練一個新的classifier。在推理時,采用投票策略來整合多個模型的結(jié)果進行最終決策。
由于Model Ensemble的核心因素取決于模型的方差,一些研究通過增強模型之間的多樣性來替代使用相同的預訓練模型構(gòu)建不同的classifier。如PromptFusion利用預訓練的ViT和CLIP,并在推理過程中動態(tài)地對logit進行組合,即f(x) = λ fvit (x) +(1?λ)fclip(x)。
與多個backbone的集成不同,PROOF采用了僅使用單個CLIP的更全面的推理方法。由于CLIP支持視覺和文本特征的跨模態(tài)匹配,因此PROOF設計了一個三層集成,考慮image-to-text、image-to-image prototype、image-to-adjusted text的跨模態(tài)融合。
Model Merge將多個不同的模型合并為一個統(tǒng)一的模型,無需要額外的訓練。LAE定義了online和offline學習協(xié)議,online模型通過交叉熵損失進行更新,目的是在新的任務中獲取新的知識。離線模型則通過Model Merge進行更新,例如指數(shù)移動平均(EMA): θ offline←α·θ offline +(1?α)·θ Online,其中α為權(quán)衡參數(shù)。LAE僅將EMA應用于參數(shù)高效調(diào)諧模塊(如prompt),其利用online和offline模型的最大logit進行推斷。
與LAE一樣,ZSCL將合并技術(shù)應用于CLIP模型,目的是在持續(xù)學習過程中保持其zero-shot性能。然而,隨著EMA中權(quán)衡參數(shù)的改變,CLIP性能不再具有魯棒性。因此,ZSCL建議每隔幾次迭代合并參數(shù),從而在模型訓練期間創(chuàng)建平滑的損失軌跡。
此外,CoFiMA注意到EMA在Merge過程中對每個參數(shù)的重要性是相等的,CoFiMA 在Merge過程中插入Fisher information(費雪信息)作為每個參數(shù)的估計重要性。
優(yōu)點:
- 學習多個模型可以做出不同的決策。因此,使用Model Ensemble和Model Merge自然會產(chǎn)生更健壯的結(jié)果。
- 由于直接合并模型進行統(tǒng)一預測,因此可以調(diào)整前模型和后模型的權(quán)重,以突出不同階段之間知識共享的重要性。
- 由于模型集將在推理過程中合并,因此最終的推理成本不會隨著模型集中添加更多模型而增加。
缺點:
- Model Ensemble需要保存所有的歷史模型,并消耗大量的內(nèi)存緩沖區(qū)。雖然基于Model Merge不需要這么大的成本,但合并大型backbone的權(quán)重也需要大量的額外計算。
- 決定Merge哪些參數(shù)仍然是問題。
閱讀記錄
文章來源地址http://www.zghlxwxcb.cn/news/detail-858232.html
到了這里,關(guān)于【論文筆記】基于預訓練模型的持續(xù)學習(Continual Learning)(增量學習,Incremental Learning)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!