在常見的多分類問題中,先經(jīng)過softmax處理后進(jìn)行交叉熵計(jì)算,原理很簡單可以將計(jì)算loss理解為,為了使得網(wǎng)絡(luò)對測試集預(yù)測的概率分布和其真實(shí)分布接近,常用的做法是使用one-hot對真實(shí)標(biāo)簽進(jìn)行編碼,然后用預(yù)測概率去擬合one-hot的真實(shí)概率。但是這樣會(huì)帶來兩個(gè)問題:
無法保證模型的泛化能力,使網(wǎng)絡(luò)過于自信會(huì)導(dǎo)致過擬合;
全概率和0概率鼓勵(lì)所屬類別和其他類別之間的差距盡可能加大,而由梯度有界可知,這種情況很難adapt。會(huì)造成模型過于相信預(yù)測的類別。
標(biāo)簽平滑可以緩解這個(gè)問題,可以有兩個(gè)角度理解這件事。
角度一
軟化這種one-hot編碼方式。
?
等號(hào)左側(cè):是一種新的預(yù)測的分布
等號(hào)右側(cè):前半部分是對原分布乘一個(gè)權(quán)重, ? \epsilon? 是一個(gè)超參,需要自己設(shè)定,取值在0到1范圍內(nèi)。后半部分u是一個(gè)均勻分布,k表示模型的類別數(shù)。
?
由以上公式可以看出,這種方式使label有 ? \epsilon? 概率來自于均勻分布, 1 ? ? 1-\epsilon1?? 概率來自于原分布。這就相當(dāng)于在原label上增加噪聲,讓模型的預(yù)測值不要過度集中于概率較高的類別,把一些概率放在概率較低的類別。
因此,交叉熵可以替換為:
可以理解為:loss為對“預(yù)測的分布與真實(shí)分布”及“預(yù)測分布與先驗(yàn)分布(均勻分布)”的懲罰。
代碼實(shí)現(xiàn)如下:
?
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_pred = torch.log_softmax(output, dim=-1)
if self.reduction == 'sum':
loss = -log_pred.sum()
else:
loss = -log_pred.sum(dim=-1)
if self.reduction == 'mean':
loss = loss.mean()
return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target, reduction=self.reduction, ignore_index=self.ignore_index)
角度二
對于以Dirac函數(shù)分布的真實(shí)標(biāo)簽,我們將它變成分為兩部分獲得(替換):
- 第一部分:將原本Dirac分布的標(biāo)簽變量替換為(1 - ?)的Dirac函數(shù);
- 第二部分:以概率 ? ,在u(k)u(k) 中份分布的隨機(jī)變量。
def label_smoothing(inputs, epsilon=0.1): K = inputs.get_shape().as_list()[-1] # number of channels return ((1-epsilon) * inputs) + (epsilon / K)
代碼的第一行是取Y的channel數(shù)也就是類別數(shù),第二行就是對應(yīng)公式了。
下面用一個(gè)例子理解一下:假設(shè)我做一個(gè)蛋白質(zhì)二級結(jié)構(gòu)分類,是三分類,那么K=3;假如一個(gè)真實(shí)標(biāo)簽是[0, 0, 1],取epsilon = 0.1,
新標(biāo)簽就變成了 (1 - 0.1)× [0, 0, 1] + (0.1 / 3) = [0, 0, 0.9] + [0.0333, 0.0333, 0.0333]= [0.0333, 0.0333, 0.9333]
實(shí)際上分了一點(diǎn)概率給其他兩類(均勻分),讓標(biāo)簽沒有那么絕對化,留給學(xué)習(xí)一點(diǎn)泛化的空間。
從而能夠提升整體的效果。
torch版本
首先,讓我們使用一個(gè)輔助函數(shù)來計(jì)算兩個(gè)值之間的線性組合:
def linear_combination(x, y, epsilon):
return epsilon*x + (1-epsilon)*y
接下來,我們使用 PyTorch nn.Module實(shí)現(xiàn)一個(gè)新的損失函數(shù)
import torch.nn.functional as F
def reduce_loss(loss, reduction='mean'):
return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, epsilon:float=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
def forward(self, preds, target):
n = preds.size()[-1]
log_preds = F.log_softmax(preds, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
return linear_combination(loss/n, nll, self.epsilon)
我們現(xiàn)在可以在我們的代碼中使用這個(gè)類。 對于這個(gè)例子,我們使用標(biāo)準(zhǔn)的 fast.ai pets 例子。
from fastai.vision import *
from fastai.metrics import error_rate
# prepare the data
path = untar_data(URLs.PETS)
path_img = path/'images'
fnames = get_image_files(path_img)
bs = 64
np.random.seed(2)
pat = r'/([^/]+)_\d+.jpg$'
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs) \
.normalize(imagenet_stats)
# train the model
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.loss_func = LabelSmoothingCrossEntropy()
learn.fit_one_cycle(4)
Tensorflow中使用方法時(shí)候只要在損失函數(shù)中加上label_smoothing的值即可,如下:
tf.losses.softmax_cross_entropy(
onehot_labels,
logits,
weights=1.0,
label_smoothing=0,
scope=None,
loss_collection=tf.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)
-
————————————————
原文鏈接:https://blog.csdn.net/qq_40176087/article/details/121519888文章來源:http://www.zghlxwxcb.cn/news/detail-427750.html -
標(biāo)簽平滑Label Smoothing_奔跑的小仙女的博客-CSDN博客_label smoothing文章來源地址http://www.zghlxwxcb.cn/news/detail-427750.html
到了這里,關(guān)于標(biāo)簽平滑(label smoothing) torch和tensorflow的實(shí)現(xiàn)的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!