模型訓(xùn)練過(guò)程中常需邊訓(xùn)練邊做validation或在訓(xùn)練完的模型需要做測(cè)試,通常的做法當(dāng)然是先創(chuàng)建model實(shí)例然后掉用load_state_dict()裝載訓(xùn)練出來(lái)的權(quán)重到model里再調(diào)用model.eval()把模型轉(zhuǎn)為測(cè)試模式,這樣寫對(duì)于訓(xùn)練完專門做測(cè)試時(shí)當(dāng)然是比較合適的,但是對(duì)于邊訓(xùn)練邊做validation使用這種方式就需要寫一堆代碼,如果能使用copy.deepcopy()直接深度拷貝訓(xùn)練中的model用來(lái)做validation顯然是比較簡(jiǎn)潔的寫法,但是由于copy.deepcopy()的限制,寫model里代碼時(shí)如果沒(méi)注意,調(diào)用copy.deepcopy(model)時(shí)可能就會(huì)遇到這個(gè)錯(cuò)誤:Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment,詳細(xì)錯(cuò)誤信息如下:
File "/usr/local/lib/python3.6/site-packages/prc/framework/model/validation.py", line 147, in init_val_model
val_model = copy.deepcopy(model)
File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct
value = deepcopy(value, memo)
File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct
value = deepcopy(value, memo)
File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy
y = copier(x, memo)
File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib64/python3.6/copy.py", line 161, in deepcopy
y = copier(memo)
File "/root/.local/lib/python3.6/site-packages/torch/_tensor.py", line 55, in __deepcopy__
raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
這個(gè)錯(cuò)誤簡(jiǎn)單地說(shuō)就是copy.deepcopy()不支持拷貝requires_grad=True的Tensor(在網(wǎng)絡(luò)中一般是非葉子結(jié)點(diǎn)Tensor, grad_fn不為None),開(kāi)始以為真的哪個(gè)地方Tensor的requires_grad沒(méi)有按要求設(shè)置,熬了幾個(gè)夜去檢查調(diào)試網(wǎng)絡(luò)代碼沒(méi)發(fā)現(xiàn)什么線索很郁悶,后來(lái)想既然是copy.deepcopy()里報(bào)錯(cuò)的,源碼也有那就去它里面debug看是拷貝網(wǎng)絡(luò)的那部分時(shí)拋出的Exception吧,折騰了一陣發(fā)現(xiàn)里面這個(gè)地方加breakpoint比較合適:
if dictiter is not None:
if deep:
for key, value in dictiter:
key = deepcopy(key, memo)
value = deepcopy(value, memo)
y[key] = value
else:
for key, value in dictiter:
y[key] = value
我這個(gè)網(wǎng)絡(luò)的結(jié)構(gòu)是使用的python dict方式定義的,運(yùn)行時(shí)使用注冊(cè)機(jī)制動(dòng)態(tài)創(chuàng)建出來(lái)的,既然是dict,這里的key和value就是對(duì)應(yīng)配置文件里的定義網(wǎng)絡(luò)每層結(jié)構(gòu)的dict的key和value,在這里加bp可以比較清楚地跟蹤看到是在哪個(gè)地方導(dǎo)致的拋出Exception,結(jié)果發(fā)現(xiàn)原因是因?yàn)橛袀€(gè)實(shí)現(xiàn)分割功能的head類的內(nèi)部有個(gè)成員變量保存了這層的輸出結(jié)果Tensor用于后面計(jì)算loss,模型每層的輸出數(shù)據(jù)Tensor自然是requires_grad=True,把這個(gè)成員變量去掉,改成forward()輸出結(jié)果,然后在網(wǎng)絡(luò)的主類里接收它并傳入計(jì)算Loss的函數(shù),然后deepcopy(model)就不報(bào)上面的錯(cuò)了!
另外,顯式創(chuàng)建一個(gè)Tensor時(shí)指定requires_grad=True(默認(rèn)是False)并不會(huì)導(dǎo)致copy.deepcopy()報(bào)錯(cuò),不管這個(gè)Tensor是在cpu上還是gpu上,關(guān)鍵是用戶自己創(chuàng)建的Tensor是葉子結(jié)點(diǎn)Tensor,它的grad_fn是None,在這個(gè)Tensor上做切片或者加載到gpu上等操作得到的新的Tensor就不是葉子結(jié)點(diǎn)了,pytorch認(rèn)為requires_grad=Trued的Tensor經(jīng)過(guò)運(yùn)算得到新的Tensor是需要求導(dǎo)的會(huì)自動(dòng)加上grad_fn而不管這個(gè)Tensor是不是網(wǎng)絡(luò)的一部分,這時(shí)再使用copy.deepcopy()深度拷貝新的Tensor時(shí)會(huì)拋出上面的錯(cuò)誤,看完下面的示例就知道了: 文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-802331.html
>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=True, device='cuda:0')
>>> t
tensor([1.0000, 2.0000, 3.5000], device='cuda:0', requires_grad=True)
>>> x = copy.deepcopy(t)
>>> x
tensor([1.0000, 2.0000, 3.5000], device='cuda:0', requires_grad=True)
>>> t1 = t[:2]
>>> t1
tensor([1., 2.], device='cuda:0', grad_fn=<SliceBackward0>)
>>> x = copy.deepcopy(t1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/python3.8/lib/python3.8/copy.py", line 153, in deepcopy
y = copier(memo)
File "/root/.local/lib/python3.8/site-packages/torch/_tensor.py", line 85, in __deepcopy__
raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=True)
>>> t1 = t.cuda()
>>> t1
tensor([1.0000, 2.0000, 3.5000], device='cuda:0', grad_fn=<ToCopyBackward0>)
>>> x = copy.deepcopy(t1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/python3.8/lib/python3.8/copy.py", line 153, in deepcopy
y = copier(memo)
File "/root/.local/lib/python3.8/site-packages/torch/_tensor.py", line 85, in __deepcopy__
raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=False)
>>> t
tensor([1.0000, 2.0000, 3.5000])
>>> x = copy.deepcopy(t)
>>> x
tensor([1.0000, 2.0000, 3.5000])
>>> t1 = t[:2]
>>> t1
tensor([1., 2.])
>>> x = copy.deepcopy(t1)
為何deepcopy()不直接支持有梯度的Tensor,按理要支持復(fù)制一個(gè)當(dāng)時(shí)的瞬間值應(yīng)該也沒(méi)問(wèn)題,看到https://discuss.pytorch.org/t/copy-deepcopy-vs-clone/55022/10這里這個(gè)經(jīng)?;卮饐?wèn)題的胡子哥給了個(gè)猜測(cè):文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-802331.html


到了這里,關(guān)于解決使用copy.deepcopy()拷貝Tensor或model時(shí)報(bào)錯(cuò)只支持用戶顯式創(chuàng)建的Tensor問(wèn)題的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!