0. 前言
按照國際慣例,首先聲明:本文只是我自己學(xué)習(xí)的理解,雖然參考了他人的寶貴見解,但是內(nèi)容可能存在不準(zhǔn)確的地方。如果發(fā)現(xiàn)文中錯誤,希望批評指正,共同進步。
1. 關(guān)于forward的兩個小問題
1.1 為什么都用def forward,而不改個名字?
在Pytorch建立神經(jīng)元網(wǎng)絡(luò)模型的時候,經(jīng)常用到forward方法,表示在建立模型后,進行神經(jīng)元網(wǎng)絡(luò)的前向傳播。說的直白點,forward就是專門用來計算給定輸入,得到神經(jīng)元網(wǎng)絡(luò)輸出的方法。
在代碼實現(xiàn)中,也是用def forward
來寫forward前向傳播的方法,我原來以為這是一種約定熟成的名字,也可以換成任意一個自己喜歡的名字。
但是看的多了之后發(fā)現(xiàn)并非如此:Pytorch對于forward方法賦予了一些特殊“功能”
(這里不禁再吐槽,一些看起來挺厲害的Pytorch“大神”,居然不知道這個。。。只能草草解釋一下:“就是這樣的。。?!?
1.2 forward有什么特殊功能?
第一條:.forward()可以不寫
我最開始發(fā)現(xiàn)forward()的與眾不同之處就是在此,首先舉個例子:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T(6))
# print(T.forward(6))
--------------------------運行結(jié)果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
Process finished with exit code 0
可以發(fā)現(xiàn),T(6)是可以輸出的!而且不用指定,默認了調(diào)用forward方法
。當(dāng)然如果非要寫上.forward()這也是可以正常運行的,和不寫是一樣的。
如果不調(diào)用Pytorch(正常的Python語法規(guī)則),這樣肯定會報錯的
# import torch.nn as nn #不再調(diào)用torch
class test():
def __init__(self, input):
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.forward(6))
print("************************")
print(T(6))
--------------------------運行結(jié)果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
************************
Traceback (most recent call last):
File "C:\Users\Lenovo\Desktop\DL\pythonProject\tt.py", line 77, in <module>
print(T(6))
TypeError: 'test' object is not callable
Process finished with exit code 1
這里會報:‘test’ object is not callable
因為class不能被直接調(diào)用,不知道你想調(diào)用哪個方法。
第二條:優(yōu)先運行forward方法
如果在class中再增加一個方法:
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def byten(self):
return self.input * 10
def forward(self,x):
return self.input * x
T = test(8)
print(T(6))
print(T.byten())
--------------------------運行結(jié)果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:/Users/Lenovo/Desktop/DL/pythonProject/tt.py
48
80
Process finished with exit code 0
可以見到,在class中有多個method的時候,如果不指定method,forward是會被優(yōu)先執(zhí)行的。
2. 總結(jié)
在Pytorch中,forward方法是一個特殊的方法,被專門用來進行前向傳播。
20230605 更新
應(yīng)評論要求,增加forward的官方定義,這塊我就不搬運PyTorch官網(wǎng)的內(nèi)容了,直接傳送門走你:nn.Module.forward。
20230919 大更新
首先非常感謝大家喜歡本文!這篇文章本來是我自己的“隨手記”沒想到有這么多C友瀏覽過!
其實在寫完本文后我是有些遺憾的,因為本文僅是用了實驗的方法探索出了.forward()
的表象,而它的運作機理卻沒有說明白,知其然不知其所以然!
在此感謝下面 Mr·小魚 的評論給了我啟迪,因為魔術(shù)方法__call__()
的特性確實很符合.forward()
的表象,但是我對著nn.Module
的源碼一臉茫然,因為源碼中壓根沒有__call__()
方法的定義!!
于是我抱著試試的心態(tài),在PyTorch官網(wǎng)上查了下PyTorch的歷史版本,這一查確實查到了線索:
下面是從PyTorch的上古版本v0.1.12中截取forward()
和__call__()
方法的源碼:
class Module(object):
#...中間不相關(guān)代碼省略...
def forward(self, *input):
"""Defines the computation performed at every call.
Should be overriden by all subclasses.
"""
raise NotImplementedError
#...中間不相關(guān)代碼省略...
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
hook_result = hook(self, input, result)
if hook_result is not None:
raise RuntimeError(
"forward hooks should never return any values, but '{}'"
"didn't return None".format(hook))
var = result
while not isinstance(var, Variable):
var = var[0]
creator = var.creator
if creator is not None and len(self._backward_hooks) > 0:
for hook in self._backward_hooks.values():
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
creator.register_hook(wrapper)
return result
我們可以看到在__call__()
方法中直接把方法self.forward()
作為函數(shù)的返回值,由于魔術(shù)方法__call__()
可以被自動調(diào)用,這也就解釋了為什么forward()
可以自動運行。
至于該方法中的其他內(nèi)容,都是與hook鉤子函數(shù)的操作相關(guān),這部分暫不做探索。。。
那我們回到現(xiàn)在的版本(我現(xiàn)在使用的是1.8.1):
通過源碼可以看到經(jīng)歷了多個版本的更迭,forward()
和__call__()
居然改名字了!!
forward: Callable[..., Any] = _forward_unimplemented
...
__call__ : Callable[..., Any] = _call_impl
這也就是為什么我之前在源碼中沒找到這兩個方法定義的原因。。。準(zhǔn)確來說這里也不能說是改名字了,而是多了一個名字,至于PyTorch為什么會有這樣的更改,我確實也沒想到原因。。。
其中_forward_unimplemented()
倒是沒變:
def _forward_unimplemented(self, *input: Any) -> None:
r"""Defines the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
"""
raise NotImplementedError
而_call_impl()
相比于上古版本,已經(jīng)復(fù)雜到了令人發(fā)指的地步!
def _call_impl(self, *input, **kwargs):
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0:
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if len(full_backward_hooks) > 0:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if len(non_full_backward_hooks) > 0:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result
其變復(fù)雜的原因是各種鉤子函數(shù)_hook的調(diào)用,有興趣的童鞋可以參考這篇文章:pytorch 中_call_impl()函數(shù)。這部分絕對是超綱了!
最后我想再做幾個實驗加深理解:
實驗①
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.__call__(6))
--------------------------運行結(jié)果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py
48
Process finished with exit code 0
這里T.__call__(6)
寫法等價于T(6)
實驗②
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
def forward(self,x):
return self.input * x
T = test(8)
print(T.forward(6))
--------------------------運行結(jié)果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py
48
Process finished with exit code 0
這里T.forward(6)
的寫法雖然也能正確地計算出結(jié)果,但是不推薦這么寫,因為這會導(dǎo)致__call__()
調(diào)用一遍forward()
,然后手動又調(diào)用了一遍forward()
,造成forward()
的重復(fù)計算,浪費計算資源。
實驗③
import torch.nn as nn
class test(nn.Module):
def __init__(self, input):
super(test,self).__init__()
self.input = input
# def forward(self,x):
# return self.input * x
T = test(8)
print(T())
--------------------------運行結(jié)果-------------------------
D:\Users\Lenovo\anaconda3\python.exe C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py
Traceback (most recent call last):
File "C:\Users\Lenovo\Desktop\DL\Pytest\calc_graph\test.py", line 11, in <module>
print(T())
File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "D:\Users\Lenovo\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 201, in _forward_unimplemented
raise NotImplementedError
NotImplementedError
forward()
是必須要寫的,因為__call__()
要自動調(diào)用forward()
。如果壓根不寫forward()
,__call__()
將無方法可以調(diào)用。按照forward()
的源碼,這里會raise NotImplementedError
。文章來源:http://www.zghlxwxcb.cn/news/detail-401864.html
至此,我覺得PyTorch中的forward應(yīng)該算是全說明白了。。。文章來源地址http://www.zghlxwxcb.cn/news/detail-401864.html
到了這里,關(guān)于Pytorch中的forward的理解的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!