forward()函數(shù)和__init__()的關(guān)系
__init__()
是一個(gè)類的構(gòu)造函數(shù),用于初始化對(duì)象的屬性。它會(huì)在創(chuàng)建對(duì)象時(shí)自動(dòng)調(diào)用,而且通常在這里完成對(duì)象所需的所有初始化操作。
forward()
是一個(gè)神經(jīng)網(wǎng)絡(luò)模型中的方法,用于定義數(shù)據(jù)流的向前傳播過(guò)程。它接受輸入數(shù)據(jù),通過(guò)網(wǎng)絡(luò)的各個(gè)層進(jìn)行計(jì)算,最終返回輸出結(jié)果。
在神經(jīng)網(wǎng)絡(luò)的 PyTorch 實(shí)現(xiàn)中,__init__()
方法通常用于實(shí)例化各個(gè)網(wǎng)絡(luò)層(例如卷積層、池化層、全連接層的維度等【這里只是執(zhí)行了初始化,但是可以通過(guò)后面實(shí)例化時(shí)調(diào)用的forward()重新給神經(jīng)網(wǎng)絡(luò)維度賦值】),并設(shè)置各層的超參數(shù)(例如卷積核大小、步幅、填充等)。而 forward()
方法則定義了這些網(wǎng)絡(luò)層之間的計(jì)算順序與邏輯,它負(fù)責(zé)將輸入數(shù)據(jù)傳遞到網(wǎng)絡(luò)中,并返回計(jì)算結(jié)果【這里輸入進(jìn)forward的數(shù)據(jù)維度要和forward()接收的第一個(gè)參數(shù)維度相同,雖然你看它只接受了一個(gè)參數(shù)‘x’,但是這個(gè)x的維度是多維的(在本代碼中就是(input_dim, hidden_dim)兩個(gè)大維度),而不是普通意義上的一個(gè)自然數(shù)】。
因此,兩個(gè)方法通常一起使用,__init__()
用于設(shè)置網(wǎng)絡(luò)結(jié)構(gòu)和超參數(shù),forward()
則定義了從輸入到輸出的完整計(jì)算流程。
例子:
定義類:
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
在上面的代碼中,我們定義了一個(gè)名為 SimpleNet
的神經(jīng)網(wǎng)絡(luò)模型,它繼承自 PyTorch 中的 nn.Module
類。我們?cè)?__init__()
方法中定義了三層網(wǎng)絡(luò)結(jié)構(gòu),分別是輸入層 fc1
、激活層 relu
和輸出層 fc2
。其中,輸入層和輸出層都使用了全連接層(nn.Linear
),而激活層使用了 ReLU 激活函數(shù)。
在 forward()
方法中,我們按照輸入數(shù)據(jù) x
經(jīng)過(guò) fc1
、relu
和 fc2
三層的順序進(jìn)行計(jì)算,最終返回輸出結(jié)果 out
。
調(diào)用
調(diào)用上述代碼的 forward()
方法需要先創(chuàng)建一個(gè) SimpleNet
類的對(duì)象,并將輸入數(shù)據(jù)傳遞給該對(duì)象。以下是一個(gè)簡(jiǎn)單的示例:
# 創(chuàng)建一個(gè) SimpleNet 對(duì)象,設(shè)置輸入維度為 10,隱藏層維度為 20,輸出維度為 5
net = SimpleNet(10, 20, 5)
# 構(gòu)造一個(gè)隨機(jī)的輸入張量,大小為 [batch_size, input_dim],這里令 batch_size=1
input_tensor = torch.randn(1, 10)
# 將輸入張量傳入網(wǎng)絡(luò)中,得到輸出張量
output_tensor = net(input_tensor)
# 打印輸出張量的形狀
print(output_tensor.shape)
為什么上面的代沒(méi)有看到 __init__()、
forword()函數(shù)的出現(xiàn)就完成了上述代碼的調(diào)用呢?
初始化一個(gè)類時(shí),則自動(dòng)調(diào)用了該類的 __init__()
方法【net = SimpleNet(10, 20, 5)】文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-591099.html
調(diào)用一個(gè)類的實(shí)例時(shí),會(huì)自動(dòng)調(diào)用該類的forward()
方法【output_tensor = net(input_tensor)】文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-591099.html
到了這里,關(guān)于重新理解一個(gè)類中的forward()和__init__()函數(shù)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!