假設(shè)一共1000個samples,batch size=4,因此一個epoch會有250 iterations,也就是會更新250次
當(dāng)設(shè)置Trainer時
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger], max_steps=40, accumulate_grad_batches=2)
這個max_steps
指的是最多更新的次數(shù),這里也就是40次,而accumulate_grad_batches
指的是每次更新前積累多少個batch,這里為2
因此,每次更新前實際上積累了2 * 4 = 8個samples的gradients(當(dāng)然是取了平均),而最多更新40次,因此最后訓(xùn)練完看到完成的iterations則是80,因為兩個iterations被積累到一起來更新
注意:不論這個max_steps
和accumulate_grad_batches
是多少,訓(xùn)練時顯示的log永遠都是正常訓(xùn)練(無梯度累計)時的樣子
Epoch 0: 32%|▎| 80/250 [02:41<05:39, 2.00s/it, loss=0.85, v_num=15, train/loss_simple_step=0.820, train/loss_v
即,這個250不會因為我們要累計兩個batches而變成125,而是保持為250,且訓(xùn)練完后可以看到完成了80個iterations
默認情況下,Pytorch Lightning在每個epoch結(jié)束后,會保存一次模型,每個epoch包含多少iterations是固定的,不會因為max_steps
和accumulate_grad_batches
的改變而改變,在上面的例子中即250。在最后一次更新完成后也會保存一次模型,不論是在epoch末尾還是中間。
值得注意的是,一個epoch后保存下載的模型的名稱
epoch=0-step=124.ckpt
這個step代表的是目前為止一共更新的次數(shù),而不是iterations的數(shù)量。比如這個在epoch0結(jié)束后保存的模型,一共經(jīng)歷了125個更新steps,而每次step其實積累了兩個batch,即兩個iterations.文章來源:http://www.zghlxwxcb.cn/news/detail-550921.html
還需注意,默認情況Pytorch Lightning只會保存最新的model,然后會刪掉之前保存的舊的model文章來源地址http://www.zghlxwxcb.cn/news/detail-550921.html
到了這里,關(guān)于Pytorch Lightning 訓(xùn)練更新次數(shù)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!