torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
功能:對(duì)輸入的四維數(shù)組進(jìn)行批量標(biāo)準(zhǔn)化處理(歸一化)
計(jì)算公式如下:
對(duì)于所有的batch中樣本的同一個(gè)channel的數(shù)據(jù)元素進(jìn)行標(biāo)準(zhǔn)化處理,即如果有C個(gè)通道,無(wú)論batch中有多少個(gè)樣本,都會(huì)在通道維度上進(jìn)行標(biāo)準(zhǔn)化處理,一共進(jìn)行C次
num_features:通道數(shù)
eps:分母中添加的值,目的是計(jì)算的穩(wěn)定性(分母不出現(xiàn)0),默認(rèn)1e-5
momentum:用于運(yùn)行過(guò)程中均值方差的估計(jì)參數(shù),默認(rèn)0.1
affine:設(shè)為true時(shí),給定開易學(xué)習(xí)的系數(shù)矩陣r和b
track_running_stats:BN中存儲(chǔ)的均值方差是否需要更新,true需要更新
舉個(gè)例子
>import torch
>import torch.nn as nn
>input = torch.arange(0, 12, dtype=torch.float32).view(1, 3, 2, 2)
>print(m)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]],
[[ 8., 9.],
[10., 11.]]]])
>m= nn.BatchNorm2d(3)
>print(m.weight)
tensor([1., 1., 1.], requires_grad=True)
>print(m.bias)
tensor([0., 0., 0.], requires_grad=True)
>output = m(input)
>print(output)
tensor([[[[-1.3416, -0.4472],
[ 0.4472, 1.3416]],
[[-1.3416, -0.4472],
[ 0.4472, 1.3416]],
[[-1.3416, -0.4472],
[ 0.4472, 1.3416]]]], grad_fn=<NativeBatchNormBackward0>)
上面是使用nn接口計(jì)算,現(xiàn)在我們拿第一個(gè)數(shù)據(jù)計(jì)算一下驗(yàn)證
公式:
#先計(jì)算第一個(gè)通道的均值、方差
>first_channel = input[0][0] #第一個(gè)通道
tensor([[0., 1.],
[2., 3.]])
#1、計(jì)算均值方差
>mean = torch.Tensor.mean(first_channel)
tensor(1.5000) #均值
>var=torch.Tensor.var(first_channel,False)
tensor(1.2500) #方差
#2、按照公式計(jì)算
>bn_value =((input[0][0][0][0] -mean)/(torch.pow(var,0.5)+m.eps))*m.weight[0]+m.bias[0]
#這里就是(0-1.5)/sqrt(1.25+1e-5)*1.0 + 1.0
tensor(-1.3416, grad_fn=<AddBackward0>)
第一個(gè)值都是-1.3416,對(duì)上了,其他都是一樣。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-568239.html
再來(lái)個(gè)batch_size>1的情況文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-568239.html
#先把結(jié)果貼出來(lái)
tensor([[[[-1.2288, -1.0650],
[-0.9012, -0.7373]],
[[-1.2288, -1.0650],
[-0.9012, -0.7373]],
[[-1.2288, -1.0650],
[-0.9012, -0.7373]]],
[[[ 0.7373, 0.9012],
[ 1.0650, 1.2288]],
[[ 0.7373, 0.9012],
[ 1.0650, 1.2288]],
[[ 0.7373, 0.9012],
[ 1.0650, 1.2288]]]], grad_fn=<NativeBatchNormBackward0>)
>input = torch.arange(0, 24, dtype=torch.float32).view(2, 3, 2, 2)
tensor([[[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]],
[[ 8., 9.],
[10., 11.]]],
[[[12., 13.],
[14., 15.]],
[[16., 17.],
[18., 19.]],
[[20., 21.],
[22., 23.]]]])
>first_channel =input[:, 0, :, :]
tensor([[[ 0., 1.],
[ 2., 3.]],
[[12., 13.],
[14., 15.]]])
>mean = torch.Tensor.mean(first_channel)
tensor(7.5000)
>var=torch.Tensor.var(first_channel,False)
tensor(37.2500)
#第1個(gè)batch中的第一個(gè)c
>print(((input[0][0][:][:] -mean)/(torch.pow(var,0.5)+m.eps))*m.weight[0]+m.bias[0])
tensor([[-1.2288, -1.0650],
[-0.9012, -0.7373]], grad_fn=<AddBackward0>)
#第2個(gè)batch中的第一個(gè)c(共用c的weight、bias、mean、var)
>print(((input[1][0][:][:] -mean)/(torch.pow(var,0.5)+m.eps))*m.weight[0]+m.bias[0])
tensor([[0.7373, 0.9012],
[1.0650, 1.2288]], grad_fn=<AddBackward0>)
到了這里,關(guān)于【CNN記錄】pytorch中BatchNorm2d的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!