模型推理詳細(xì)步驟
模型加載步驟
首先,模型加載總共分為三步,第一步加載網(wǎng)絡(luò)結(jié)構(gòu),需要和你訓(xùn)時(shí)的network結(jié)構(gòu)一樣。
model = Model.FeedBack3(cfg, config_path=None, pretrained=True).to(device)
第二步,加載訓(xùn)練好的參數(shù),實(shí)際上雖然我們一直說訓(xùn)練模型,實(shí)際上訓(xùn)練出來的就是一組參數(shù),這個(gè)參數(shù)是一個(gè)字典類型,一般保存的名稱為xxx.pt或者pth。里面存放的是模型每一層中的權(quán)重等數(shù)據(jù)。pytorch中對(duì)于加載參數(shù)使torch.load()
pretrained_dict = torch.load('outputmicrosoft-deberta-v3-base_fold3_best.pth')
第三步,將參數(shù)加載進(jìn)模型里
model.load_state_dict(pretrained_dict['model_state_dict'], strict=True)
以上就是加載模型的所有步驟了
關(guān)于模型參數(shù)和字典對(duì)不上的問題
一般報(bào)錯(cuò)為:Missing key(s) in state_dict: xxxx
最近在做模型部署的時(shí)候發(fā)現(xiàn)了這個(gè)問題,并且之前也遇到過,由于急于求成就簡單實(shí)在模型加載參數(shù)的時(shí)候用了strict=False這樣的條件,這個(gè)條件會(huì)使模型直接忽略所有對(duì)不上的參數(shù),本質(zhì)上沒有解決問題。今天在debug時(shí)對(duì)模型每一層的參數(shù)排查終于發(fā)現(xiàn)了問題所在。
首先開啟debug模式,直接將斷點(diǎn)打在模型加載的代碼上:
首先查看model的結(jié)構(gòu)有沒有問題:
接下來進(jìn)行下一步,執(zhí)行到加載參數(shù)字典,同樣查看你的參數(shù)字典(這里由于參數(shù)過多就不詳細(xì)展示了):
那么要如何排查呢,具體步驟如下:
首先參數(shù)字典里都是以鍵值對(duì)和tensor型式存儲(chǔ)的,那么我們只需要一一排查鍵值對(duì)和參數(shù)。比如首先是model建,那么只有你加載參數(shù)的時(shí)候只有加載里面的model建模型才能讀到參數(shù),實(shí)際上我就是錯(cuò)在這里了,因?yàn)槲壹虞d的是通常使用的‘model_state_dict’這個(gè)建,因?yàn)槲矣?xùn)練部分是網(wǎng)上復(fù)制來的代碼,沒想到他把參數(shù)保存為model。
也就是我只需要把前面的
model.load_state_dict(pretrained_dict['model_state_dict'])
改成
model.load_state_dict(pretrained_dict['model'])
就行了。
那么如果你的問題不是這里,接下來改如何排查呢
接著看OrderedDict里,這里面是模型每一層的參數(shù),對(duì)照方法如下:
相當(dāng)于網(wǎng)絡(luò)結(jié)構(gòu)中的每一層都會(huì)變?yōu)橐粋€(gè)對(duì)應(yīng)的tensor
(model)(embeddings)(LayerNorm)在參數(shù)中就會(huì)存為:(‘model.embdeddings.LayerNorm’, tensor([xxxxx])
這樣就看懂了吧,如此對(duì)照每一層網(wǎng)絡(luò)結(jié)構(gòu),只要你有耐心,就能找出來具體是那一層不對(duì),不過大多情況下這種在網(wǎng)絡(luò)中間層出現(xiàn)參數(shù)不對(duì)的情況很少,出現(xiàn)的原因也肯定是你推理部分加載的網(wǎng)絡(luò)結(jié)構(gòu)和訓(xùn)練時(shí)的網(wǎng)絡(luò)結(jié)構(gòu)不一致導(dǎo)致的。
順便推薦一個(gè)能幫你排查模型參數(shù)的代碼,他會(huì)輸出具體有多少參數(shù)使用了和沒使用:
def check_keys(model, pretrained_state_dict):
ckpt_keys = set(pretrained_state_dict.keys())
model_keys = set(model.state_dict().keys())
used_pretrained_keys = model_keys & ckpt_keys
unused_pretrained_keys = ckpt_keys - model_keys
missing_keys = model_keys - ckpt_keys
# filter 'num_batches_tracked'
missing_keys = [x for x in missing_keys
if not x.endswith('num_batches_tracked')]
if len(missing_keys) > 0:
print('[Warning] missing keys: {}'.format(missing_keys))
print('missing keys:{}'.format(len(missing_keys)))
if len(unused_pretrained_keys) > 0:
print('[Warning] unused_pretrained_keys: {}'.format(
unused_pretrained_keys))
print('unused checkpoint keys:{}'.format(
len(unused_pretrained_keys)))
print('used keys:{}'.format(len(used_pretrained_keys)))
assert len(used_pretrained_keys) > 0, \
'check_key load NONE from pretrained checkpoint'
return True
模型推理中的數(shù)據(jù)處理
首先模型推理中數(shù)據(jù)最終的處理格式要和訓(xùn)練時(shí)輸入進(jìn)網(wǎng)絡(luò)中的格式一致,不過我們通常不再構(gòu)造新的dataset和使用dataloader,而是直接針對(duì)input處理成我們需要的格式。
主要步驟為,讀取數(shù)據(jù),embedding,增加維度
讀取的數(shù)據(jù)可以是本地存的,如果你是要將模型部署在web上那么數(shù)據(jù)就是從客戶端傳來的json格式的數(shù)據(jù),因此通常需要先將真正的input取出來。
接下來是向量化,這里步驟和訓(xùn)練中的一致,比如訓(xùn)練中使用了resize([800,800])和toTensor,那么推理中也要這樣設(shè)置。
由于我是NLP任務(wù),那么處理的步驟為
inputs = cfg.tokenizer.encode_plus(
input,
return_tensors=None,
add_special_tokens=True,
max_length=cfg.max_lenth,
pad_to_max_length=True,
truncation=True
)
for k, v in inputs.items():
inputs[k] = torch.tensor(v, dtype=torch.long)
至此,再次輸出此時(shí)的tensor和訓(xùn)練時(shí)輸入進(jìn)模型的tensor相比,只是少了一個(gè)維度,這個(gè)維度通??梢岳斫馕覀?cè)谟?xùn)練的時(shí)候是有batch_size的,而推理時(shí)沒有,因此要手動(dòng)升維,升維度的函數(shù)有很多,通常使用unsequeeze(1)或者expand:文章來源:http://www.zghlxwxcb.cn/news/detail-647935.html
for k, v in inputs.items():
s = v.shape
inputs[k] = v.expand(1,-1).to(device) #-1自動(dòng)計(jì)算
這樣處理完數(shù)據(jù)格式就和訓(xùn)練時(shí)完全一致了,說白了還是要先debug一下訓(xùn)練時(shí)的數(shù)據(jù),看看到底輸進(jìn)去的是什么格式,然后在推理部分照著一點(diǎn)一點(diǎn)改。文章來源地址http://www.zghlxwxcb.cn/news/detail-647935.html
到了這里,關(guān)于模型推理詳細(xì)步驟以及如何排查模型和參數(shù)字典對(duì)不上的問題:Missing key(s) in state_dict: xxxx的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!