1 問題描述
PyTorch提供了非常便捷的多GPU網(wǎng)絡(luò)訓(xùn)練方法:DataParallel
和DistributedDataParallel
。在涉及到一些復(fù)雜模型時,基本都是采用多個GPU并行訓(xùn)練并保存模型。但在推理階段往往只采用單個GPU或者CPU運行。這時怎么將多GPU環(huán)境下保存的模型權(quán)重加載到單GPU/CPU運行環(huán)境下的模型上成了一個關(guān)鍵的問題。
如果忽視環(huán)境問題直接加載往往會出現(xiàn)兩類問題:
1 出現(xiàn)錯誤:IndexError: list index out of range

出現(xiàn)這個錯誤的原因是:現(xiàn)有模型參數(shù)是在多GPU上獲得并保存的,因此在讀入時默認(rèn)會保存至對應(yīng)的GPU上,但是目前推理環(huán)境中只有一塊GPU,所以導(dǎo)致那些本來在其它GPU上的參數(shù)找不到自己應(yīng)該去的GPU編號,出現(xiàn)了一個溢出錯誤,本質(zhì)是GPU編號溢出。
2 出現(xiàn)錯誤:Missing key(s) in state_dict:

出現(xiàn)這個錯誤的原因是:由于模型訓(xùn)練和推理的環(huán)境不同,導(dǎo)致一些參數(shù)丟失,因此報錯。目前在網(wǎng)上的一些解決策略是忽視這些丟失的參數(shù),例如使用命令:model.load_state_dict(torch.load('model.pth'), strict=False)
來成功導(dǎo)入模型。這條命令可以讓程序不報錯并看似成功的導(dǎo)入模型參數(shù)。但實際上這條命令的含義是在導(dǎo)入模型參數(shù)時通過設(shè)置 strict=False
來忽略丟失的參數(shù),也就是說那些丟失參數(shù)地方的模型權(quán)重初仍為初始化隨機(jī)狀態(tài),等同于沒有進(jìn)行訓(xùn)練和學(xué)習(xí),何談推理與驗證?。?!
2 模型保存方式
不論是用哪種方式進(jìn)行推理,在訓(xùn)練的時候要保證程序保存模型的方式是這樣的:
torch.save(model.state_dict(), "model.pth")
3 單塊GPU上加載模型
將多GPU訓(xùn)練的權(quán)重文件加載到單GPU上:
# 1 加載模型
model = Model()
# 2 指定運行設(shè)備,這里為單塊GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 3 將模型用DataParallel方法封裝一次
model = torch.nn.DataParallel(model)
# 4 將模型讀入到GPU設(shè)備上
model = model_E2E.to(device)
# 5 加載權(quán)重文件
model.load_state_dict(torch.load(weight_path, map_location=device))
通過上面的程序就可以實現(xiàn)將多塊GPU上訓(xùn)練得到的權(quán)重文件加載到單塊GPU環(huán)境下的模型中。這里有兩點需要注意:
-
在多GPU訓(xùn)練時,模型使用了
DataParallel
或DistributedDataParallel
方法,這兩種并行化工具會修改模型的結(jié)構(gòu),將模型封裝在一個新的模塊中,通常名為:module
。因此在權(quán)重文件中保存的模型是經(jīng)過DataParallel
封裝后的結(jié)構(gòu)。為了能夠載入全部參數(shù),需要通過步驟3使推理模型與原始多GPU訓(xùn)練模型在結(jié)構(gòu)上保持一致。 -
在步驟5加載模型參數(shù)時使用了
map_location
參數(shù)。這個參數(shù)會告訴 PyTorch在加載模型時應(yīng)該將張量放置在哪個設(shè)備上。設(shè)置map_location=device
,那么無論模型原來是在哪個設(shè)備上訓(xùn)練的,現(xiàn)在都將放置在指定的設(shè)備device='cuda:0'
上。
4 CPU上加載模型
在CPU上加載模型:
from collections import OrderedDict
# 1 加載模型
model = Model()
# 2 指定設(shè)備CPU
device = "cpu"
# 3 讀取權(quán)重文件
state_dict = torch.load(weight_path, map_location=device)
# 4 剝除權(quán)重文件中的module層
if next(iter(state_dict)).startswith("module."):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
state_dict = new_state_dict
# 5 加載權(quán)重文件
model.load_state_dict(state_dict)
# 6 將模型載入到CPU
model = model.to(device)
在CPU上加載模型的邏輯和GPU的差不多,核心都是因為原權(quán)重文件中的模型被封裝成了module.Model
,所以需要將這層外殼去掉,最后再進(jìn)行讀取并將模型加載到CPU上。文章來源:http://www.zghlxwxcb.cn/news/detail-699399.html
5 總結(jié)
在深度學(xué)習(xí)任務(wù)中訓(xùn)練與推理環(huán)境存在差異的情況十分常見 ,有差異的環(huán)境下實現(xiàn)網(wǎng)絡(luò)權(quán)重文件的正確讀取十分重要。實際操作中一定要確保正確的權(quán)重文件被讀入,這是進(jìn)行推理最基本的前提!最好在推理前做一些對比實驗(例如:選取一部分?jǐn)?shù)據(jù),分別套用已有的程序進(jìn)行訓(xùn)練和推理,對比二者的效果)來確保已經(jīng)讀入到正確的權(quán)重。文章來源地址http://www.zghlxwxcb.cn/news/detail-699399.html
到了這里,關(guān)于PyTorch多GPU訓(xùn)練模型——使用單GPU或CPU進(jìn)行推理的方法的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!