TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation
這篇文章中你可以找到一下內(nèi)容:
- Attention是怎么樣在CNN中火起來的?-Non Local
- Transformer結(jié)構(gòu)帶來了什么?-Multi Head Self Attention
- Transformer結(jié)構(gòu)為何在CV中如此流行?-Vision Transformer和SETR
- TransUnet又是如何魔改Unet和Transformer?-ResNet50+VIT作為backbone\Encoder
- TransUnet的pytorch代碼實(shí)現(xiàn)
- 作者吐槽以及偷懶的痕跡
引文
在醫(yī)學(xué)圖像分割領(lǐng)域,U形結(jié)構(gòu)的網(wǎng)絡(luò),尤其是Unet,已經(jīng)取得了很優(yōu)秀的效果。但是,CNN結(jié)構(gòu)并不擅長(zhǎng)建立遠(yuǎn)程信息連接,也就是CNN結(jié)構(gòu)的感受野有限。盡管可以通過堆疊CNN結(jié)構(gòu)、使用空洞卷積等方式增加感受野,但也會(huì)引入一些奇怪的問題(包括但不限于卷積核退化、空洞卷積造成的柵格化),導(dǎo)致最終效果受限。
基于self-attention機(jī)制的Transformer結(jié)構(gòu)在NLP任務(wù)中已經(jīng)取得了重要的成就,Vision Transformer將Transformer結(jié)構(gòu)引入了CV領(lǐng)域,并在當(dāng)年取得了十分優(yōu)秀的成果。Transformer因此在CV中流行起來。
話說回來,為什么Transformer結(jié)構(gòu)能夠在CV領(lǐng)域中獲得不錯(cuò)的效果?
Attention is all you need?
在介紹Transformer之前,我們先看一下CNN結(jié)構(gòu)中有什么好玩的東西。
先回顧一下 Non Local結(jié)構(gòu)
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k} } )V Attention(Q,K,V)=softmax(dk??QKT?)V
從Non Local開始,注意力(Attention)機(jī)制在17、18年的各大頂會(huì)大殺四方,出現(xiàn)了包括NonLocal Net、DANet、PSANet、ISANet、CCNet等等網(wǎng)絡(luò)。這里的核心思想只有一個(gè),就是Attention機(jī)制,可以不限距離的建立遠(yuǎn)程連接,突破了CNN模型感受野不足的問題。當(dāng)然,這種Attention的計(jì)算方法有一個(gè)缺陷就是計(jì)算量很大。因此,在這一個(gè)方向,CCNet、ISANet等等網(wǎng)絡(luò),也針對(duì)計(jì)算量大這一個(gè)缺陷進(jìn)行優(yōu)化,從而發(fā)了一些頂會(huì)論文。
當(dāng)然,為什么會(huì)想到提出Non Local來計(jì)算Attention呢,是因?yàn)镹on Local作者從Transformer中得到了靈感。所以,再回到提出Transformer的那篇經(jīng)典論文《Attention is all you need》。
這篇論文主要是兩個(gè)工作,一個(gè)是提出了Transformer,另一個(gè)則是Multi-head Attention,也就是用多頭注意力機(jī)制來代替注意力。
Transformer的結(jié)構(gòu)很簡(jiǎn)單,主要就是Multi-Head Atention、FFN、Norm幾個(gè)模塊。其中需要注意的就是Multi-Head Atention。
Multi-Head Atention其實(shí)并不難理解,Multi-Head Atention只是Attention機(jī)制中的一種。Multi-Head Atention顧名思義,也就是有多個(gè)Head,其中每一個(gè)Head計(jì)算一組注意力,也就是將Scaled Dot-Product Attention的過程做h次,再把輸出合并起來。這樣,同一個(gè)位置有擁有了h個(gè)表示,相比于Scaled Dot-Product Attention,輸出的內(nèi)容就更加豐富了。
M
u
l
t
i
?
H
e
a
d
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
C
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
,
h
e
a
d
h
)
W
O
\small Multi-Head Attention(Q, K, V) = Concat(head_1, ..., head_h)W^O
Multi?HeadAttention(Q,K,V)=Concat(head1?,...,headh?)WO
h
e
a
d
i
=
A
t
t
e
n
t
i
o
n
(
Q
W
i
Q
,
K
W
i
K
,
V
W
i
V
)
\small head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
headi?=Attention(QWiQ?,KWiK?,VWiV?)
Vision Transformer - the pioneer from CNN to Transformer
Vision Transformer可謂是CV屆的開路先鋒,也是CVer的救世主,在沒有Vit前,CVer不知道還要在Non Local中掙扎多久。(當(dāng)然,現(xiàn)在Transformer也快掙扎不下去了)。
Vit的論文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》Google的人取名字都挺有意思。
實(shí)現(xiàn)原理也很簡(jiǎn)單,Transformer處理的都是序列數(shù)據(jù),而圖像數(shù)據(jù)是不能直接輸入Transformer的。因此呢,Vit就想了一個(gè)方法,把圖像分成9塊,也就是9個(gè)patch(當(dāng)然,可以分成16塊,25塊等等,具體取決于你的一個(gè)patch的大小)。這樣,再把patch按順序拼接起來,變成一個(gè)序列,這個(gè)序列添加了一個(gè)positional encoding后,就可以輸入Transformer中進(jìn)行處理。這里的positional encoding作用是讓模型知道圖像patch的順序,有助于模型學(xué)習(xí)。
Vit在ImageNet上的成功,讓CV屆看到了希望。分割是CV的一大任務(wù),既然Vit能夠進(jìn)行分類,那他就能像ResNet一樣充當(dāng)分割任務(wù)的Backbone。
SERT Vit也能用于語義分割!
那么,在另一個(gè)CVPR頂會(huì)論文中,《Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers》SERT就最先使用Vit作為BackBone實(shí)現(xiàn)語義分割任務(wù)。
SERT模型實(shí)現(xiàn)也很簡(jiǎn)單,用經(jīng)典的encoder-decoder網(wǎng)絡(luò),Vit作為BackBone,設(shè)計(jì)了三種不同的Decoder結(jié)構(gòu),進(jìn)行語義分割實(shí)驗(yàn),證明Vit在語義分割中是可行的。很簡(jiǎn)單的一個(gè)思路,先實(shí)現(xiàn)就能先吃到肉(感謝Vit白送的一個(gè)頂會(huì))。
正文
前面廢話了很多,都是關(guān)于CNN、Attention、Non Local、Transformer,我們回到TransUnet模型。CV論文中很大一部分都是拼湊剪裁(雖然TransUnet看起來也像是拼湊剪裁)。不過,拼湊剪裁也是一門藝術(shù)。正如下圖,TransUnet結(jié)構(gòu)。
還是很經(jīng)典的Unet形網(wǎng)絡(luò),但和CNN-base的Unet不同,這里前三層是CNN-based,但是最后一層是Transformer-based。也就是把Unet的encoder最后一層換成了Transformer模型。
為什么只有一層Transformer
TransUnet只將其中一部分換成Transformer也是有它自己的考慮。雖然Transformer能夠獲得到全局的感受野,但是在細(xì)節(jié)特征的處理上存在缺陷。
SegFormer:《Segmenter: Transformer for Semantic Segmentation》論文中討論了patch size大小對(duì)于模型預(yù)測(cè)結(jié)果的影響,發(fā)現(xiàn),大patch size雖然計(jì)算速度更快,但是邊緣的分割效果明顯很差,而小patch size邊緣相對(duì)更為精確一些。
很多事實(shí)都證明,Transformer對(duì)于局部的細(xì)節(jié)分割是有缺陷的。而CNN反而是得益于其局部的感受野,能夠較為精確恢復(fù)細(xì)節(jié)特征。因此呢,TransUnet模型只替換了最后一層,而這一層則更多關(guān)注全局信息,這是Transformer擅長(zhǎng)的,至于淺層的細(xì)節(jié)識(shí)別任務(wù)則由CNN來完成。
TransUnet具體細(xì)節(jié)
- decoder結(jié)構(gòu)很簡(jiǎn)單,還是典型的skip-connection和upsample結(jié)合。
- 對(duì)于encoder部分:
- 作者選取了ResNet50的前三層作為CNN結(jié)構(gòu),這很好理解,ResNet牛逼嘛。
- 最后一層則是Vit結(jié)構(gòu),也就是12層Transformer Layer
- 作者把encoder叫做R50-ViT。
對(duì)于Vit的一些介紹,可以看另一篇文章:VIT+SETR,本文就偷懶省略了。
不過,需要注意的是,如果輸入Vit的大小為(b, c, W, H),patch size=P時(shí),Vit的輸出為(b, c, W/P, H/P), 也就是 H / P H/P H/P , W / P W/P W/P,需要上采樣到(W, H)大小。
TransUnet模型實(shí)現(xiàn)
Encoder部分
Encoder部分主要由ResNet50和Vit組成,在ResNet50部分,取消掉stem_block結(jié)構(gòu)中的4倍下采樣,保留前三層模型結(jié)構(gòu),這三層都選擇兩倍下采樣,其中最后一層的輸出作為Vit的輸入,這樣保證了feature size、channel number和原圖對(duì)應(yīng)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
expansion: int = 4
def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
base_width = 64, dilation = 1, norm_layer = None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride,
padding=dilation,groups=groups, bias=False,dilation=dilation)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride,
padding=dilation,groups=groups, bias=False,dilation=dilation)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample= None,
groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
self.bn1 = norm_layer(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 2
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
f"or a 3-element tuple, got {replace_stride_with_dilation}"
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 64//4, layers[0], stride=2)
self.layer2 = self._make_layer(block, 128//4, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256//4, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512//4, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(
self,
block,
planes,
blocks,
stride = 1,
dilate = False,
):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = stride
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
norm_layer(planes * block.expansion))
layers = []
layers.append(
block(
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
)
)
return nn.Sequential(*layers)
def _forward_impl(self, x):
out = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
out.append(x)
x = self.layer2(x)
out.append(x)
x = self.layer3(x)
out.append(x)
# 最后一層不輸出
# x = self.layer4(x)
# out.append(x)
return out
def forward(self, x) :
return self._forward_impl(x)
def _resnet(block, layers, pretrained_path = None, **kwargs,):
model = ResNet(block, layers, **kwargs)
if pretrained_path is not None:
model.load_state_dict(torch.load(pretrained_path), strict=False)
return model
def resnet50(pretrained_path=None, **kwargs):
return ResNet._resnet(Bottleneck, [3, 4, 6, 3], pretrained_path,**kwargs)
def resnet101(pretrained_path=None, **kwargs):
return ResNet._resnet(Bottleneck, [3, 4, 23, 3], pretrained_path,**kwargs)
if __name__ == "__main__":
v = ResNet.resnet50().cuda()
img = torch.randn(1, 3, 512, 512).cuda()
preds = v(img)
# torch.Size([1, 64, 256, 256])
print(preds[0].shape)
# torch.Size([1, 128, 128, 128])
print(preds[1].shape)
# torch.Size([1, 256, 64, 64])
print(preds[2].shape)
接著是Vit部分,Vit接受ResNet50的第三個(gè)輸出。
import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
def pair(t):
return t if isinstance(t, tuple) else (t, t)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 512, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.out = Rearrange("b (h w) c->b c h w", h=image_height//patch_height, w=image_width//patch_width)
# 這里上采樣倍數(shù)為8倍。為了保持和圖中的feature size一樣
self.upsample = nn.UpsamplingBilinear2d(scale_factor = patch_size//2)
self.conv = nn.Sequential(
nn.Conv2d(dim, dim, 3, padding=1),
nn.BatchNorm2d(dim),
nn.ReLU())
def forward(self, img):
# 這里對(duì)應(yīng)了圖中的Linear Projection,主要是將圖片分塊嵌入,成為一個(gè)序列
x = self.to_patch_embedding(img)
b, n, _ = x.shape
# 為圖像切片序列加上索引
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
# 輸入到Transformer中處理
x = self.transformer(x)
# delete cls_tokens, 輸出前需要?jiǎng)h除掉索引
output = x[:,1:,:]
output = self.out(output)
# Transformer輸出后,上采樣到原始尺寸
output = self.upsample(output)
output = self.conv(output)
return output
import torch
if __name__ == "__main__":
v = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1).cpu()
# 假設(shè)ResNet50第三層輸出大小是 1, 256, 64, 64 也就是b, c, W/8, H/8
img = torch.randn(1, 256, 64, 64).cpu()
preds = v(img)
# 輸出是 b, c, W/16, H/16
# preds: torch.Size([1, 512, 32, 32])
print("preds: ",preds.size())
再把兩個(gè)部分合并一下,包裝成TransUnetEncoder類。
class TransUnetEncoder(nn.Module):
def __init__(self, **kwargs):
super(TransUnetEncoder, self).__init__()
self.R50 = ResNet.resnet50()
self.Vit = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1)
def forward(self, x):
x1, x2, x3 = self.R50(x)
x4 = self.Vit(x3)
return [x1, x2, x3, x4]
if __name__ == "__main__":
x = torch.randn(1, 3, 512, 512).cuda()
net = TransUnetEncoder().cuda()
out = net(x)
# torch.Size([1, 64, 256, 256])
print(out[0].shape)
# torch.Size([1, 128, 128, 128])
print(out[1].shape)
# torch.Size([1, 256, 64, 64])
print(out[2].shape)
# torch.Size([1, 512, 32, 32])
print(out[3].shape)
Decoder部分
Decoder部分就是經(jīng)典的Unet decoder模塊了,接受skip connection,然后卷積,上采樣、卷積。同樣包裝成TransUnetDecoder類。
class TransUnetDecoder(nn.Module):
def __init__(self, out_channels=64, **kwargs):
super(TransUnetDecoder, self).__init__()
self.decoder1 = nn.Sequential(
nn.Conv2d(out_channels//4, out_channels//4, 3, padding=1),
nn.BatchNorm2d(out_channels//4),
nn.ReLU()
)
self.upsample1 = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels, out_channels//4, 3, padding=1),
nn.BatchNorm2d(out_channels//4),
nn.ReLU()
)
self.decoder2 = nn.Sequential(
nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.upsample2 = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.decoder3 = nn.Sequential(
nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
nn.BatchNorm2d(out_channels*2),
nn.ReLU()
)
self.upsample3 = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
nn.BatchNorm2d(out_channels*2),
nn.ReLU()
)
self.decoder4 = nn.Sequential(
nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
nn.BatchNorm2d(out_channels*4),
nn.ReLU()
)
self.upsample4 = nn.Sequential(
nn.UpsamplingBilinear2d(scale_factor=2),
nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
nn.BatchNorm2d(out_channels*4),
nn.ReLU()
)
def forward(self, inputs):
x1, x2, x3, x4 = inputs
# b 512 H/8 W/8
x4 = self.upsample4(x4)
x = self.decoder4(torch.cat([x4, x3], dim=1))
x = self.upsample3(x)
x = self.decoder3(torch.cat([x, x2], dim=1))
x = self.upsample2(x)
x = self.decoder2(torch.cat([x, x1], dim=1))
x = self.upsample1(x)
x = self.decoder1(x)
return x
if __name__ == "__main__":
x1 = torch.randn([1, 64, 256, 256]).cuda()
x2 = torch.randn([1, 128, 128, 128]).cuda()
x3 = torch.randn([1, 256, 64, 64]).cuda()
x4 = torch.randn([1, 512, 32, 32]).cuda()
net = TransUnetDecoder().cuda()
out = net([x1,x2,x3,x4])
# out: torch.Size([1, 16, 512, 512])
print(out.shape)
TransUnet類
最后將Encoder和Decoder包裝成TransUnet。
class TransUnet(nn.Module):
# 主要是修改num_classes
def __init__(self, num_classes=4, **kwargs):
super(TransUnet, self).__init__()
self.TransUnetEncoder = TransUnetEncoder()
self.TransUnetDecoder = TransUnetDecoder()
self.cls_head = nn.Conv2d(16, num_classes, 1)
def forward(self, x):
x = self.TransUnetEncoder(x)
x = self.TransUnetDecoder(x)
x = self.cls_head(x)
return x
if __name__ == "__main__":
# 輸入的圖像尺寸 [1, 3, 512, 512]
x1 = torch.randn([1, 3, 512, 512]).cuda()
net = TransUnet().cuda()
out = net(x1)
# 輸出的結(jié)果[batch, num_classes, 512, 512]
print(out.shape)
在Camvid測(cè)試集上測(cè)試一下
因?yàn)槭诸^沒有合適的醫(yī)學(xué)領(lǐng)域的圖像,就隨便找個(gè)數(shù)據(jù)集測(cè)試一下分割效果。
Camvid是自動(dòng)駕駛領(lǐng)域的一個(gè)分割數(shù)據(jù)集,八九百?gòu)垐D像比較少,在我的電腦上運(yùn)行快一點(diǎn)。
一些參數(shù)設(shè)置如下
# 導(dǎo)入庫(kù)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
torch.manual_seed(17)
# 自定義數(shù)據(jù)集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
"""CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
Args:
images_dir (str): path to images folder
masks_dir (str): path to segmentation masks folder
class_values (list): values of classes to extract from segmentation mask
augmentation (albumentations.Compose): data transfromation pipeline
(e.g. flip, scale, etc.)
preprocessing (albumentations.Compose): data preprocessing
(e.g. noralization, shape manipulation, etc.)
"""
def __init__(self, images_dir, masks_dir):
self.transform = A.Compose([
A.Resize(512, 512),
A.HorizontalFlip(),
A.VerticalFlip(),
A.Normalize(),
ToTensorV2(),
])
self.ids = os.listdir(images_dir)
self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
def __getitem__(self, i):
# read data
image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
image = self.transform(image=image,mask=mask)
return image['image'], image['mask'][:,:,0]
def __len__(self):
return len(self.ids)
# 設(shè)置數(shù)據(jù)集路徑
DATA_DIR = r'../blork_file/dataset//camvid/' # 根據(jù)自己的路徑來設(shè)置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
train_dataset = CamVidDataset(
x_train_dir,
y_train_dir,
)
val_dataset = CamVidDataset(
x_valid_dir,
y_valid_dir,
)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, drop_last=True)
一些模型和訓(xùn)練過程設(shè)置
from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
import monai
# model
model = TransUnet(num_classes=33).cuda()
# training loop 100 epochs
epochs_num = 100
# 選用SGD優(yōu)化器來訓(xùn)練
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)
# 損失函數(shù)選用多分類交叉熵?fù)p失函數(shù)
lossf = nn.CrossEntropyLoss(ignore_index=255)
def evaluate_accuracy_gpu(net, data_iter, device=None):
if isinstance(net, nn.Module):
net.eval() # Set the model to evaluation mode
if not device:
device = next(iter(net.parameters())).device
# No. of correct predictions, no. of predictions
metric = d2l.Accumulator(2)
with torch.no_grad():
for X, y in data_iter:
if isinstance(X, list):
# Required for BERT Fine-tuning (to be covered later)
X = [x.to(device) for x in X]
else:
X = X.to(device)
y = y.to(device)
output = net(X)
metric.add(d2l.accuracy(output, y), d2l.size(y))
return metric[0] / metric[1]
# 訓(xùn)練函數(shù)
def train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, schedule, devices=d2l.try_all_gpus()):
timer, num_batches = d2l.Timer(), len(train_iter)
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'])
net = nn.DataParallel(net, device_ids=devices).to(devices[0])
# 用來保存一些訓(xùn)練參數(shù)
loss_list = []
train_acc_list = []
test_acc_list = []
epochs_list = []
time_list = []
lr_list = []
for epoch in range(num_epochs):
# Sum of training loss, sum of training accuracy, no. of examples,
# no. of predictions
metric = d2l.Accumulator(4)
for i, (X, labels) in enumerate(train_iter):
timer.start()
if isinstance(X, list):
X = [x.to(devices[0]) for x in X]
else:
X = X.to(devices[0])
gt = labels.long().to(devices[0])
net.train()
optimizer.zero_grad()
result = net(X)
loss_sum = loss(result, gt)
loss_sum.sum().backward()
optimizer.step()
acc = d2l.accuracy(result, gt)
metric.add(loss_sum, acc, labels.shape[0], labels.numel())
timer.stop()
if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))
schedule.step()
test_acc = evaluate_accuracy_gpu(net, test_iter)
animator.add(epoch + 1, (None, None, test_acc))
print(f"epoch {epoch+1}/{epochs_num} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- lr {optimizer.state_dict()['param_groups'][0]['lr']} --- cost time {timer.sum()}")
#---------保存訓(xùn)練數(shù)據(jù)---------------
df = pd.DataFrame()
loss_list.append(metric[0] / metric[2])
train_acc_list.append(metric[1] / metric[3])
test_acc_list.append(test_acc)
epochs_list.append(epoch+1)
time_list.append(timer.sum())
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
df['epoch'] = epochs_list
df['loss'] = loss_list
df['train_acc'] = train_acc_list
df['test_acc'] = test_acc_list
df["lr"] = lr_list
df['time'] = time_list
df.to_excel("../blork_file/savefile/TransUnet_camvid.xlsx")
#----------------保存模型-------------------
if np.mod(epoch+1, 5) == 0:
torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_{epoch+1}.pth')
# 保存下最后的model
torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_last.pth')
# 開始訓(xùn)練
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num, schedule)
訓(xùn)練結(jié)果:
說在最后
文章的代碼雖然比較粗糙,但大抵上是與TransUnet原圖對(duì)應(yīng)的。如果你想得到不同規(guī)模的模型,需要更改的只是每一層的通道數(shù)量,你需要在ResNet50中、Vit、Decoder中進(jìn)行修改和確認(rèn)。如果你想將TransUnet用在不同的數(shù)據(jù)集中,你只需要在創(chuàng)建模型時(shí)修改num_classes的數(shù)值即可。文章來源:http://www.zghlxwxcb.cn/news/detail-447482.html
作者注:文章來源地址http://www.zghlxwxcb.cn/news/detail-447482.html
- num_classes的構(gòu)成主要為:background+類別1+類別2+類別n。
- 作者比較懶,還在自我批評(píng)中。如果作者不懶的話,可以把通道數(shù)的關(guān)系連接一下,這樣只需要改一處就可以修改模型規(guī)模了,不像現(xiàn)在需要改好幾個(gè)地方,還需要進(jìn)行驗(yàn)證。
- 不過,驗(yàn)證的過程也是學(xué)習(xí)的過程,所以,多看一看代碼改一改對(duì)小白來說是有很大的好處的。
- 因此,作者在這里為自己偷懶找了一個(gè)不錯(cuò)的借口。
- 這篇文章寫完了TransUnet,應(yīng)某位讀者的要求,下一篇文章會(huì)寫SwinUnet。
- 個(gè)人認(rèn)為,Transformer效果不一定會(huì)很好。至少作者在自己的細(xì)胞數(shù)據(jù)集上測(cè)試情況來講,Swin Transformer的結(jié)果不如傳統(tǒng)的CNN模型來得更好。Transformer存在的缺陷很明顯,同時(shí)GPU資源消耗很大。但是在大物體上的分割效果會(huì)很不錯(cuò),這也是注意力機(jī)制的強(qiáng)大之處。但其在細(xì)小物體和邊界的處理上,明顯來的不那么好。這種情況下,使用deformable-DETR中提到的multi-scale Deformable Attention或許會(huì)達(dá)到一個(gè)不錯(cuò)的效果,畢竟可以更關(guān)注局部信息。不過2022年的各大頂會(huì)已經(jīng)也都開始了對(duì)Transformer的魔改,融合CNN到Transformer中,從而達(dá)到局部全局兩手抓的效果,像什么MixFormer、MaxVit啊等等。
- 總之呢,個(gè)人認(rèn)為,CV快到瓶頸期了,期待下一匹黑馬誕生,干翻Transformer和CNN。
到了這里,關(guān)于醫(yī)學(xué)圖像分割2 TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!