深度學習(28)——YOLO系列(7)
咱就是說,需要源碼請造訪:Jane的GitHub
:在這里
上午沒寫完的,下午繼續(xù),是一個小尾巴。其實上午把訓練的關(guān)鍵部分和數(shù)據(jù)的關(guān)鍵部分都寫完了,現(xiàn)在就是寫一下推理部分
在推理過程為了提高效率,速度更快:
detect 全過程
1.1 attempt_load(weights)
- weights是加載的yolov7之前訓練好的權(quán)重
- 剛開始load以后還有BN,沒有合并的
- 關(guān)鍵在下面的fuse()
1.2 model.fuse()
# 很隱蔽,剛開始我沒想到接口是在這里的
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print('Fusing layers... ')
for m in self.model.modules():
if isinstance(m, RepConv):
#print(f" fuse_repvgg_block")
m.fuse_repvgg_block()
elif isinstance(m, RepConv_OREPA):
#print(f" switch_to_deploy")
m.switch_to_deploy()
elif type(m) is Conv and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.fuseforward # update forward
elif isinstance(m, (IDetect, IAuxDetect)):
m.fuse()
m.forward = m.fuseforward
self.info()
return self
當遇到conv后面一定是有BN的,所以
1.3 fuse_conv_and_bn(conv,bn)
- 先定義一個新的conv【和原來傳入的是一樣的inputsize,outputsize和kernel】
- 先得到w_conv:
w_conv = conv.weight.clone().view(conv.out_channels, -1)
- 得到w_bn:
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
【bn.weight 就是以下公式中的gamma,sigma平方是方差bn.running_var
】 - 得到w_fuse:
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
- 得到b_conv,因為在學習過程中bias我們都設(shè)置為0,所以:
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
- 得到b_bn :
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
【bn.bias是上面公式中的β,μ為均值bn.running_mean】 - 計算b_fuse
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# prepare filters bn.weight 對應論文中的gamma bn.bias對應論文中的beta bn.running_mean則是對于當前batch size的數(shù)據(jù)所統(tǒng)計出來的平均值 bn.running_var是對于當前batch size的數(shù)據(jù)所統(tǒng)計出來的方差
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
1.4 Repvgg_block
把Repvgg中的卷積和BN合在一起文章來源:http://www.zghlxwxcb.cn/news/detail-573956.html
- 原來的block↓
- 融合rbr_dense后:
- 融合rbr_1*1后:
1.5 將1* 1卷積padding成3* 3
padding后
所有的都改變以后:model長這樣——>
OK,這次真沒啦,886~~~~文章來源地址http://www.zghlxwxcb.cn/news/detail-573956.html
到了這里,關(guān)于深度學習(28)——YOLO系列(7)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!