目錄
ASPP結(jié)構(gòu)介紹
ASPP在代碼中的構(gòu)建
參考資料
ASPP結(jié)構(gòu)介紹
ASPP:Atrous Spatial Pyramid Pooling,空洞空間卷積池化金字塔。
簡單理解就是個至尊版池化層,其目的與普通的池化層一致,盡可能地去提取特征。
利用主干特征提取網(wǎng)絡(luò),會得到一個淺層特征和一個深層特征,這一篇主要以如何對較深層特征進(jìn)行加強(qiáng)特征提取,也就是在Encoder中所看到的部分。
它就叫做ASPP,主要有5個部分:
- 1x1卷積
- 膨脹率為6的3x3卷積
- 膨脹率為12的3x3卷積
- 膨脹率為18的3x3卷積
- 對輸入進(jìn)去的特征層進(jìn)行池化
接著會對這五個部分進(jìn)行一個堆疊,再利用一個1x1卷積對通道數(shù)進(jìn)行調(diào)整,獲得上圖中綠色的特征。
ASPP在代碼中的構(gòu)建
import torch
import torch.nn as nn
import torch.nn.functional as F
class ASPP(nn.Module):
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
super(ASPP, self).__init__()
self.branch1 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(1,1), stride=(1,1), padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=6 * rate, dilation=6 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch3 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=12 * rate, dilation=12 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch4 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=(3,3), stride=(1,1), padding=18 * rate, dilation=18 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch5_conv = nn.Conv2d(dim_in, dim_out, kernel_size=(1,1), stride=(1,1), padding=0, bias=True)
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
self.branch5_relu = nn.ReLU(inplace=True)
self.conv_cat = nn.Sequential(
nn.Conv2d(dim_out * 5, dim_out ,kernel_size=(1,1), stride=(1,1), padding=0, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
def forward(self, x):
[b, c, row, col] = x.size()
# 五個分支
conv1x1 = self.branch1(x)
conv3x3_1 = self.branch2(x)
conv3x3_2 = self.branch3(x)
conv3x3_3 = self.branch4(x)
# 第五個分支,進(jìn)行全局平均池化+卷積
global_feature = torch.mean(x, 2, True)
global_feature = torch.mean(global_feature, 3, True)
global_feature = self.branch5_conv(global_feature)
global_feature = self.branch5_bn(global_feature)
global_feature = self.branch5_relu(global_feature)
global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
# 五個分支的內(nèi)容堆疊起來,然后1x1卷積整合特征。
feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
result = self.conv_cat(feature_cat)
return result
if __name__ == "__main__":
model = ASPP(dim_in=320, dim_out=256, rate=16//16)
print(model)
那么從這里來看的話,也是相當(dāng)清晰的,branch*(1、2、3、4、5)分別代表了ASPP五個部分在def __init__()可以體現(xiàn),對于每一個都是卷積、標(biāo)準(zhǔn)化、激活函數(shù)。
第五個部分可以看到def forward中,首先呢,是要進(jìn)行一個全局平均池化,再用1x1卷積通道數(shù)的整合,標(biāo)準(zhǔn)化、激活函數(shù),接著采用上采樣的方法,把它的大小調(diào)整成和我們上面獲得的分支一樣大小的特征層,這樣我們才可以將五個部分進(jìn)行一個堆疊,使用的是torch.cat()函數(shù)實現(xiàn),最后,利用1x1卷積,對輸入進(jìn)來的特征層進(jìn)行一個通道數(shù)的調(diào)整,獲得想上圖中綠色的部分,接著就會將這個具有較高語義信息的有效特征層就會傳入到Decoder當(dāng)中。
參考資料
(6條消息) Pytorch-torchvision源碼解讀:ASPP_xiongxyowo的博客-CSDN博客_aspp代碼文章來源:http://www.zghlxwxcb.cn/news/detail-602573.html
DeepLabV3-/deeplabv3+.pdf at main · Auorui/DeepLabV3- (github.com)文章來源地址http://www.zghlxwxcb.cn/news/detail-602573.html
到了這里,關(guān)于DeepLabV3+:ASPP加強(qiáng)特征提取網(wǎng)絡(luò)的搭建的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!