1. 文章簡介
- 標(biāo)題:Prompt Consistency for Zero-Shot Task Generalization
- 作者:Chunting Zhou, Junxian He, Xuezhe Ma, Taylor Berg-Kirkpatrick, Graham Neubig
- 日期:2022
- 期刊:Arxiv preprint
2. 文章概括
??文章基于prompt的一致性學(xué)習(xí)給出了一種zero-shot task generalization(零樣本泛化學(xué)習(xí))的無監(jiān)督方法。數(shù)值實(shí)驗(yàn)表明,文章提出的指令一致性學(xué)習(xí)方法只需在幾個(gè)prompt、幾十個(gè)樣本上進(jìn)行訓(xùn)練,就可以在NLI等NLP任務(wù)上追平SOTA水平。
??文章整體架構(gòu)如下
3 文章重點(diǎn)技術(shù)
3.1 Prompt-based zero-shot task generalization
??首先簡單介紹下zero-shot task generalization(零樣本泛化學(xué)習(xí)):給定輸入
x
∈
X
x\in \mathcal{X}
x∈X,零樣本泛化學(xué)習(xí)旨在學(xué)習(xí)一個(gè)預(yù)訓(xùn)練模型PLM預(yù)測出
y
∈
Y
y\in \mathcal{Y}
y∈Y,其中PLM未在數(shù)據(jù)集
X
\mathcal{X}
X上訓(xùn)練過。零樣本泛化學(xué)習(xí)要求模型可以泛化出一個(gè)新的表達(dá)式
f
:
X
→
Y
f: \mathcal{X} \to \mathcal{Y}
f:X→Y,而非僅僅在數(shù)據(jù)集上具有泛化能力。
??給定prompt
r
r
r,
r
r
r包含一個(gè)輸入模板
r
x
r_x
rx?、輸出模板
r
y
r_y
ry?以及待放入模板的元數(shù)據(jù)
x
,
y
x, y
x,y,我們可以得到prompt-based輸入:
r
x
(
x
)
,
r
y
(
y
)
r_x(x), r_y(y)
rx?(x),ry?(y)?;趐rompt的學(xué)習(xí)方法一般用
p
θ
(
r
y
(
y
)
∣
r
x
(
x
)
)
p_{\theta} (r_y(y)|r_x(x))
pθ?(ry?(y)∣rx?(x))來計(jì)算輸出的概率
q
(
y
∣
x
,
r
)
)
q(y|x, r))
q(y∣x,r)),其中
θ
\theta
θ表示模型的參數(shù)。本文重點(diǎn)關(guān)注NLP的分類任務(wù),則可以通過如下公式計(jì)算輸出的概率:
q
(
y
∣
x
,
r
)
=
p
θ
(
r
y
(
y
)
∣
r
x
(
x
)
)
∑
y
′
∈
Y
p
θ
(
r
y
(
y
′
)
∣
r
x
(
x
)
)
(1)
q(y|x, r) = \frac{p_{\theta} (r_y(y)|r_x(x))}{\sum_{y'\in\mathcal{Y}} p_{\theta} (r_y(y')|r_x(x))}\tag{1}
q(y∣x,r)=∑y′∈Y?pθ?(ry?(y′)∣rx?(x))pθ?(ry?(y)∣rx?(x))?(1)。
3.2 Prompt Consistency Training
?? 文章的方法需要無標(biāo)注的數(shù)據(jù)集
{
x
1
,
…
,
x
N
}
\{x_1, \dots, x_N\}
{x1?,…,xN?}和
K
K
K個(gè)不同的prompt
{
(
r
x
1
,
r
y
1
)
,
…
,
(
r
x
K
,
r
y
K
)
}
\{(r_x^1, r_y^1), \dots, (r_x^K, r_y^K)\}
{(rx1?,ry1?),…,(rxK?,ryK?)}。其中無標(biāo)注的數(shù)據(jù)集可以來自任意NLP(分類)任務(wù)的訓(xùn)練數(shù)據(jù)集或測試數(shù)據(jù)集,也可以來自我們要測試的任務(wù)的數(shù)據(jù)集。prompt可直接采用Public Pool of Prompts(p3)數(shù)據(jù)集里的prompt。
?? 傳統(tǒng)的一致性訓(xùn)練會擾亂樣本,使得擾亂后的樣本和之前的樣本得到的輸出盡可能一致。本文希望學(xué)習(xí)prompt級別的一致性,即不同prompt在單個(gè)樣本上的學(xué)習(xí)結(jié)構(gòu)盡可能一致。這樣做可以1) 概念非常簡單 2)緩解PLM“輸入不同prompt結(jié)果不一致”的問題。
??損失函數(shù)定義如下
L
=
?
E
x
∈
p
d
(
x
)
E
r
i
,
e
r
j
∈
p
(
r
)
E
y
^
∈
q
^
(
y
∣
x
,
r
i
)
log
?
p
θ
(
r
y
j
(
y
^
)
∣
r
x
j
(
x
)
)
\mathcal{L} = -\mathbb{E}_{x\in p_d(x)} \mathbb{E}_{r^i, er^j\in p(r)} \mathbb{E}_{\hat{y} \in \hat{q}(y|x,r^i)} \log p_{\theta} (r_y^j(\hat{y})|r_x^j(x))
L=?Ex∈pd?(x)?Eri,erj∈p(r)?Ey^?∈q^?(y∣x,ri)?logpθ?(ryj?(y^?)∣rxj?(x))
,
p
d
p_d
pd?是數(shù)據(jù)集的分布,
p
(
r
)
p(r)
p(r)表示
K
K
K個(gè)prompt的隨機(jī)prompt對的均勻分布,
q
^
\hat{q}
q^?定義為式(1)的條件分布。這里簡單解釋下,如圖所示,給定prompt
r
i
,
r
j
r^i, r^j
ri,rj,我們首先預(yù)測
y
^
∈
q
^
(
y
∣
x
,
r
i
)
\hat{y}\in \hat{q}(y|x, r^i)
y^?∈q^?(y∣x,ri),即當(dāng)promt為
r
i
r^i
ri時(shí)得到輸出
y
^
\hat{y}
y^?。當(dāng)prompt為
r
j
r^j
rj時(shí),我們希望最大化輸出結(jié)果為
y
^
\hat{y}
y^?(即和
r
i
r^i
ri輸出相同)的概率
p
θ
(
r
y
j
(
y
^
)
∣
r
x
j
(
x
)
)
p_{\theta} (r_y^j(\hat{y})|r_x^j(x))
pθ?(ryj?(y^?)∣rxj?(x)),取負(fù)對數(shù)和期望之后,即得到上述損失函數(shù)。我們稱上述訓(xùn)練方法為swarm distillation。
3.3 如何防止遺忘和退化?
??如果直接采用上述方法進(jìn)行訓(xùn)練,則我們很容易collapse,得到一個(gè)平凡解:所有prompt、所有樣本均輸出同一個(gè)結(jié)果可以實(shí)現(xiàn)損失函數(shù)最小。另一方面,訓(xùn)練后的模型可以能忘記之前的知識,即castrophic forgetting。為了避免collapse和catastrophic forgetting,文章提出下述兩種方法:
- LoRA:文章是在T0模型上層進(jìn)行訓(xùn)練的,為了不發(fā)生災(zāi)難性遺忘,文章采用了LoRA方法,即通過兩個(gè)低階矩陣的乘積進(jìn)行迭代學(xué)習(xí),具體如下圖所示。在實(shí)際訓(xùn)練時(shí)我們將LoRA應(yīng)用到Transformer每一個(gè)前饋層。
- Fleiss’ Kappa:由于我們沒有標(biāo)注數(shù)據(jù)作為validation set,從而很難選擇一個(gè)最佳的checkpoint作為最終模型。為此文章采用了Fleiss’ Kappa指標(biāo)來度量模型的效果。首先,我們定義一致性概率。對給定的樣本 x i x_i xi?,記所有 K K K個(gè)prompt中預(yù)測輸出為第 j j j個(gè)label的prompt數(shù)量為 n i j n_{ij} nij?,則對該樣本,任意兩個(gè)prompt給出相同的預(yù)測結(jié)果的概率為 p i = ∑ j ( n i j 2 ) / ( K 2 ) = ∑ j n i j ( n i j ? 1 ) / K ( K ? 1 ) p_i = \sum_j \binom {n_{ij}}2 /\binom K2 = \sum_{j} n_{ij}(n_{ij} - 1) / K(K-1) pi?=j∑?(2nij??)/(2K?)=j∑?nij?(nij??1)/K(K?1),所有樣本的絕對一致性為 P  ̄ = ∑ i p i \overline{P} = \sum_i p_i P=∑i?pi?。另一方面,第 j j j個(gè)label的占比為 q j = ∑ i n i j / N K q_j = \sum_i n_{ij}/NK qj?=∑i?nij?/NK,則 P  ̄ e = ∑ j q j 2 \overline{P}_e = \sum_j q_j^2 Pe?=∑j?qj2?表示任意兩個(gè)prompts按照標(biāo)簽的分布隨機(jī)預(yù)測結(jié)果一致的概率。當(dāng)所有 q j q_j qj?均相等時(shí), P  ̄ e \overline{P}_e Pe?最小,即預(yù)測的標(biāo)簽隨機(jī)分布。最終得到Fleiss’ kappa度量為 κ = P  ̄ ? P  ̄ e 1 ? P  ̄ e ∈ ( ? 1 , 1 ) \kappa = \frac {\overline{P} - \overline{P}_e}{1 - \overline{P}_e} \in (-1, 1) κ=1?Pe?P?Pe??∈(?1,1),其中 P  ̄ e \overline{P}_e Pe?越大, κ \kappa κ越小,即預(yù)測的結(jié)果如果被一個(gè)類別主導(dǎo),則 κ \kappa κ會被懲罰。
4. 文章亮點(diǎn)
??文章提出了一種基于prompt一致性的zero-shot task generation學(xué)習(xí)方法swarm distillation,且采用了LoRA和Fleiss’ Kappa方法避免學(xué)習(xí)災(zāi)難性遺忘或?qū)W習(xí)結(jié)果collapse。文章在多個(gè)NLP下游任務(wù)上進(jìn)行了驗(yàn)證,發(fā)現(xiàn)swarm distillation在多個(gè)任務(wù)上表現(xiàn)超過SOTA。此外,數(shù)值實(shí)驗(yàn)表明,swarm distillation只需要4個(gè)prompt,10+個(gè)樣本就可以對源模型(T0)進(jìn)行提升。
??但實(shí)驗(yàn)也表明,swarm distillation方法在增加到一定樣本量之后性能就達(dá)到了飽和,當(dāng)我們有很多標(biāo)記樣本可用的時(shí)候,性能可能不及監(jiān)督微調(diào)。未來可以將swarm distillation與few-shot少樣本學(xué)習(xí)或預(yù)訓(xùn)練相結(jié)合來實(shí)現(xiàn)在標(biāo)記樣本上的性能提升。文章來源:http://www.zghlxwxcb.cn/news/detail-498102.html
5. 原文傳送門
Prompt Consistency for Zero-Shot Task Generalization文章來源地址http://www.zghlxwxcb.cn/news/detail-498102.html
到了這里,關(guān)于論文筆記--Prompt Consistency for Zero-Shot Task Generalization的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!