問(wèn)題產(chǎn)生的原因是使用nn.CrossEntropyLoss()來(lái)計(jì)算損失的時(shí)候,target的維度超過(guò)4
import torch
import torch.nn as nn
logit = torch.ones(size=(4, 32, 256, 256)) # b,c,h,w
target = torch.ones(size=(4, 1, 256, 256))
criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target)
如實(shí)target中的C不是1,則可以:
import torch
import torch.nn as nn
logit = torch.ones(size=(4, 32, 256, 256)) # b,c,h,w
target = torch.ones(size=(4, 2, 256, 256))
criterion = nn.CrossEntropyLoss()
losses = 0
for i in range(2):
loss = criterion(logit, target[:, i, ...].long())
losses += loss
?可以看到代碼里面有個(gè).long(),如果不用的話則會(huì)報(bào)錯(cuò):文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-508363.html
RuntimeError: expected scalar type Long but found Float文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-508363.html
到了這里,關(guān)于only batches of spatial targets supported (3D tensors) but got targets of dimension的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!