Conv1d
Conv1d 的輸入數(shù)據(jù)維度通常是一個(gè)三維張量,形狀為 (batch_size, in_channels, sequence_length),其中:
batch_size 表示當(dāng)前輸入數(shù)據(jù)的批次大??;
in_channels 表示當(dāng)前輸入數(shù)據(jù)的通道數(shù),對于文本分類任務(wù)通常為 1,對于圖像分類任務(wù)通常為 3(RGB)、1(灰度)等;
sequence_length 表示當(dāng)前輸入數(shù)據(jù)的序列長度,對于文本分類任務(wù)通常為詞向量的長度,對于時(shí)序信號(hào)處理任務(wù)通常為時(shí)間序列的長度,對于圖像分類任務(wù)通常為圖像的高或?qū)挕?br> 具體來說,Conv1d 模塊會(huì)對第二維和第三維分別進(jìn)行一維卷積操作,保留第一維(即批次大?。┎蛔?,輸出一個(gè)新的三維張量,形狀為 (batch_size, out_channels, new_sequence_length),其中 out_channels 表示卷積核的數(shù)量,new_sequence_length 表示卷積后的序列長度。
示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv1d(in_channels=1, out_channels=16, kernel_size=2),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2),
nn.Conv1d(in_channels=16, out_channels=32, kernel_size=2),
nn.ReLU(),
# nn.MaxPool1d(kernel_size=2)
)
self.fc = nn.Linear(128, 2)
def forward(self, x):
x = x.unsqueeze(1)
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
x = torch.randn(200,6)
# x = x.unsqueeze(1)
net = Net()
output = net(x)
print(x.shape)
Conv2d
在 PyTorch 中,使用 nn.Conv2d 創(chuàng)建卷積層時(shí),輸入數(shù)據(jù)的維度應(yīng)該是 (batch_size, input_channels, height, width)。其中,文章來源:http://www.zghlxwxcb.cn/news/detail-793741.html
batch_size 表示當(dāng)前輸入數(shù)據(jù)的批次大小;
input_channels 表示當(dāng)前輸入數(shù)據(jù)的通道數(shù),對于彩色圖像通常為 3(RGB),對于灰度圖像通常為 1;
height 和 width 分別表示輸入數(shù)據(jù)的高和寬。因此,在 PyTorch 框架中,Conv2d 的輸入數(shù)據(jù)維度應(yīng)該是一個(gè)四維張量,形狀為 (batch_size, input_channels, height, width)。文章來源地址http://www.zghlxwxcb.cn/news/detail-793741.html
到了這里,關(guān)于pytorch框架:conv1d、conv2d的輸入數(shù)據(jù)維度是什么樣的的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!