1 nn.BatchNorm
??????? BatchNorm是深度網(wǎng)絡(luò)中經(jīng)常用到的加速神經(jīng)網(wǎng)絡(luò)訓(xùn)練,加速收斂速度及穩(wěn)定性的算法,是深度網(wǎng)絡(luò)訓(xùn)練必不可少的一部分,幾乎成為標(biāo)配;
????????BatchNorm 即批規(guī)范化,是為了將每個(gè)batch的數(shù)據(jù)規(guī)范化為統(tǒng)一的分布,幫助網(wǎng)絡(luò)訓(xùn)練, 對輸入數(shù)據(jù)做規(guī)范化,稱為Covariate shift;
??????? 數(shù)據(jù)經(jīng)過一層層網(wǎng)絡(luò)計(jì)算后,數(shù)據(jù)的分布也在發(fā)生著變化,因?yàn)槊恳淮螀?shù)迭代更新后,上一層網(wǎng)絡(luò)輸出數(shù)據(jù),經(jīng)過這一層網(wǎng)絡(luò)參數(shù)的計(jì)算,數(shù)據(jù)的分布會發(fā)生變化,這就為下一層網(wǎng)絡(luò)的學(xué)習(xí)帶來困難 -- 也就是在每一層都進(jìn)行批規(guī)范化(Internal Covariate shift),方便網(wǎng)絡(luò)訓(xùn)練,因?yàn)樯窠?jīng)網(wǎng)絡(luò)本身就是要學(xué)習(xí)數(shù)據(jù)的分布;
??????? 下面通過代碼掩飾BatchNorm的作用;
??????? 首先要清楚,BatchNorm后是不改變輸入的shape的;
????????nn.BatchNorm1d: N * d --> N * d
????????nn.BatchNorm2d: N * C * H * W? -- > N * C * H * W
????????nn.BatchNorm3d: N * C * d * H * W --> N * C * d * H * W
下面講解nn.BatchNorm1d,和nn.BatchNorm2d的情況
1.1 nn.BatchNorm1d
??????? 首先看其參數(shù):
CLASStorch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True, device=None, dtype=None)
??????? 主要參數(shù)介紹:
??????????????? num_features: 輸入維度,也就是數(shù)據(jù)的特征維度;
??????????????? eps: 是在分母上加的一個(gè)值,是為了防止分母為0的情況,讓其能正常計(jì)算;
??????????????? affine: 是仿射變化,將,分別初始化為1和0;
??????? 使用方法介紹:
????????主要作用在特征上,比如輸入維度為N*d, N代表batchsize大小,d代表num_features;
????????而nn.BatchNorm1d是對num_features做歸一化處理,也就是對批次內(nèi)的特征進(jìn)行歸一化;
如輸入 N = 5(batch_size = 5), d = 3(數(shù)據(jù)特征維度為3);
???????? 上圖中的r, b是可學(xué)習(xí)的參數(shù),文檔中成為放射變換,文檔中稱為,? 可以使用x.weight 和 x.bias獲得, r初始化值為1,b初始化值為0;
??????? 上圖中方差的計(jì)算是采用的有偏估計(jì);
??????? 歸一化處理公式:
???????????????? E(x)表示均值, Var(x)表示方差;表示為上述參數(shù)的eps,防止分母為0 的情況;
??????? 演示代碼:
>>> import torch
>>> import torch.nn as nn
m = nn.BatchNorm1d(3) #首先要實(shí)例化,才能使用,3 對應(yīng)輸入特征,也就是number_features
>>> m.weight # 對應(yīng)r ,初始化值為1
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
>>> m.bias # 對應(yīng)b,初始化為0
Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
>>> output.mean(dim = 0) # 歸一化后,平均值都是0, e-08 實(shí)際上也就是0了
tensor([ 0.0000e+00, -1.1921e-08, -2.3842e-08], grad_fn=<MeanBackward1>)
>>> output.std(dim = 0,unbiased = False) # 標(biāo)準(zhǔn)差為1, 有偏估計(jì),所以unbiased = False
tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward0>)
?采用普通方法實(shí)現(xiàn)BatchNorm:
>>> x
tensor([[ 0.0482, -0.1098, 0.4099],
[ 0.9851, 2.8229, -0.7795],
[ 0.3493, -1.0165, -0.0416],
[ 1.5942, -1.3420, 1.0296],
[ 0.0452, -1.0462, -1.1866]])
>>> mean = x.mean(dim = 0)
>>> mean
tensor([ 0.6044, -0.1383, -0.1136])
>>> std = torch.sqrt(1e-5 + torch.var(x,dim = 0, unbiased = False))
>>> std
tensor([0.6020, 1.5371, 0.7976])
>>> (x - mean)/std
tensor([[-0.9239, 0.0185, 0.6564],
[ 0.6325, 1.9265, -0.8348],
[-0.4238, -0.5713, 0.0903],
[ 1.6442, -0.7831, 1.4333],
[-0.9290, -0.5906, -1.3452]])
>>> m(x) # 和上述計(jì)算結(jié)果相同
tensor([[-0.9239, 0.0185, 0.6564],
[ 0.6325, 1.9265, -0.8348],
[-0.4238, -0.5713, 0.0903],
[ 1.6442, -0.7831, 1.4333],
[-0.9290, -0.5906, -1.3452]], grad_fn=<NativeBatchNormBackward0>)
1.2 nn.BatchNorm2d
首先看其參數(shù):
CLASStorch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True, device=None, dtype=None)
使用方法介紹:
??????? 主要作用在特征上,比如輸入維度為B*C*H*W, B代表batchsize大小,C代表channel,H代表圖片的高度維度,W代表圖片的寬度維度;
??????? 而nn.BatchNorm2d是對channel做歸一化處理,也就是對批次內(nèi)的特征進(jìn)行歸一化;
如輸入B * C * H * W = (2 * 3 * 2 * 2):
???????? 計(jì)算的均值和方差的方式實(shí)際上是把batch內(nèi)對應(yīng)通道的數(shù)據(jù)拉平計(jì)算;
??????? 演示代碼:
>>> y = torch.randn(2,3,2,2)
>>> y
tensor([[[[-0.3008, 0.7066],
[ 0.5374, -0.4211]],
[[-0.3935, 0.6193],
[ 0.5375, -0.2747]],
[[ 0.8895, 0.0956],
[-0.0622, 1.7511]]],
[[[-0.2402, 0.6884],
[ 0.5264, 0.3918]],
[[-0.3101, -0.6729],
[-0.5292, -1.0383]],
[[-0.6681, -0.3747],
[ 0.3431, 0.3245]]]])
>>> n = nn.BatchNorm2d(3)
>>> n.weight
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
>>> n.bias
Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
>>> n(y)
tensor([[[[-1.2111, 1.0613],
[ 0.6797, -1.4823]],
[[-0.2544, 1.6433],
[ 1.4902, -0.0318]],
[[ 0.8494, -0.2705],
[-0.4931, 2.0649]]],
[[[-1.0742, 1.0204],
[ 0.6549, 0.3513]],
[[-0.0981, -0.7779],
[-0.5086, -1.4626]],
[[-1.3479, -0.9340],
[ 0.0786, 0.0524]]]], grad_fn=<NativeBatchNormBackward0>)
??????? 關(guān)于均值方差的計(jì)算方法演示:文章來源:http://www.zghlxwxcb.cn/news/detail-417277.html
>>> z = [-1.2111, 1.0613, 0.6797, -1.4823, -1.0742, 1.0204, 0.6549, 0.3513] # 每個(gè)通道拉平計(jì)算
>>> import numpy as np
>>> np.mean(z) # 10的-17次方就是0
-2.7755575615628914e-17
>>> np.std(z) # numpy默認(rèn)是有偏的, torch的模式是無偏的
0.9999846111315913
參考:[pytorch 網(wǎng)絡(luò)模型結(jié)構(gòu)] 深入理解 nn.BatchNorm1d/2d 計(jì)算過程_嗶哩嗶哩_bilibili文章來源地址http://www.zghlxwxcb.cn/news/detail-417277.html
到了這里,關(guān)于nn.BatchNorm講解,nn.BatchNorm1d, nn.BatchNorm2d代碼演示的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!