国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

醫(yī)學(xué)圖像分割2 TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation

這篇具有很好參考價(jià)值的文章主要介紹了醫(yī)學(xué)圖像分割2 TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問。

TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation

這篇文章中你可以找到一下內(nèi)容:
- Attention是怎么樣在CNN中火起來的?-Non Local
- Transformer結(jié)構(gòu)帶來了什么?-Multi Head Self Attention
- Transformer結(jié)構(gòu)為何在CV中如此流行?-Vision Transformer和SETR
- TransUnet又是如何魔改Unet和Transformer?-ResNet50+VIT作為backbone\Encoder
- TransUnet的pytorch代碼實(shí)現(xiàn)
- 作者吐槽以及偷懶的痕跡

引文

在醫(yī)學(xué)圖像分割領(lǐng)域,U形結(jié)構(gòu)的網(wǎng)絡(luò),尤其是Unet,已經(jīng)取得了很優(yōu)秀的效果。但是,CNN結(jié)構(gòu)并不擅長(zhǎng)建立遠(yuǎn)程信息連接,也就是CNN結(jié)構(gòu)的感受野有限。盡管可以通過堆疊CNN結(jié)構(gòu)、使用空洞卷積等方式增加感受野,但也會(huì)引入一些奇怪的問題(包括但不限于卷積核退化、空洞卷積造成的柵格化),導(dǎo)致最終效果受限。

基于self-attention機(jī)制的Transformer結(jié)構(gòu)在NLP任務(wù)中已經(jīng)取得了重要的成就,Vision Transformer將Transformer結(jié)構(gòu)引入了CV領(lǐng)域,并在當(dāng)年取得了十分優(yōu)秀的成果。Transformer因此在CV中流行起來。

話說回來,為什么Transformer結(jié)構(gòu)能夠在CV領(lǐng)域中獲得不錯(cuò)的效果?

Attention is all you need?

在介紹Transformer之前,我們先看一下CNN結(jié)構(gòu)中有什么好玩的東西。
先回顧一下 Non Local結(jié)構(gòu)

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k} } )V Attention(Q,K,V)=softmax(dk? ?QKT?)V

從Non Local開始,注意力(Attention)機(jī)制在17、18年的各大頂會(huì)大殺四方,出現(xiàn)了包括NonLocal Net、DANet、PSANet、ISANet、CCNet等等網(wǎng)絡(luò)。這里的核心思想只有一個(gè),就是Attention機(jī)制,可以不限距離的建立遠(yuǎn)程連接,突破了CNN模型感受野不足的問題。當(dāng)然,這種Attention的計(jì)算方法有一個(gè)缺陷就是計(jì)算量很大。因此,在這一個(gè)方向,CCNet、ISANet等等網(wǎng)絡(luò),也針對(duì)計(jì)算量大這一個(gè)缺陷進(jìn)行優(yōu)化,從而發(fā)了一些頂會(huì)論文。

當(dāng)然,為什么會(huì)想到提出Non Local來計(jì)算Attention呢,是因?yàn)镹on Local作者從Transformer中得到了靈感。所以,再回到提出Transformer的那篇經(jīng)典論文《Attention is all you need》。

這篇論文主要是兩個(gè)工作,一個(gè)是提出了Transformer,另一個(gè)則是Multi-head Attention,也就是用多頭注意力機(jī)制來代替注意力。

Transformer的結(jié)構(gòu)很簡(jiǎn)單,主要就是Multi-Head Atention、FFN、Norm幾個(gè)模塊。其中需要注意的就是Multi-Head Atention。

Multi-Head Atention其實(shí)并不難理解,Multi-Head Atention只是Attention機(jī)制中的一種。Multi-Head Atention顧名思義,也就是有多個(gè)Head,其中每一個(gè)Head計(jì)算一組注意力,也就是將Scaled Dot-Product Attention的過程做h次,再把輸出合并起來。這樣,同一個(gè)位置有擁有了h個(gè)表示,相比于Scaled Dot-Product Attention,輸出的內(nèi)容就更加豐富了。

M u l t i ? H e a d A t t e n t i o n ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d h ) W O \small Multi-Head Attention(Q, K, V) = Concat(head_1, ..., head_h)W^O Multi?HeadAttention(Q,K,V)=Concat(head1?,...,headh?)WO
h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) \small head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) headi?=Attention(QWiQ?,KWiK?,VWiV?)

Vision Transformer - the pioneer from CNN to Transformer

Vision Transformer可謂是CV屆的開路先鋒,也是CVer的救世主,在沒有Vit前,CVer不知道還要在Non Local中掙扎多久。(當(dāng)然,現(xiàn)在Transformer也快掙扎不下去了)。
Vit的論文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》Google的人取名字都挺有意思。

實(shí)現(xiàn)原理也很簡(jiǎn)單,Transformer處理的都是序列數(shù)據(jù),而圖像數(shù)據(jù)是不能直接輸入Transformer的。因此呢,Vit就想了一個(gè)方法,把圖像分成9塊,也就是9個(gè)patch(當(dāng)然,可以分成16塊,25塊等等,具體取決于你的一個(gè)patch的大小)。這樣,再把patch按順序拼接起來,變成一個(gè)序列,這個(gè)序列添加了一個(gè)positional encoding后,就可以輸入Transformer中進(jìn)行處理。這里的positional encoding作用是讓模型知道圖像patch的順序,有助于模型學(xué)習(xí)。

Vit在ImageNet上的成功,讓CV屆看到了希望。分割是CV的一大任務(wù),既然Vit能夠進(jìn)行分類,那他就能像ResNet一樣充當(dāng)分割任務(wù)的Backbone。

SERT Vit也能用于語義分割!

那么,在另一個(gè)CVPR頂會(huì)論文中,《Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers》SERT就最先使用Vit作為BackBone實(shí)現(xiàn)語義分割任務(wù)。

SERT模型實(shí)現(xiàn)也很簡(jiǎn)單,用經(jīng)典的encoder-decoder網(wǎng)絡(luò),Vit作為BackBone,設(shè)計(jì)了三種不同的Decoder結(jié)構(gòu),進(jìn)行語義分割實(shí)驗(yàn),證明Vit在語義分割中是可行的。很簡(jiǎn)單的一個(gè)思路,先實(shí)現(xiàn)就能先吃到肉(感謝Vit白送的一個(gè)頂會(huì))。

正文

前面廢話了很多,都是關(guān)于CNN、Attention、Non Local、Transformer,我們回到TransUnet模型。CV論文中很大一部分都是拼湊剪裁(雖然TransUnet看起來也像是拼湊剪裁)。不過,拼湊剪裁也是一門藝術(shù)。正如下圖,TransUnet結(jié)構(gòu)。

還是很經(jīng)典的Unet形網(wǎng)絡(luò),但和CNN-base的Unet不同,這里前三層是CNN-based,但是最后一層是Transformer-based。也就是把Unet的encoder最后一層換成了Transformer模型。

為什么只有一層Transformer

TransUnet只將其中一部分換成Transformer也是有它自己的考慮。雖然Transformer能夠獲得到全局的感受野,但是在細(xì)節(jié)特征的處理上存在缺陷。
SegFormer:《Segmenter: Transformer for Semantic Segmentation》論文中討論了patch size大小對(duì)于模型預(yù)測(cè)結(jié)果的影響,發(fā)現(xiàn),大patch size雖然計(jì)算速度更快,但是邊緣的分割效果明顯很差,而小patch size邊緣相對(duì)更為精確一些。

很多事實(shí)都證明,Transformer對(duì)于局部的細(xì)節(jié)分割是有缺陷的。而CNN反而是得益于其局部的感受野,能夠較為精確恢復(fù)細(xì)節(jié)特征。因此呢,TransUnet模型只替換了最后一層,而這一層則更多關(guān)注全局信息,這是Transformer擅長(zhǎng)的,至于淺層的細(xì)節(jié)識(shí)別任務(wù)則由CNN來完成。

TransUnet具體細(xì)節(jié)

  • decoder結(jié)構(gòu)很簡(jiǎn)單,還是典型的skip-connection和upsample結(jié)合。
  • 對(duì)于encoder部分:
    • 作者選取了ResNet50的前三層作為CNN結(jié)構(gòu),這很好理解,ResNet牛逼嘛。
    • 最后一層則是Vit結(jié)構(gòu),也就是12層Transformer Layer
    • 作者把encoder叫做R50-ViT。

對(duì)于Vit的一些介紹,可以看另一篇文章:VIT+SETR,本文就偷懶省略了。

不過,需要注意的是,如果輸入Vit的大小為(b, c, W, H),patch size=P時(shí),Vit的輸出為(b, c, W/P, H/P), 也就是 H / P H/P H/P , W / P W/P W/P,需要上采樣到(W, H)大小。

TransUnet模型實(shí)現(xiàn)

Encoder部分

Encoder部分主要由ResNet50和Vit組成,在ResNet50部分,取消掉stem_block結(jié)構(gòu)中的4倍下采樣,保留前三層模型結(jié)構(gòu),這三層都選擇兩倍下采樣,其中最后一層的輸出作為Vit的輸入,這樣保證了feature size、channel number和原圖對(duì)應(yīng)。

import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes, planes, stride = 1, downsample = None, groups = 1,
        base_width = 64, dilation = 1, norm_layer = None):
        
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes ,kernel_size=3, stride=stride, 
                               padding=dilation,groups=groups, bias=False,dilation=dilation)
        
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample= None,
        groups = 1, base_width = 64, dilation = 1, norm_layer = None,):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, stride=1, bias=False)
        self.bn1 = norm_layer(width)
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, bias=False, padding=dilation, dilation=dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = nn.Conv2d(width, planes * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(
        self,block, layers,num_classes = 1000, zero_init_residual = False, groups = 1,
        width_per_group = 64, replace_stride_with_dilation = None, norm_layer = None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 2
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
            
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64//4, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128//4, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256//4, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512//4, layers[3], stride=1, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block,
        planes,
        blocks,
        stride = 1,
        dilate = False,
    ):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = stride
            
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes,  planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion))

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        out = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        out.append(x)
        x = self.layer2(x)
        out.append(x)
        x = self.layer3(x)
        out.append(x)
        # 最后一層不輸出
        # x = self.layer4(x)
        # out.append(x)
        return out

    def forward(self, x) :
        return self._forward_impl(x)

    def _resnet(block, layers, pretrained_path = None, **kwargs,):
        model = ResNet(block, layers, **kwargs)
        if pretrained_path is not None:
            model.load_state_dict(torch.load(pretrained_path),  strict=False)
        return model
    
    def resnet50(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 6, 3], pretrained_path,**kwargs)
    
    def resnet101(pretrained_path=None, **kwargs):
        return ResNet._resnet(Bottleneck, [3, 4, 23, 3], pretrained_path,**kwargs)

if __name__ == "__main__":
    v = ResNet.resnet50().cuda()
    img = torch.randn(1, 3, 512, 512).cuda()
    preds = v(img)
    # torch.Size([1, 64, 256, 256])
    print(preds[0].shape)
    # torch.Size([1, 128, 128, 128])
    print(preds[1].shape)
    # torch.Size([1, 256, 64, 64])
    print(preds[2].shape)

接著是Vit部分,Vit接受ResNet50的第三個(gè)輸出。

import torch
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels = 512, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.out = Rearrange("b (h w) c->b c h w", h=image_height//patch_height, w=image_width//patch_width)
        
		# 這里上采樣倍數(shù)為8倍。為了保持和圖中的feature size一樣
        self.upsample = nn.UpsamplingBilinear2d(scale_factor = patch_size//2)
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.ReLU())

    def forward(self, img):
    	# 這里對(duì)應(yīng)了圖中的Linear Projection,主要是將圖片分塊嵌入,成為一個(gè)序列
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        # 為圖像切片序列加上索引
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        # 輸入到Transformer中處理
        x = self.transformer(x)

        # delete cls_tokens, 輸出前需要?jiǎng)h除掉索引
        output = x[:,1:,:]
        output = self.out(output)

        # Transformer輸出后,上采樣到原始尺寸
        output = self.upsample(output)
        output = self.conv(output)

        return output


import torch
if __name__ == "__main__":
    v = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1).cpu()
    # 假設(shè)ResNet50第三層輸出大小是 1, 256, 64, 64 也就是b, c, W/8, H/8
    img = torch.randn(1, 256, 64, 64).cpu()
    preds = v(img)
    # 輸出是 b, c, W/16, H/16
    # preds:  torch.Size([1, 512, 32, 32])
    print("preds: ",preds.size())

再把兩個(gè)部分合并一下,包裝成TransUnetEncoder類。

class TransUnetEncoder(nn.Module):
    def __init__(self, **kwargs):
        super(TransUnetEncoder, self).__init__()
        self.R50 = ResNet.resnet50()
        self.Vit = ViT(image_size = (64, 64), patch_size = 16, channels = 256, dim = 512, depth = 12, heads = 16, mlp_dim = 1024, dropout = 0.1, emb_dropout = 0.1)

    def forward(self, x):
        x1, x2, x3 = self.R50(x)
        x4 = self.Vit(x3)
        return [x1, x2, x3, x4]

if __name__ == "__main__":
    x = torch.randn(1, 3, 512, 512).cuda()
    net = TransUnetEncoder().cuda()
    out = net(x)
    # torch.Size([1, 64, 256, 256])
    print(out[0].shape)
    # torch.Size([1, 128, 128, 128])
    print(out[1].shape)
    # torch.Size([1, 256, 64, 64])
    print(out[2].shape)
    # torch.Size([1, 512, 32, 32])
    print(out[3].shape)

Decoder部分

Decoder部分就是經(jīng)典的Unet decoder模塊了,接受skip connection,然后卷積,上采樣、卷積。同樣包裝成TransUnetDecoder類。

class TransUnetDecoder(nn.Module):
    def __init__(self, out_channels=64, **kwargs):
        super(TransUnetDecoder, self).__init__()
        self.decoder1 = nn.Sequential(
            nn.Conv2d(out_channels//4, out_channels//4, 3, padding=1), 
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU()            
        )
        self.upsample1 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels, out_channels//4, 3, padding=1),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU()     
        )

        self.decoder2 = nn.Sequential(
            nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()            
        )
        self.upsample2 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*2, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()     
        )

        self.decoder3 = nn.Sequential(
            nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
            nn.BatchNorm2d(out_channels*2),
            nn.ReLU()            
        )        
        self.upsample3 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*4, out_channels*2, 3, padding=1),
            nn.BatchNorm2d(out_channels*2),
            nn.ReLU()     
        )

        self.decoder4 = nn.Sequential(
            nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
            nn.BatchNorm2d(out_channels*4),
            nn.ReLU()                           
        )
        self.upsample4 = nn.Sequential(
            nn.UpsamplingBilinear2d(scale_factor=2),
            nn.Conv2d(out_channels*8, out_channels*4, 3, padding=1),
            nn.BatchNorm2d(out_channels*4),
            nn.ReLU()     
        )

    def forward(self, inputs):
        x1, x2, x3, x4 = inputs
        # b 512 H/8 W/8
        
        x4 = self.upsample4(x4)
        x = self.decoder4(torch.cat([x4, x3], dim=1))        
        
        x = self.upsample3(x)
        x = self.decoder3(torch.cat([x, x2], dim=1))

        x = self.upsample2(x)
        x = self.decoder2(torch.cat([x, x1], dim=1))

        x = self.upsample1(x)
        x = self.decoder1(x)

        return x

if __name__ == "__main__":
    x1 = torch.randn([1, 64, 256, 256]).cuda()
    x2 = torch.randn([1, 128, 128, 128]).cuda()
    x3 = torch.randn([1, 256, 64, 64]).cuda()
    x4 = torch.randn([1, 512, 32, 32]).cuda()
    net = TransUnetDecoder().cuda()
    out = net([x1,x2,x3,x4])
    # out: torch.Size([1, 16, 512, 512])
    print(out.shape)

TransUnet類

最后將Encoder和Decoder包裝成TransUnet。

class TransUnet(nn.Module):
	# 主要是修改num_classes 
    def __init__(self, num_classes=4, **kwargs):
        super(TransUnet, self).__init__()
        self.TransUnetEncoder = TransUnetEncoder()
        self.TransUnetDecoder = TransUnetDecoder()
        self.cls_head = nn.Conv2d(16, num_classes, 1)
    def forward(self, x):
        x = self.TransUnetEncoder(x)
        x = self.TransUnetDecoder(x)
        x = self.cls_head(x)
        return x
    
if __name__ == "__main__":
	# 輸入的圖像尺寸 [1, 3, 512, 512]
    x1 = torch.randn([1, 3, 512, 512]).cuda()
    net = TransUnet().cuda()
    out = net(x1)
    # 輸出的結(jié)果[batch, num_classes, 512, 512]
    print(out.shape)

在Camvid測(cè)試集上測(cè)試一下

因?yàn)槭诸^沒有合適的醫(yī)學(xué)領(lǐng)域的圖像,就隨便找個(gè)數(shù)據(jù)集測(cè)試一下分割效果。
Camvid是自動(dòng)駕駛領(lǐng)域的一個(gè)分割數(shù)據(jù)集,八九百?gòu)垐D像比較少,在我的電腦上運(yùn)行快一點(diǎn)。
一些參數(shù)設(shè)置如下

# 導(dǎo)入庫(kù)
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
 
torch.manual_seed(17)
# 自定義數(shù)據(jù)集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(512, 512),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
 
    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 設(shè)置數(shù)據(jù)集路徑
DATA_DIR = r'../blork_file/dataset//camvid/' # 根據(jù)自己的路徑來設(shè)置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    
train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)
 
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, drop_last=True)

一些模型和訓(xùn)練過程設(shè)置

from d2l import torch as d2l
from tqdm import tqdm
import pandas as pd
import monai
# model
model = TransUnet(num_classes=33).cuda()
# training loop 100 epochs
epochs_num = 100
# 選用SGD優(yōu)化器來訓(xùn)練
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30,80], gamma=0.5)

# 損失函數(shù)選用多分類交叉熵?fù)p失函數(shù)
lossf = nn.CrossEntropyLoss(ignore_index=255)

def evaluate_accuracy_gpu(net, data_iter, device=None):
    if isinstance(net, nn.Module):
        net.eval()  # Set the model to evaluation mode
        if not device:
            device = next(iter(net.parameters())).device
    # No. of correct predictions, no. of predictions
    metric = d2l.Accumulator(2)

    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(X, list):
                # Required for BERT Fine-tuning (to be covered later)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            output = net(X)
            metric.add(d2l.accuracy(output, y), d2l.size(y))
    return metric[0] / metric[1]


# 訓(xùn)練函數(shù)
def train_ch13(net, train_iter, test_iter, loss, optimizer, num_epochs, schedule, devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    # 用來保存一些訓(xùn)練參數(shù)

    loss_list = []
    train_acc_list = []
    test_acc_list = []
    epochs_list = []
    time_list = []
    lr_list = []

    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (X, labels) in enumerate(train_iter):
            timer.start()

            if isinstance(X, list):
                X = [x.to(devices[0]) for x in X]
            else:
                X = X.to(devices[0])
            gt = labels.long().to(devices[0])

            net.train()
            optimizer.zero_grad()
            result = net(X)
            loss_sum = loss(result, gt)
            loss_sum.sum().backward()
            optimizer.step()

            acc = d2l.accuracy(result, gt)
            metric.add(loss_sum, acc, labels.shape[0], labels.numel())

            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,(metric[0] / metric[2], metric[1] / metric[3], None))
                
        schedule.step()

        test_acc = evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
        print(f"epoch {epoch+1}/{epochs_num} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- lr {optimizer.state_dict()['param_groups'][0]['lr']} --- cost time {timer.sum()}")
        
        #---------保存訓(xùn)練數(shù)據(jù)---------------
        df = pd.DataFrame()
        loss_list.append(metric[0] / metric[2])
        train_acc_list.append(metric[1] / metric[3])
        test_acc_list.append(test_acc)
        epochs_list.append(epoch+1)
        time_list.append(timer.sum())
        lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
        
        df['epoch'] = epochs_list
        df['loss'] = loss_list
        df['train_acc'] = train_acc_list
        df['test_acc'] = test_acc_list
        df["lr"] = lr_list
        df['time'] = time_list
        
        df.to_excel("../blork_file/savefile/TransUnet_camvid.xlsx")
        #----------------保存模型------------------- 
        if np.mod(epoch+1, 5) == 0:
            torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_{epoch+1}.pth')

    # 保存下最后的model
    torch.save(net.state_dict(), f'../blork_file/checkpoints/TransUnet_last.pth')
    
# 開始訓(xùn)練
train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num, schedule)

訓(xùn)練結(jié)果:
醫(yī)學(xué)圖像分割2 TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation

說在最后

文章的代碼雖然比較粗糙,但大抵上是與TransUnet原圖對(duì)應(yīng)的。如果你想得到不同規(guī)模的模型,需要更改的只是每一層的通道數(shù)量,你需要在ResNet50中、Vit、Decoder中進(jìn)行修改和確認(rèn)。如果你想將TransUnet用在不同的數(shù)據(jù)集中,你只需要在創(chuàng)建模型時(shí)修改num_classes的數(shù)值即可。

作者注文章來源地址http://www.zghlxwxcb.cn/news/detail-447482.html

  • num_classes的構(gòu)成主要為:background+類別1+類別2+類別n。
  • 作者比較懶,還在自我批評(píng)中。如果作者不懶的話,可以把通道數(shù)的關(guān)系連接一下,這樣只需要改一處就可以修改模型規(guī)模了,不像現(xiàn)在需要改好幾個(gè)地方,還需要進(jìn)行驗(yàn)證。
  • 不過,驗(yàn)證的過程也是學(xué)習(xí)的過程,所以,多看一看代碼改一改對(duì)小白來說是有很大的好處的。
  • 因此,作者在這里為自己偷懶找了一個(gè)不錯(cuò)的借口。
  • 這篇文章寫完了TransUnet,應(yīng)某位讀者的要求,下一篇文章會(huì)寫SwinUnet。
  • 個(gè)人認(rèn)為,Transformer效果不一定會(huì)很好。至少作者在自己的細(xì)胞數(shù)據(jù)集上測(cè)試情況來講,Swin Transformer的結(jié)果不如傳統(tǒng)的CNN模型來得更好。Transformer存在的缺陷很明顯,同時(shí)GPU資源消耗很大。但是在大物體上的分割效果會(huì)很不錯(cuò),這也是注意力機(jī)制的強(qiáng)大之處。但其在細(xì)小物體和邊界的處理上,明顯來的不那么好。這種情況下,使用deformable-DETR中提到的multi-scale Deformable Attention或許會(huì)達(dá)到一個(gè)不錯(cuò)的效果,畢竟可以更關(guān)注局部信息。不過2022年的各大頂會(huì)已經(jīng)也都開始了對(duì)Transformer的魔改,融合CNN到Transformer中,從而達(dá)到局部全局兩手抓的效果,像什么MixFormer、MaxVit啊等等。
  • 總之呢,個(gè)人認(rèn)為,CV快到瓶頸期了,期待下一匹黑馬誕生,干翻Transformer和CNN。

到了這里,關(guān)于醫(yī)學(xué)圖像分割2 TransUnet:Transformers Make Strong Encoders for Medical Image Segmentation的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 醫(yī)學(xué)圖像分割之MedNeXt

    醫(yī)學(xué)圖像分割之MedNeXt

    論文: MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation ConvNeXt 網(wǎng)絡(luò)是一種借鑒 Transformer 的思想進(jìn)行了改進(jìn)實(shí)現(xiàn)的全卷積網(wǎng)絡(luò),其通過全卷積網(wǎng)絡(luò)和逆向殘差瓶頸單元的設(shè)計(jì),可以實(shí)現(xiàn)比較大的空間感受野。本文在此基礎(chǔ)上提出了新的可伸縮,標(biāo)準(zhǔn)化的網(wǎng)絡(luò)結(jié)構(gòu)

    2023年04月08日
    瀏覽(20)
  • 醫(yī)學(xué)圖像分割

    醫(yī)學(xué)圖像分割

    方法 內(nèi)容 監(jiān)督學(xué)習(xí) 骨干網(wǎng)絡(luò)的選擇、網(wǎng)絡(luò)塊的設(shè)計(jì)、損失函數(shù)的改進(jìn) 弱監(jiān)督學(xué)習(xí) 數(shù)據(jù)增強(qiáng)、遷移學(xué)習(xí)、交互式分割研究 醫(yī)學(xué)圖像分割的難點(diǎn): 醫(yī)學(xué)圖像的特征提取,因?yàn)獒t(yī)學(xué)影像中存在模糊、噪聲、對(duì)比度低等問題。–CNN 醫(yī)學(xué)圖像通常含有噪聲且邊界模糊。–U-Net 目標(biāo)

    2024年02月04日
    瀏覽(20)
  • 深度學(xué)習(xí)實(shí)驗(yàn)-3d醫(yī)學(xué)圖像分割

    實(shí)驗(yàn)四 基于nnU-Net模型的3D醫(yī)學(xué)圖像分割實(shí)驗(yàn) 腹部多器官分割一直是醫(yī)學(xué)圖像分析領(lǐng)域最活躍的研究領(lǐng)域之一,其作為一項(xiàng)基礎(chǔ)技術(shù),在支持疾病診斷,治療規(guī)劃等計(jì)算機(jī)輔助技術(shù)發(fā)揮著重要作用。近年來,基于深度學(xué)習(xí)的方法在該領(lǐng)域中獲得了巨大成功。本實(shí)驗(yàn)數(shù)據(jù)集為多

    2024年02月07日
    瀏覽(99)
  • 通用醫(yī)學(xué)圖像分割模型UniverSeg

    通用醫(yī)學(xué)圖像分割模型UniverSeg

    雖然深度學(xué)習(xí)模型已經(jīng)成為醫(yī)學(xué)圖像分割的主要方法,但它們通常無法推廣到涉及新解剖結(jié)構(gòu)、圖像模態(tài)或標(biāo)簽的unseen分割任務(wù)。給定一個(gè)新的分割任務(wù),研究人員通常必須訓(xùn)練或微調(diào)模型,這很耗時(shí),并對(duì)臨床研究人員構(gòu)成了巨大障礙,因?yàn)樗麄兺狈τ?xùn)練神經(jīng)網(wǎng)絡(luò)的資

    2024年02月04日
    瀏覽(22)
  • 醫(yī)學(xué)圖像分割常用的評(píng)價(jià)指標(biāo)

    ????在醫(yī)學(xué)圖像分割的論文中,常常看到Dice、VOE、RVD、MSD等指標(biāo),但是具體這些指標(biāo)是什么意思呢,我們進(jìn)行相應(yīng)的簡(jiǎn)單說明。 V s e g text V_{s e g} V s e g ? :代表預(yù)測(cè)的分割結(jié)果 V g t text V_{g t} V g t ? :代表ground truth的分割結(jié)果 ????Dice 系數(shù)是一種評(píng)估相似度的函

    2024年02月08日
    瀏覽(27)
  • 醫(yī)學(xué)圖像的圖像處理、分割、分類和定位-1

    醫(yī)學(xué)圖像的圖像處理、分割、分類和定位-1

    ????????本報(bào)告全面探討了應(yīng)用于醫(yī)學(xué)圖像的圖像處理和分類技術(shù)。開展了四項(xiàng)不同的任務(wù)來展示這些方法的多功能性和有效性。任務(wù) 1 涉及讀取、寫入和顯示 PNG、JPG 和 DICOM 圖像。任務(wù) 2 涉及基于定向變化的多類圖像分類。此外,我們?cè)谌蝿?wù) 3 中包括了胸部 X 光圖像的性

    2024年01月19日
    瀏覽(27)
  • CVPR 2023 醫(yī)學(xué)圖像分割論文大盤點(diǎn)

    CVPR 2023 醫(yī)學(xué)圖像分割論文大盤點(diǎn)

    點(diǎn)擊下方 卡片 ,關(guān)注“ CVer ”公眾號(hào) AI/CV重磅干貨,第一時(shí)間送達(dá) 點(diǎn)擊進(jìn)入— 【醫(yī)學(xué)圖像分割】微信交流群 被催了很久,CVer 正式開啟 CVPR 2023 論文大盤點(diǎn)系列 ! Amusi?一共搜集了13篇醫(yī)學(xué)圖像分割論文 ,這應(yīng)該是目前各平臺(tái)上 最新最全面的CVPR 2023?醫(yī)學(xué)圖像分割盤點(diǎn)資料

    2024年02月14日
    瀏覽(46)
  • 醫(yī)學(xué)圖像分割的全卷積transformer

    醫(yī)學(xué)圖像分割的全卷積transformer

    我們提出了一種新的Transformer ,能夠分割不同模式的醫(yī)學(xué)圖像。醫(yī)學(xué)圖像分析的細(xì)粒度特性所帶來的挑戰(zhàn)意味著Transformer 對(duì)其分析的適應(yīng)仍處于初級(jí)階段。UNet的巨大成功在于它能夠理解分割任務(wù)的細(xì)粒度性質(zhì),這是現(xiàn)有的基于變壓器的模型目前所不具備的能力。為了解決這個(gè)

    2024年02月12日
    瀏覽(26)
  • UniverSeg:通用醫(yī)學(xué)圖像分割模型來了!

    UniverSeg:通用醫(yī)學(xué)圖像分割模型來了!

    自從今年以來ChatGPT爆火和GPT-4的發(fā)布,一時(shí)間在大模型的潮流下,通用人工智能(AGI)也呼之欲出。隨著本月初SAM和SegGPT等通用的CV大模型的提出,大模型和通用模型這把火也逐漸燒到的CV領(lǐng)域,特別是圖像分割領(lǐng)域。很多做分割方向的小伙伴自我調(diào)侃說一覺醒來,自己的方向

    2024年02月08日
    瀏覽(22)
  • 【半監(jiān)督醫(yī)學(xué)圖像分割 2023 CVPR】BCP

    【半監(jiān)督醫(yī)學(xué)圖像分割 2023 CVPR】BCP

    論文題目:Bidirectional Copy-Paste for Semi-Supervised Medical Image Segmentation 中文題目:雙向復(fù)制粘貼半監(jiān)督醫(yī)學(xué)圖像分割 論文鏈接:https://arxiv.org/abs/2305.00673 論文代碼:https://github.com/DeepMed-Lab-ECNU/BCP 論文團(tuán)隊(duì):華東師范大學(xué)上海交通大學(xué) 發(fā)表時(shí)間:2023年5月 DOI: 引用: 引用數(shù): 在半

    2024年02月08日
    瀏覽(28)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包