最近的工作有涉及該任務(wù),整理一下思路以及代碼細(xì)節(jié)。
流程
總體來說思路就是首先用預(yù)訓(xùn)練的bert模型,在訓(xùn)練集的序列上進(jìn)行CLS任務(wù)。對序列內(nèi)容(這里默認(rèn)是token id的sequence)以0.3左右的概率進(jìn)行隨機(jī)mask,然后將相應(yīng)sequence的attention mask(原來決定padding index)和label(也就是mask的ground truth)輸入到bert model里面。
當(dāng)然其中vocab.txt并不存在的token是需要add進(jìn)去的,具體方法不再詳述,網(wǎng)上例子很多,注意word embedding也需要初始化就行。
模型定義:self.model = AutoModelForMaskedLM.from_pretrained('./bert')
模型的輸入:result = self.bert_model(tail_mask, attention_mask, labels)
得到模型訓(xùn)練的結(jié)果之后,要做一個選擇:
(1)transformer的bert model可以輸出要預(yù)測時間步的hidden state,可以選擇取出對應(yīng)的hidden state,其中需要在數(shù)據(jù)處理的時候記錄下每個sequence的tail position,也就是要預(yù)測位置的idx。另外我認(rèn)為既然要進(jìn)行序列推薦,那么最后一個tail position的token表征一定是最重要的,所以需要對tail position的idx專門給個寫死的mask,效果會好一些。然后與sequence中item的全集進(jìn)行相似度的計算,再去算交叉熵loss。
bert_hidden = result.hidden_states[-1]
bert_seq_hidden = torch.zeros((self.args.batch_size, 312)).to(self.device)
for i in range(self.args.batch_size):
bert_seq_hidden[i,:] = bert_hidden[i, tail_pos[i], :]
logits = torch.matmul(bert_seq_hidden, test_item_emb.transpose(0, 1))
main_loss = self.criterion(logits, targets)
(2)同時也可以result.loss
直接數(shù)據(jù)mask prediction的loss,我理解這個loss面對的任務(wù)是我要求sequence中的各個token表征都要盡可能準(zhǔn)確,都要考慮,(1)可能更加注重最后一個位置的標(biāo)準(zhǔn)的準(zhǔn)確性。
然后在evaluate階段,需要注意輸入到模型的不再是tail_mask,而是僅僅mask掉tail token id的sequence,因為我們需要盡可能準(zhǔn)確的序列信息,只需要保證要預(yù)測的存在mask就夠了。
由于是推薦任務(wù),而且bert得到的hidden state表征過于隱式,所以需要一定的個性化引導(dǎo)它進(jìn)行訓(xùn)練。經(jīng)過個人的實(shí)驗也確實(shí)如此,而且結(jié)果相差很多。文章來源:http://www.zghlxwxcb.cn/news/detail-634835.html
以上就是我個人的總結(jié)經(jīng)驗,歡迎大家指點(diǎn)。文章來源地址http://www.zghlxwxcb.cn/news/detail-634835.html
到了這里,關(guān)于使用Bert預(yù)訓(xùn)練模型處理序列推薦任務(wù)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!