目錄
Mobilenetv2的改進(jìn)
淺層特征和深層特征的融合
完整代碼
參考資料
Mobilenetv2的改進(jìn)
在DeeplabV3當(dāng)中,一般不會5次下采樣,可選的有3次下采樣和4次下采樣。因?yàn)橐M(jìn)行五次下采樣的話會損失較多的信息。
在這里mobilenetv2會從之前寫好的模塊中得到,但注意的是,我們在這里獲得的特征是[-1],也就是最后的1x1卷積不取,只取循環(huán)完后的模型。
down_idx是InvertedResidual進(jìn)行的次數(shù)。
# t, c, n, s
[1, 16, 1, 1],?
[6, 24, 2, 2], ?? 2
[6, 32, 3, 2], ? ?4
[6, 64, 4, 2], ? ?7 ?
[6, 96, 3, 1],
[6, 160, 3, 2], ??14
[6, 320, 1, 1],?
根據(jù)下采樣的不同,當(dāng)downsample_factor=8時(shí),進(jìn)行3次下采樣,對倒數(shù)兩次,步長為2的InvertedResidual進(jìn)行參數(shù)的修改,讓步長變?yōu)?,膨脹系數(shù)為2。
當(dāng)downsample_factor=16時(shí),進(jìn)行4次下采樣,只需對最后一次進(jìn)行參數(shù)的修改。
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from net.mobilenetv2 import mobilenetv2
from net.ASPP import ASPP
class MobileNetV2(nn.Module):
def __init__(self, downsample_factor=8, pretrained=True):
super(MobileNetV2, self).__init__()
model = mobilenetv2(pretrained)
self.features = model.features[:-1]
self.total_idx = len(self.features)
self.down_idx = [2, 4, 7, 14]
if downsample_factor == 8:
for i in range(self.down_idx[-2], self.down_idx[-1]):
self.features[i].apply(
partial(self._nostride_dilate, dilate=2)
)
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(
partial(self._nostride_dilate, dilate=4)
)
elif downsample_factor == 16:
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(
partial(self._nostride_dilate, dilate=2)
)
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate//2, dilate//2)
m.padding = (dilate//2, dilate//2)
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def forward(self, x):
low_level_features = self.features[:4](x)
x = self.features[4:](low_level_features)
return low_level_features, x
forward當(dāng)中,會輸出兩個(gè)特征層,一個(gè)是淺層特征層,具有淺層的語義信息;另一個(gè)是深層特征層,具有深層的語義信息。
淺層特征和深層特征的融合
?具有高語義信息的部分先進(jìn)行上采樣,低語義信息的特征層進(jìn)行1x1卷積,二者進(jìn)行特征融合,再進(jìn)行3x3卷積進(jìn)行特征提取
self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
這一步就是獲得那個(gè)綠色的特征層;
low_level_features = self.shortcut_conv(low_level_features)
從這里將是對淺層特征的初步處理(1x1卷積);
x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
上采樣后進(jìn)行特征融合,這樣我們輸入和輸出的大小才相同,每一個(gè)像素點(diǎn)才能進(jìn)行預(yù)測;
完整代碼
# deeplabv3plus.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from net.xception import xception
from net.mobilenetv2 import mobilenetv2
from net.ASPP import ASPP
class MobileNetV2(nn.Module):
def __init__(self, downsample_factor=8, pretrained=True):
super(MobileNetV2, self).__init__()
model = mobilenetv2(pretrained)
self.features = model.features[:-1]
self.total_idx = len(self.features)
self.down_idx = [2, 4, 7, 14]
if downsample_factor == 8:
for i in range(self.down_idx[-2], self.down_idx[-1]):
self.features[i].apply(
partial(self._nostride_dilate, dilate=2)
)
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(
partial(self._nostride_dilate, dilate=4)
)
elif downsample_factor == 16:
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(
partial(self._nostride_dilate, dilate=2)
)
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate//2, dilate//2)
m.padding = (dilate//2, dilate//2)
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def forward(self, x):
low_level_features = self.features[:4](x)
x = self.features[4:](low_level_features)
return low_level_features, x
class DeepLab(nn.Module):
def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):
super(DeepLab, self).__init__()
if backbone=="xception":
# 獲得兩個(gè)特征層:淺層特征 主干部分
self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)
in_channels = 2048
low_level_channels = 256
elif backbone=="mobilenet":
# 獲得兩個(gè)特征層:淺層特征 主干部分
self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)
in_channels = 320
low_level_channels = 24
else:
raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))
# ASPP特征提取模塊
# 利用不同膨脹率的膨脹卷積進(jìn)行特征提取
self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)
# 淺層特征邊
self.shortcut_conv = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True)
)
self.cat_conv = nn.Sequential(
nn.Conv2d(48+256, 256, kernel_size=(3,3), stride=(1,1), padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Conv2d(256, 256, kernel_size=(3,3), stride=(1,1), padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
)
self.cls_conv = nn.Conv2d(256, num_classes, kernel_size=(1,1), stride=(1,1))
def forward(self, x):
H, W = x.size(2), x.size(3)
# 獲得兩個(gè)特征層,low_level_features: 淺層特征-進(jìn)行卷積處理
# x : 主干部分-利用ASPP結(jié)構(gòu)進(jìn)行加強(qiáng)特征提取
low_level_features, x = self.backbone(x)
x = self.aspp(x)
low_level_features = self.shortcut_conv(low_level_features)
# 將加強(qiáng)特征邊上采樣,與淺層特征堆疊后利用卷積進(jìn)行特征提取
x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)
x = self.cat_conv(torch.cat((x, low_level_features), dim=1))
x = self.cls_conv(x)
x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
return x
參考資料
DeepLabV3-/論文精選 at main · Auorui/DeepLabV3- (github.com)文章來源:http://www.zghlxwxcb.cn/news/detail-805206.html
(6條消息) 憨批的語義分割重制版9——Pytorch 搭建自己的DeeplabV3+語義分割平臺_Bubbliiiing的博客-CSDN博客文章來源地址http://www.zghlxwxcb.cn/news/detail-805206.html
到了這里,關(guān)于DeepLabV3+:Mobilenetv2的改進(jìn)以及淺層特征和深層特征的融合的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!