詳細(xì)顯示如下
x = self.fc(x)
File “D:\Python36\lib\site-packages\torch\nn\modules\module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “D:\Python36\lib\site-packages\torch\nn\modules\linear.py”, line 103, in forward
return F.linear(input, self.weight, self.bias)
File “D:\Python36\lib\site-packages\torch\nn\functional.py”, line 1848, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x704 and 2304x4)
根據(jù)提示,全連接層兩個(gè)需要相乘的矩陣維度不匹配,代碼中batchSize為8,最后的類別數(shù)量為4。
原因,樣本總數(shù)不是批次的倍數(shù),有余數(shù),因此,最后一個(gè)批次的樣本會(huì)產(chǎn)生該問題。
解決方案1,dataloader中需要設(shè)置參數(shù)drop_last=True。即丟棄最后一個(gè)不足batchSize的樣本。文章來源:http://www.zghlxwxcb.cn/news/detail-418987.html
trainLoader = DataLoader(dataset=trainSet, batch_size=batchSize, shuffle=True, drop_last=True)
解決方案2,reshape時(shí)使用樣本的數(shù)量文章來源地址http://www.zghlxwxcb.cn/news/detail-418987.html
......
for seq, y_train in trainLoader:
sampleSize = seq.shape[0]
optimizer.zero_grad()
y_pred = model(seq.reshape(sampleSize, 1, -1)) # Dataloader中drop_last=False
# y_pred = model(seq.reshape(batchSize, 1, -1))
......
到了這里,關(guān)于RuntimeError mat1 and mat2 shapes cannot be multiplied的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!