一、前言
在高質(zhì)量標(biāo)注數(shù)據(jù)稀缺的工業(yè)界來說,少樣本學(xué)習(xí)或者零樣本學(xué)習(xí)的方法特別受歡迎,后面出現(xiàn)過一些少樣本和零樣本的方法,例如對比學(xué)習(xí)和prompt等,主流prompt的工作分為離散型和連續(xù)型模板。離散型主要還是插入bert特殊的token為主,連續(xù)型則是插入數(shù)字token。離散型可解釋性強于連續(xù)型,我這里講的soft prompt則是連續(xù)型的。
大型預(yù)訓(xùn)練語言模型的規(guī)模不斷擴大,在許多自然語言處理 (NLP) 基準(zhǔn)測試中取得了最先進的結(jié)果。自GPT和BERT開發(fā)以來,標(biāo)準(zhǔn)做法一直是在下游任務(wù)上微調(diào)模型,這涉及調(diào)整網(wǎng)絡(luò)中的每個權(quán)重(即模型調(diào)整)。然而,隨著模型變得越來越大,為每個下游任務(wù)存儲和提供模型的微調(diào)變得不切實際。
一個吸引人的替代方案是在所有下游任務(wù)中共享一個單一的凍結(jié)預(yù)訓(xùn)練語言模型,其中所有權(quán)重都是固定的。在一個令人興奮的發(fā)展中,GPT-3令人信服地表明,可以通過“上下文”學(xué)習(xí)來調(diào)節(jié)凍結(jié)模型以執(zhí)行不同的任務(wù)。使用這種方法,用戶通過提示設(shè)計為給定任務(wù)準(zhǔn)備模型,即手工制作帶有手頭任務(wù)描述或示例的文本提示。例如,要為情感分析設(shè)置模型,可以附加提示“以下電影評論是正面的還是負面的?”?在輸入序列之前,“這部電影太棒了!”
二、soft prompt
跨任務(wù)共享相同的凍結(jié)模型極大地簡化了服務(wù)并允許有效的混合任務(wù)推理,但不幸的是,這是以犧牲任務(wù)性能為代價的。文本提示需要人工設(shè)計,即使是精心設(shè)計的提示與模型調(diào)優(yōu)相比仍然表現(xiàn)不佳。例如,在SuperGLUE基準(zhǔn)測試中凍結(jié)的 GPT-3 175B 參數(shù)模型的性能比使用少 800 倍參數(shù) 的微調(diào)T5 模型低 5 個點。
在EMNLP 2021上發(fā)表的“參數(shù)高效提示調(diào)整的規(guī)模力量”中,我們探索了提示調(diào)整,一種使用可調(diào)軟提示調(diào)節(jié)凍結(jié)模型的更有效方法。就像工程文本提示一樣,軟提示連接到輸入文本。但不是從現(xiàn)有的詞匯項目中選擇,軟提示的“標(biāo)記”是可學(xué)習(xí)的向量。這意味著可以在訓(xùn)練數(shù)據(jù)集上端到端優(yōu)化軟提示。除了消除手動設(shè)計的需要之外,這還允許提示從包含數(shù)千或數(shù)百萬個示例的數(shù)據(jù)集中壓縮信息。相比之下,由于模型輸入長度的限制,離散文本提示通常限制在 50 個以下示例。
要為給定任務(wù)創(chuàng)建軟提示,我們首先將提示初始化為固定長度的向量序列(例如,20 個token長)。我們將這些向量附加到每個嵌入輸入的開頭,并將組合序列輸入模型。將模型的預(yù)測與目標(biāo)進行比較以計算損失,并將誤差反向傳播以計算梯度,但是我們僅將這些梯度更新應(yīng)用于我們的新可學(xué)習(xí)向量——保持核心模型凍結(jié)。雖然以這種方式學(xué)習(xí)的軟提示不能立即解釋,但在直觀的層面上,軟提示正在從標(biāo)記的數(shù)據(jù)集中提取有關(guān)如何執(zhí)行任務(wù)的證據(jù),其作用與手動編寫的文本提示相同,但??不需要限于離散的語言。
soft prompt是參數(shù)越多效果越好,引自Google發(fā)表原文。
文章來源:http://www.zghlxwxcb.cn/news/detail-487999.html
三、soft prompt實現(xiàn)
import torch
import torch.nn as nn
class SoftEmbedding(nn.Module):
def __init__(self,
wte: nn.Embedding,
n_tokens: int = 10,
random_range: float = 0.5,
initialize_from_vocab: bool = True):
"""appends learned embedding to
Args:
wte (nn.Embedding): original transformer word embedding
n_tokens (int, optional): number of tokens for task. Defaults to 10.
random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
"""
super(SoftEmbedding, self).__init__()
self.wte = wte
self.n_tokens = n_tokens
self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
n_tokens,
random_range,
initialize_from_vocab))
def initialize_embedding(self,
wte: nn.Embedding,
n_tokens: int = 10,
random_range: float = 0.5,
initialize_from_vocab: bool = True):
"""initializes learned embedding
Args:
same as __init__
Returns:
torch.float: initialized using original schemes
"""
if initialize_from_vocab:
return self.wte.weight[:n_tokens].clone().detach()
return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
def forward(self, tokens):
"""run forward pass
Args:
tokens (torch.long): input tokens before encoding
Returns:
torch.float: encoding of text concatenated with learned task specifc embedding
"""
input_embedding = self.wte(tokens[:, self.n_tokens:])
learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
return torch.cat([learned_embedding, input_embedding], 1)
from transformers import AutoConfig, AdamW, AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from soft_embedding import SoftEmbedding
n_tokens = 20
initialize_from_vocab = True
tokenizer = AutoTokenizer.from_pretrained("nezha-base-wwm")
config = AutoConfig.from_pretrained("nezha-base-wwm", num_labels=num_class)
config.output_hidden_states = True # 需要設(shè)置為true才輸出
model = AutoModel.from_pretrained(self.model_path, config=config)
s_wte = SoftEmbedding(model.get_input_embeddings(),
n_tokens=n_tokens,
initialize_from_vocab=initialize_from_vocab)
model.set_input_embeddings(s_wte)
inputs = tokenizer("May the force be", return_tensors="pt")
# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)
outputs = model(**inputs)
四、總結(jié)
soft prompt比較依賴于模型參數(shù)大小,更加適合零樣本和小樣本,如果用來大數(shù)據(jù)量下微調(diào)模型,效果可能會比普通微調(diào)差不多或者更差點。文章來源地址http://www.zghlxwxcb.cn/news/detail-487999.html
到了這里,關(guān)于提示學(xué)習(xí)soft prompt淺嘗,啟發(fā)了p-tuing的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!