一、問題
使用了自己創(chuàng)建的dataset和collate_fn,使用了默認(rèn)的dataloader,當(dāng)設(shè)置num_worker的時候,超過0則報如下錯誤:
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
二、原因
查看了平時其他人的代碼,發(fā)現(xiàn)自己在collate_fn對tensor使用了.to(device),而別人則是在從dataloader取出數(shù)據(jù)后,放入模型前,將其移到device中,例如:
for epoch in range(max_epoch):
for batch_idx, (en_id, att_mask_en, zh_id, att_mask_zh) in enumerate(data_loader):
en_id = en_id.to(device)
而我則是錯誤地在collate_fn中將數(shù)據(jù)移至device中:
def collate_fn(batch):
en, zh = list(zip(*batch))
en_output = tokenizer.encode_batch(en,is_pretokenized=True)
zh_id,att_mask_zh = zh_mapper.encode_batch(zh)
en_id =[]
att_mask_en = []
for item in en_output:
en_id.append(item.ids)
att_mask_en.append(item.attention_mask)
en_id = torch.tensor(en_id, dtype=torch.long).to(device)
att_mask_en = torch.tensor(att_mask_en, dtype=torch.bool).to(device)
zh_id = torch.tensor(zh_id, dtype=torch.long).to(device)
att_mask_zh = torch.tensor(att_mask_zh, dtype=torch.bool).to(device)
return en_id, att_mask_en, zh_id, att_mask_zh
當(dāng)然,在定義dataset,方法__getitem__?里把張量放到device里也會產(chǎn)生同樣的錯誤。
此外,我還使用了官方建議的spawn:文章來源:http://www.zghlxwxcb.cn/news/detail-859300.html
if __name__ == '__main__':
torch.multiprocessing.set_start_method('spawn')
train()
但是gpu會卡住不運(yùn)行數(shù)據(jù)加載。?文章來源地址http://www.zghlxwxcb.cn/news/detail-859300.html
到了這里,關(guān)于Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!