目前pytorch支持2種多卡訓(xùn)練:
1. torch.nn.DataParallel
2. torch.nn.parallel.DistributedDataParallel
第一種只支持單機(jī)多卡,第二種支持單機(jī)多卡和多機(jī)多卡;性能上,第二種優(yōu)于第一種,真正實(shí)現(xiàn)分布式訓(xùn)練。下面只介紹第二種方法的單機(jī)多卡訓(xùn)練:
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParalled as ddp
import os
import argparse
# 1. 初始化group,參數(shù)backend默認(rèn)為nccl,表示多gpu,多cpu是其它參數(shù)
dist.init_process_group(backend='nccl')
# 2. 添加一個(gè)local_rank參數(shù)
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank")
args = parser.parse_args()
# 3. 從外面得到local_rank參數(shù),在調(diào)用DDP的時(shí)候,其會(huì)根據(jù)調(diào)用gpu自動(dòng)給出這個(gè)參數(shù)
local_rank = args.local_rank
# 4.根據(jù)local_rank指定使用那塊gpu
torch.cuda.set_device(local_rank)
# 5.定義設(shè)備
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3" # 根據(jù)gpu的數(shù)量來設(shè)定,初始gpu為0,這里我的gpu數(shù)量為4
DEVICE = torch.device("cuda", local_rank)
# 6. 把模型加載到cuda上
model = Model().to(device) # 自己定義的模型
# 7. 初始化ddp模型
model = ddp(model, device_ids=[local_rank], output_device=local_rank)
# 8. 數(shù)據(jù)分到各gpu上
traindata = datasetloader() # 此處為加載訓(xùn)練集
testdata = datasetloader() # 此處為加載測(cè)試集
train_samper =Data.distributed.DistributedSampler(traindata) # 切分訓(xùn)練數(shù)據(jù)集
test_samper =Data.distributed.DistributedSampler(testdata) # 切分測(cè)試數(shù)據(jù)集
trainloader = Data.DataLoader(dataset=traindata, batch_size=32, sampler=train_samper, shuffle=False, num_workers=2)
testloader = Data.DataLoader(dataset=testdata, batch_size=32, sampler=test_samper, shuffle=False, num_workers=2)
根據(jù)以上設(shè)置,便可實(shí)現(xiàn)單機(jī)多卡的分布式訓(xùn)練,這里需要注意一點(diǎn),也是我踩過的坑,就是Data.DataLoader的shuffle必須為false,否則報(bào)錯(cuò)。
最后,在終端,使用以下命令行運(yùn)行訓(xùn)練腳本(train.py)即可:文章來源:http://www.zghlxwxcb.cn/news/detail-642675.html
python -m torch.distributed.launch --nproc_per_node=4 train.py
這里nproc_per_node的參數(shù)表示gpu數(shù)量文章來源地址http://www.zghlxwxcb.cn/news/detail-642675.html
到了這里,關(guān)于pytorch:?jiǎn)螜C(jī)多卡(GPU)訓(xùn)練的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!