去年我們梳理過OpenAI,Anthropic和DeepMind出品的經(jīng)典RLHF論文。今年我們會(huì)針對(duì)經(jīng)典RLHF算法存在的不穩(wěn)定,成本高,效率低等問題討論一些新的方案。不熟悉RLHF的同學(xué)建議先看這里哦解密Prompt7. 偏好對(duì)齊RLHF-OpenAI·DeepMind·Anthropic對(duì)比分析
RLHF算法當(dāng)前存在的一些問題有
- RL的偏好樣本的人工標(biāo)注成本太高,效率低,容易存在標(biāo)注偏好不一致的問題
- RLHF屬于online訓(xùn)練策略,在訓(xùn)練過程中需要讓模型進(jìn)行解碼,時(shí)間成本高訓(xùn)練效率低
- RLHF在訓(xùn)練過程中需要同時(shí)部署Reward模型和SFT模型和更新后的模型,顯存占用高訓(xùn)練成本高
- RLHF需要兩階段的訓(xùn)練,需要先訓(xùn)練reward模型,再使用reward模型更新SFT模S型
這一章我們先聊聊訓(xùn)練策略的新方案。用新方案而不是優(yōu)化或者改良,因?yàn)?strong>平替?zhèn)?/strong>的效果需要更長時(shí)間的驗(yàn)證。
SLiC-HF
- SLiC-HF: Sequence Likelihood Calibration with Human Feedback
- CALIBRATING SEQUENCE LIKELIHOOD IMPROVES CONDITIONAL LANGUAGE GENERATION
要說SLiC-HF,肯定要先說下前置的Calibartion Sequence likelihood(SLiC)的對(duì)齊技術(shù),畢竟上面這兩篇論文的部分作者都相同,思路自然是一脈相承。
SLiC
SLiC對(duì)標(biāo)SFT,也是post-training的指令對(duì)齊方案。方案針對(duì)指令微調(diào)階段使用MLE也就是next token prediction帶來的稀疏訓(xùn)練問題。因?yàn)榻o定context,是有無數(shù)種output可能的。而微調(diào)階段只使用唯一的答案進(jìn)行訓(xùn)練,導(dǎo)致模型訓(xùn)練不充分。一個(gè)明顯的現(xiàn)象就是序列的解碼概率越高,并不意味著生成序列的質(zhì)量越好,這意味著生成序列其實(shí)是未修正的(uncalibrated)
SLiC的思路有些類似半監(jiān)督。也就是標(biāo)注數(shù)據(jù)有限,導(dǎo)致模型參數(shù)更新的空間有限的情況下,我們可以使用半監(jiān)督的平滑性和一致性原則,既和標(biāo)注樣本相似的樣本label相同,反之不同的思路,使用無標(biāo)注樣本對(duì)模型進(jìn)行更新
那我們把半監(jiān)督的思路放到文本生成:
第一步.先使用SFT對(duì)齊后的模型,針對(duì)標(biāo)注樣本,每個(gè)樣本生成m個(gè)推理候選結(jié)果,這些就是半監(jiān)督中的未標(biāo)注樣本
第二步.使用無監(jiān)督樣本進(jìn)行對(duì)比訓(xùn)練,核心就是訓(xùn)練模型對(duì)和標(biāo)注答案更相似的候選樣本給予更高的解碼概率,反之更低
這里訓(xùn)練就有兩個(gè)細(xì)節(jié)
- 序列相似如何定義?這里沒有引入新的向量模型,直接使用大模型解碼輸出層的向量表征(seq * hidden)和標(biāo)注結(jié)果的向量表征來計(jì)算cosine相似度,相似度計(jì)算參考了BertScore的F1值。并且這里對(duì)序列進(jìn)行了切分,分別計(jì)算span=1,2,4,8等不同長度的F1值,再進(jìn)行聚合。
- 損失函數(shù)如何定義?論文嘗試了以下4種不同的對(duì)比損失函數(shù),主要差異在pair-wise還是list-wise,擬合相似度的相對(duì)排序(i-j),還是絕對(duì)打分(P(yi|x)-P(yj|x))的高低。消融實(shí)驗(yàn)顯示第一個(gè)Rank Loss的效果最好。也就是從所有解碼生成的候選中隨機(jī)采樣兩個(gè),以上F1更高的為正樣本,反之為負(fù)樣本。計(jì)算解碼概率的Hinge-Loss
這里論文同樣加入了正則項(xiàng),避免模型過度偏離原始SFT對(duì)齊的模型,分別嘗試了KL和MLE兩種不同的正則。消融實(shí)驗(yàn)顯示KL正則項(xiàng)的效果更好。
所以綜上SLiC使用了無監(jiān)督的思路,用對(duì)比學(xué)習(xí)來進(jìn)行對(duì)齊。下面我們來看如何使用SLiC來對(duì)齊人類偏好
SLiC-HF
偏好樣本
首先SLiC-HF用的是offline的訓(xùn)練方案,所以先說下偏好樣本是如何構(gòu)建的。論文嘗試了Direct和Sample and Rank兩種樣本構(gòu)建方案。
Direct方案就是直接使用Reddit摘要數(shù)據(jù)集中人工標(biāo)注的正負(fù)偏好樣本作為\(y^+,y^-\),優(yōu)點(diǎn)是成本低,缺點(diǎn)是這里的解碼結(jié)果可能和SFT模型的解碼分布存在偏差。
Sample and Rank,也就是先使用以上偏好數(shù)據(jù)訓(xùn)練Reward模型,論文嘗試了兩種方案,一個(gè)是絕對(duì)偏好,模型預(yù)測Good/Bad使用解碼概率作為label。另一個(gè)是相對(duì)偏好,也就是模型學(xué)習(xí)兩個(gè)摘要之間的相對(duì)好壞。
之后使用SFT模型隨機(jī)解碼(temperature=0.7)生成的8個(gè)解碼候選,使用以上模型打分或排序后,隨機(jī)采樣8個(gè)正負(fù)樣本對(duì)。
效果上Sample and Rank要優(yōu)于Direct,但如果Driect部分是直接使用SFT模型生成候選再人工標(biāo)注的話,其實(shí)結(jié)果可能也不差。
損失函數(shù)
已經(jīng)有了正負(fù)樣本對(duì),那其實(shí)只需要用到上面的對(duì)比損失函數(shù)了,不需要使用半監(jiān)督了。不過這里的正則器沒有選用KL,而是直接使用SFT樣本的MLE來防止模型能力衰減。最終的損失函數(shù)如下
除了Offline的樣本構(gòu)建訓(xùn)練效率更高之外,SLiC-HF直接使用序列概率表征偏好,因此不需要使用reward模型,同時(shí)對(duì)比來自樣本而非來自模型,因此也不再需要使用凍結(jié)參數(shù)的SFT模型。訓(xùn)練過程內(nèi)容中只有一個(gè)SFT模型進(jìn)行梯度更新。
DPO
- Direct Preference Optimization: Your Language Model is Secretly a Reward Model
- https://github.com/eric-mitchell/direct-preference-optimization
- https://github.com/huggingface/trl/blob/0a6c42c12c637bb7f28782fa72ec45dd64bce0bd/trl/trainer/dpo_trainer.py
DPO和SLiC同樣是基于offline的正負(fù)偏好樣本對(duì),通過對(duì)比學(xué)習(xí)來進(jìn)行偏好對(duì)齊。DPO的偏好樣本標(biāo)注是直接基于SFT模型生成候選,然后人工標(biāo)注得到正負(fù)(win,loss)樣本對(duì),然后直接使用損失函數(shù)進(jìn)行擬合,不訓(xùn)練reward模型。不過二者的對(duì)比損失函數(shù)不同,DPO的損失函數(shù)如下
以上\(\pi\)是模型解碼輸出層每個(gè)token
的輸出概率logp求和,\(\theta\)是參與梯度更新的模型,ref是SFT對(duì)齊后的模型參數(shù)作為基準(zhǔn)參數(shù)被凍結(jié)。
所以簡單直觀的理解也就是DPO的損失函數(shù),讓模型對(duì)偏好樣本的解碼概率相比ref升高,讓模型對(duì)負(fù)樣本的解碼概率相比ref下降。和Triplet Loss的對(duì)比損失函數(shù)的思路有些相似。
我們和SLiC-HF做下對(duì)比,首先SLiC是hinge-loss(maximum-margin),DPO不是。其次SLiC是正負(fù)樣本直接對(duì)比,DPO是正負(fù)樣本概率分別和基準(zhǔn)模型(SFT模型)進(jìn)行對(duì)比,二者的差異有些類似simases和triplet loss,只不過DPO的錨點(diǎn)不是錨點(diǎn)樣本而是基準(zhǔn)模型。所以模型既需要擬合相對(duì)偏好,也需要保證絕對(duì)分布不會(huì)答復(fù)偏離原始SFT模型。在后面的一些對(duì)比論文中普遍結(jié)論是DPO的損失函數(shù)更優(yōu),SLiC的對(duì)比函數(shù)會(huì)導(dǎo)致一些reward hacking
論文還進(jìn)一步從梯度計(jì)算的角度進(jìn)行了闡述,如果上述損失函數(shù)對(duì)\(\theta\)求導(dǎo)。會(huì)得到以下公式
其中\(\hat{r_{\theta}}(x,y)=\beta log(\frac{\pi_{\theta}(y|x)}{\pi_{ref}(y|x)})\)是DPO的核心,既對(duì)齊模型的輸出層的概率偏離原始SFT模型的幅度能隱式表征偏好,作為 pseudo Reward來進(jìn)行模型對(duì)齊。正負(fù)樣本差異越大越多更新幅度越大,梯度方向是提高偏好樣本的解碼概率,降低負(fù)樣本的解碼概率。
RRHF
- RRHF: Rank Responses to Align Language Models with Human Feedback without tears
- https://github.com/GanjinZero/RRHF
RRHF同樣是offline構(gòu)建正負(fù)樣本對(duì),再采用對(duì)比學(xué)習(xí)進(jìn)行偏好對(duì)齊的方案,那這里我們只看RRHF和SLiC的差異點(diǎn)。
其一是RRHF使用了長度歸一化的序列概率來表征偏好,SLiC直接使用了解碼概率
其二是SLiC使用了Hinge-Loss,而RRHF是直接擬合正負(fù)樣本的概率差
其三是正負(fù)樣本的構(gòu)建方案,SLiC是基于SFT模型進(jìn)行隨機(jī)解碼生成候選,并基于Reward模型離線構(gòu)建正負(fù)樣本,而RRHF的候選采樣方案還對(duì)比了beam-search,diversity-beam-search,以及Iterate-beam-search,也就是每訓(xùn)練一個(gè)epoch基于微調(diào)后的模型重新生成一波候選。Iterate-beam-search的采樣方案會(huì)有一些效果提升,考慮生成樣本會(huì)隨分布修正而逐漸優(yōu)化,可以覆蓋更多的分布空間。以及Iterate-beam-search其實(shí)和PPO在線解碼進(jìn)行模型更新的方案更加相似,但相對(duì)效率更高。
三合一大禮包- RSO
STATISTICAL REJECTION SAMPLING IMPROVES PREFERENCE OPTIMIZATION
RSO方案融合了以上三者,主要是DPO和SLiC,分別對(duì)損失函數(shù)和偏好樣本對(duì)的構(gòu)建方式進(jìn)行了改良。先說損失函數(shù),RSO把SLiC的Hinge-loss加入到DPO的sigmoid-norm損失函數(shù)中,得到了如下的hinge-norm損失函數(shù)
再有是偏好樣本構(gòu)建,RSO指出既然以上對(duì)比函數(shù)的目標(biāo)是擬合最優(yōu)的Policy,那理論上偏好樣本對(duì)也應(yīng)該從\(\pi^*\)來構(gòu)建。近似于以上RRHF的Iterate-beam-search的最后一個(gè)Iterate的樣本分布。但\(\pi^*\)還沒訓(xùn)練出來要如何拿到它的對(duì)比樣本呢?
這里RSO提出可以采用從\(\pi_{SFT}\)中拒絕采樣來近似\(\pi_{r}\)的分布,對(duì)比SLiC的SFT-sample-rank,稱之為RSO-Sample-Rank。具體構(gòu)建方式還是從SFT生成多個(gè)解碼候選,并使用訓(xùn)練的Reward模型對(duì)每個(gè)候選進(jìn)行打分,接著進(jìn)行拒絕采樣。
首先拒絕采樣使用g(x)擬合f(x), 計(jì)算一個(gè)常數(shù)C,使得\(c*g(x)>=f(x)\)。則采樣過程是從g(x)中采樣,當(dāng)隨機(jī)變量\(U\sim(0,1)<=\frac{f(x)}{c*g(x)}\)則保留樣本,反之拒絕。
這里g(x)就是SFT模型\(\pi_{sft}\),f(x)是最終對(duì)齊的模型\(\pi_{r_{\tau}}\),理論上\(m*\pi_{sft}>=\pi_{r_{\tau}}\),這樣當(dāng)\(U<= \frac{\pi_{r_{\tau}}}{m*\pi_{sft}}\)我們保留樣本,但因?yàn)檫@里的的\(\pi_{r_{\tau}}\)并無法獲得,因此我們用DPO中推導(dǎo)的Policy和reward的關(guān)系
為了diff掉正則項(xiàng)Z,論文使用所有隨機(jī)解碼樣本的最大reward的(x,y)來作為常數(shù)C的估計(jì)。
最終得到的拒絕采樣的代碼如下
效果上論文對(duì)比了DPO,SLiC,RSO,以及不同損失函數(shù),不同采樣方案的效果差異。整體上采樣帶來的收益是更為顯著,DPO的損失函數(shù)上加不加hinge差異并不大,但都會(huì)優(yōu)于SLiC的直接對(duì)比損失函數(shù)。
文章來源:http://www.zghlxwxcb.cn/news/detail-835609.html
想看更全的大模型相關(guān)論文梳理·微調(diào)及預(yù)訓(xùn)練數(shù)據(jù)和框架·AIGC應(yīng)用,移步Github >> DecryPrompt文章來源地址http://www.zghlxwxcb.cn/news/detail-835609.html
到了這里,關(guān)于解密prompt系列24. RLHF新方案之訓(xùn)練策略:SLiC-HF & DPO & RRHF & RSO的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!