1. 模型簡(jiǎn)介
Unet模型于2015年在論文《U-Net: Convolutional Networks for Biomedical Image Segmentation》中被提出,最初的提出是為了解決醫(yī)學(xué)圖像分割問題,用于細(xì)胞層面的圖像分割任務(wù)。
UNet模型是在FCN網(wǎng)絡(luò)的基礎(chǔ)上構(gòu)建的,但由于FCN無法獲取上下文信息以及位置信息,導(dǎo)致準(zhǔn)確性較低,UNet模型由此引入了U型結(jié)構(gòu)獲取上述兩種信息,并且模型結(jié)構(gòu)簡(jiǎn)單高效、容易構(gòu)建,在較小的數(shù)據(jù)集上也能實(shí)現(xiàn)較高的準(zhǔn)確率。
Paper:https://arxiv.org/abs/1505.04597
Code: https://github.com/Cjl-MedSeg/U-Net
1.1 模型結(jié)構(gòu)
UNet模型的整體結(jié)構(gòu)由兩部分組成,即特征提取網(wǎng)絡(luò)和特征融合網(wǎng)絡(luò),其結(jié)構(gòu)也被稱為“編碼器-解碼器結(jié)構(gòu)”,并且由于網(wǎng)絡(luò)整體結(jié)構(gòu)類似于大寫的英文字母“U”,故得名UNet,在其原始論文中定義的網(wǎng)絡(luò)結(jié)構(gòu)如圖1所示。
整個(gè)模型結(jié)構(gòu)就是在原始圖像輸入后,首先進(jìn)行特征提取,再進(jìn)行特征融合:
a) 左半部分負(fù)責(zé)特征提取的網(wǎng)絡(luò)結(jié)構(gòu)(即編碼器結(jié)構(gòu))需要利用兩個(gè)3x3的卷積核與2x2的池化層組成一個(gè)“下采樣模塊”,每一個(gè)下采樣模塊首先會(huì)對(duì)特征圖進(jìn)行兩次valid卷積,再進(jìn)行一次池化操作。由此經(jīng)過4個(gè)下采樣模塊后,原始尺寸為572x572大小、通道數(shù)為1的原始圖像,轉(zhuǎn)換為了大小為28x28、通道數(shù)為1024的特征圖。
b) 右半部分負(fù)責(zé)進(jìn)行上采樣的網(wǎng)絡(luò)結(jié)構(gòu)(即解碼器結(jié)構(gòu))需要利用1次反卷積操作、特征拼接操作以及兩個(gè)3x3的卷積核作為一個(gè)“上采樣模塊”,每一個(gè)上采樣模塊首先會(huì)對(duì)特征圖通過反卷積操作使圖像尺寸增加1倍,再通過拼接編碼器結(jié)構(gòu)中的特征圖使得通道數(shù)增加,最后經(jīng)過兩次valid卷積。由此經(jīng)過4個(gè)上采樣模塊后,經(jīng)過下采樣模塊的、大小為28x28、通道數(shù)為1024的特征圖,轉(zhuǎn)換為了大小為388x388、通道數(shù)為64的特征圖。
c) 網(wǎng)絡(luò)結(jié)構(gòu)的最后一部分是通過兩個(gè)1x1的卷積核將經(jīng)過上采樣得到的通道數(shù)為64的特征圖,轉(zhuǎn)換為了通道數(shù)為2的圖像作為預(yù)測(cè)結(jié)果輸出。
1.2 模型特點(diǎn)
a) 利用拼接操作將低級(jí)特征圖與高級(jí)特征圖進(jìn)行特征融合。
b) 完全對(duì)稱的U型結(jié)構(gòu)使得高分辨率信息和低分辨率信息在目標(biāo)圖片中增加,前后特征融合更為徹底。
c) 結(jié)合了下采樣時(shí)的低分辨率信息(提供物體類別識(shí)別依據(jù))和上采樣時(shí)的高分辨率信息(提供精準(zhǔn)分割定位依據(jù)),此外還通過融合操作填補(bǔ)底層信息以提高分割精度。
2. 案例實(shí)現(xiàn)
2.1 環(huán)境準(zhǔn)備與數(shù)據(jù)讀取
本案例基于MindSpore1.8.1 版本實(shí)現(xiàn),在CPU、GPU和Ascend上均可訓(xùn)練。
案例實(shí)現(xiàn)所使用的數(shù)據(jù)即ISBI果蠅電鏡圖數(shù)據(jù)集,可以從http://brainiac2.mit.edu/isbi_challenge/ 中下載,下載好的數(shù)據(jù)集包括3個(gè)tif文件,分別對(duì)應(yīng)測(cè)試集樣本、訓(xùn)練集標(biāo)簽、訓(xùn)練集樣本,文件路徑結(jié)構(gòu)如下:
.datasets/
└── ISBI
├── test-volume.tif
├── train-labels.tif
└── train-volume.tif
其中每個(gè)tif文件都由30副圖片壓縮而成,所以接下來需要獲取每個(gè)tif文件中所存儲(chǔ)的所有圖片,將其轉(zhuǎn)換為png格式存儲(chǔ),得到訓(xùn)練集樣本對(duì)應(yīng)的30張png圖片、訓(xùn)練集標(biāo)簽對(duì)應(yīng)的30張png圖片以及測(cè)試集樣本對(duì)應(yīng)的30張png圖片。
具體的實(shí)現(xiàn)方式首先是將tif文件轉(zhuǎn)換為數(shù)組形式,之后通過skimage操作將每張圖片對(duì)應(yīng)的數(shù)組存儲(chǔ)為png圖像,處理過后的訓(xùn)練集樣本及其對(duì)應(yīng)的標(biāo)簽圖像如圖2所示。將3個(gè)tif文件轉(zhuǎn)換為png格式后,針對(duì)訓(xùn)練集的樣本與標(biāo)簽,將其以2:1的比例,重新劃分為了訓(xùn)練集與驗(yàn)證集,劃分完成后的文件路徑結(jié)構(gòu)如下:
.datasets/
└── ISBI
├── test_imgs
│ ├── 00000.png
│ ├── 00001.png
│ └── . . . . .
├── train
│ ├── image
│ │ ├── 00001.png
│ │ ├── 00002.png
│ │ └── . . . . .
│ └── mask
│ ├── 00001.png
│ ├── 00002.png
│ └── . . . . .
└── val
├── image
│ ├── 00000.png
│ ├── 00003.png
│ └── . . . . .
└── mask
├── 00000.png
├── 00003.png
└── . . . . .
2.2 數(shù)據(jù)集創(chuàng)建
在進(jìn)行上述tif文件格式轉(zhuǎn)換,以及測(cè)試集和驗(yàn)證集的進(jìn)一步劃分后,就完成了數(shù)據(jù)讀取所需的所有工作,接下來就需要利用處理好的圖像數(shù)據(jù),通過一定的圖像變換來進(jìn)行數(shù)據(jù)增強(qiáng),并完成數(shù)據(jù)集的創(chuàng)建。
數(shù)據(jù)增強(qiáng)部分是引入了mindspore.dataset.vision,針對(duì)訓(xùn)練集樣本和標(biāo)簽,首先通過A.resize()方法將圖像尺寸重新調(diào)整為統(tǒng)一大小,之后再進(jìn)行轉(zhuǎn)置以及水平翻轉(zhuǎn)、垂直翻轉(zhuǎn),完成針對(duì)訓(xùn)練集樣本和標(biāo)簽的數(shù)據(jù)增強(qiáng)。針對(duì)驗(yàn)證集的樣本和標(biāo)簽,僅通過resize()方法將圖像尺寸重新調(diào)整為統(tǒng)一大小。
其次數(shù)據(jù)集的創(chuàng)建部分,首先是定義了Data_Loader類,在該類的__init__函數(shù)中,根據(jù)傳入的data_path參數(shù),確定在數(shù)據(jù)讀取階段設(shè)置好的、訓(xùn)練集和驗(yàn)證集的存儲(chǔ)路徑,再設(shè)置對(duì)應(yīng)的樣本和標(biāo)簽路徑,并針對(duì)訓(xùn)練集和驗(yàn)證集的不同數(shù)據(jù)增強(qiáng)方法。在該類的__getitem__函數(shù)中,通過傳入索引值讀取訓(xùn)練集或驗(yàn)證集存儲(chǔ)路徑下的樣本和標(biāo)簽圖像,并對(duì)圖像進(jìn)行對(duì)應(yīng)的數(shù)據(jù)增強(qiáng)操作,之后再對(duì)樣本和標(biāo)簽的形狀進(jìn)行轉(zhuǎn)置,就完成了__getitem__函數(shù)對(duì)樣本和標(biāo)簽圖像的讀取。最后通過定義create_dataset函數(shù),傳入data_dir、batch_size等參數(shù),在函數(shù)中實(shí)例化Data_Loader類獲取data_dir,也就是訓(xùn)練集或驗(yàn)證集對(duì)應(yīng)路徑下的樣本和標(biāo)簽元組對(duì),再通過mindspore.dataset中的GeneratorDataset將元組轉(zhuǎn)換為Tensor,最后通過設(shè)定好的batch_size將樣本和標(biāo)簽按照batch_size大小分組,由此完成數(shù)據(jù)集的創(chuàng)建,上述流程對(duì)應(yīng)代碼如下:
import os
import cv2
import mindspore.dataset as ds
import glob
import mindspore.dataset.vision as vision_C
import mindspore.dataset.transforms as C_transforms
import random
import mindspore
from mindspore.dataset.vision import Inter
def train_transforms(img_size):
return [
vision_C.Resize(img_size, interpolation=Inter.NEAREST),
vision_C.Rescale(1./255., 0.0),
vision_C.RandomHorizontalFlip(prob=0.5),
vision_C.RandomVerticalFlip(prob=0.5),
vision_C.HWC2CHW()
]
def val_transforms(img_size):
return [
vision_C.Resize(img_size, interpolation=Inter.NEAREST),
vision_C.Rescale(1/255., 0),
vision_C.HWC2CHW()
]
class Data_Loader:
def __init__(self, data_path):
# 初始化函數(shù),讀取所有data_path下的圖片
self.data_path = data_path
self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
self.label_path = glob.glob(os.path.join(data_path, 'mask/*.png'))
def __getitem__(self, index):
# 根據(jù)index讀取圖片
image = cv2.imread(self.imgs_path[index])
label = cv2.imread(self.label_path[index], cv2.IMREAD_GRAYSCALE)
label = label.reshape((label.shape[0], label.shape[1], 1))
return image, label
@property
def column_names(self):
column_names = ['image', 'label']
return column_names
def __len__(self):
# 返回訓(xùn)練集大小
return len(self.imgs_path)
def create_dataset(data_dir, img_size, batch_size, augment, shuffle):
mc_dataset = Data_Loader(data_path=data_dir)
dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, shuffle=shuffle)
if augment:
transform_img = train_transforms(img_size)
else:
transform_img = val_transforms(img_size)
seed = random.randint(1,1000)
mindspore.set_seed(seed)
dataset = dataset.map(input_columns='image', num_parallel_workers=1, operations=transform_img)
mindspore.set_seed(seed)
dataset = dataset.map(input_columns="label", num_parallel_workers=1, operations=transform_img)
if shuffle:
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size, num_parallel_workers=1)
if augment == True and shuffle == True:
print("訓(xùn)練集數(shù)據(jù)量:", len(mc_dataset))
elif augment == False and shuffle == False:
print("驗(yàn)證集數(shù)據(jù)量:", len(mc_dataset))
else:
pass
return dataset
2.3 模型構(gòu)建
本案例實(shí)現(xiàn)中所構(gòu)建的Unet模型結(jié)構(gòu)與2015年論文中提出的UNet結(jié)構(gòu)大致相同,但本案例中UNet網(wǎng)絡(luò)模型的“下采樣模塊”與“上采樣模塊”使用的卷積類型都為Same卷積,而原論文中使用的是Valid卷積。此外,原論文的網(wǎng)絡(luò)模型最終使用兩個(gè)1x1的卷積核,輸出了通道數(shù)2的預(yù)測(cè)圖像,而本案例的網(wǎng)絡(luò)模型最終使用的是1個(gè)1x1的卷積核,輸出通道數(shù)為1的灰度圖,和標(biāo)簽圖像格式保持一致。實(shí)際構(gòu)建的UNet模型結(jié)構(gòu)如圖3所示。
MindSpore框架構(gòu)建網(wǎng)絡(luò)的流程與PyTorch類似,在定義模型類時(shí)需要繼承Cell類,并重寫__init__和construct方法。具體的實(shí)現(xiàn)方式首先是定義了一個(gè)double_conv模型類,在類中重寫__init__方法,通過使用nn.Conv2d層定義“下采樣模塊”與“上采樣模塊”中都使用到的兩個(gè)卷積函數(shù),并且在每個(gè)卷積層后加入nn.BatchNorm2d層來對(duì)每次卷積后的特征圖進(jìn)行標(biāo)準(zhǔn)化,防止過擬合,以及使用nn.ReLU層加入非線性的激活函數(shù)。之后在construct方法中使用定義好的運(yùn)算構(gòu)建前向網(wǎng)絡(luò)。
在doubel_conv模型類定義好之后,接下來就是通過定義UNet模型類來完成整個(gè)UNet網(wǎng)絡(luò)的構(gòu)建。在UNet模型類的__init__方法中實(shí)例化double_conv類來表示兩個(gè)連續(xù)的卷積層,接著使用nn.MaxPool2d來進(jìn)行最大池化,由此完成了1個(gè)“下采樣模塊”的構(gòu)建,重復(fù)4次即可完成網(wǎng)絡(luò)中的編碼器部分。針對(duì)解碼器部分,使用了nn.ResizeBilinear層來表示反卷積層,接著實(shí)例化了double_conv類來表示兩個(gè)卷積層,由此完成了1個(gè)“上采樣模塊”的構(gòu)建,重復(fù)4次即完成網(wǎng)絡(luò)中解碼器部分的搭建。之后通過1個(gè)nn.Conv2d層來完成預(yù)測(cè)圖像的輸出。最后在construct方法中使用定義好的運(yùn)算構(gòu)建前向網(wǎng)絡(luò),由此完成整個(gè)UNet網(wǎng)絡(luò)模型的構(gòu)建。上述構(gòu)建流程的對(duì)應(yīng)代碼如下所示:
from mindspore import nn
import mindspore.numpy as np
import mindspore.ops as ops
import mindspore.ops.operations as F
def double_conv(in_ch, out_ch):
return nn.SequentialCell(nn.Conv2d(in_ch, out_ch, 3),
nn.BatchNorm2d(out_ch), nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3),
nn.BatchNorm2d(out_ch), nn.ReLU())
class UNet(nn.Cell):
def __init__(self, in_ch = 3, n_classes = 1):
super(UNet, self).__init__()
self.concat1 = F.Concat(axis=1)
self.concat2 = F.Concat(axis=1)
self.concat3 = F.Concat(axis=1)
self.concat4 = F.Concat(axis=1)
self.double_conv1 = double_conv(in_ch, 64)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.double_conv2 = double_conv(64, 128)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.double_conv3 = double_conv(128, 256)
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.double_conv4 = double_conv(256, 512)
self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.double_conv5 = double_conv(512, 1024)
self.upsample1 = nn.ResizeBilinear()
self.double_conv6 = double_conv(1024 + 512, 512)
self.upsample2 = nn.ResizeBilinear()
self.double_conv7 = double_conv(512 + 256, 256)
self.upsample3 = nn.ResizeBilinear()
self.double_conv8 = double_conv(256 + 128, 128)
self.upsample4 = nn.ResizeBilinear()
self.double_conv9 = double_conv(128 + 64, 64)
self.final = nn.Conv2d(64, n_classes, 1)
self.sigmoid = ops.Sigmoid()
def construct(self, x):
feature1 = self.double_conv1(x)
tmp = self.maxpool1(feature1)
feature2 = self.double_conv2(tmp)
tmp = self.maxpool2(feature2)
feature3 = self.double_conv3(tmp)
tmp = self.maxpool3(feature3)
feature4 = self.double_conv4(tmp)
tmp = self.maxpool4(feature4)
feature5 = self.double_conv5(tmp)
up_feature1 = self.upsample1(feature5, scale_factor=2)
tmp = self.concat1((feature4, up_feature1))
tmp = self.double_conv6(tmp)
up_feature2 = self.upsample2(tmp, scale_factor=2)
tmp = self.concat2((feature3, up_feature2))
tmp = self.double_conv7(tmp)
up_feature3 = self.upsample3(tmp, scale_factor=2)
tmp = self.concat3((feature2, up_feature3))
tmp = self.double_conv8(tmp)
up_feature4 = self.upsample4(tmp, scale_factor=2)
tmp = self.concat4((feature1, up_feature4))
tmp = self.double_conv9(tmp)
output = self.sigmoid(self.final(tmp))
return output
2.4 自定義評(píng)估指標(biāo)
為了能夠更加全面和直觀的觀察網(wǎng)絡(luò)模型訓(xùn)練效果,本案例實(shí)現(xiàn)中還使用了MindSpore框架來自定義Metrics,在自定義的metrics類中使用了多種評(píng)價(jià)函數(shù)來評(píng)估模型的好壞,分別為準(zhǔn)確率Acc、交并比IoU、Dice系數(shù)、靈敏度Sens、特異性Spec。
a) 其中準(zhǔn)確率Acc是圖像中正確分類的像素百分比。即分類正確的像素占總像素的比例,用公式可表示為:
A
c
c
=
T
P
+
T
N
T
P
+
T
N
+
F
P
+
F
N
A c c=\frac{T P+T N}{T P+T N+F P+F N}
Acc=TP+TN+FP+FNTP+TN?
其中:
- TP:真陽(yáng)性數(shù),在label中為陽(yáng)性,在預(yù)測(cè)值中也為陽(yáng)性的個(gè)數(shù)。
- TN:真陰性數(shù),在label中為陰性,在預(yù)測(cè)值中也為陰性的個(gè)數(shù)。
- FP:假陽(yáng)性數(shù),在label中為陰性,在預(yù)測(cè)值中為陽(yáng)性的個(gè)數(shù)。
- FN:假陰性數(shù),在label中為陽(yáng)性,在預(yù)測(cè)值中為陰性的個(gè)數(shù)。
b) 交并比IoU是預(yù)測(cè)分割和標(biāo)簽之間的重疊區(qū)域除以預(yù)測(cè)分割和標(biāo)簽之間的聯(lián)合區(qū)域(兩者的交集/兩者的并集),是語(yǔ)義分割中最常用的指標(biāo)之一,其計(jì)算公式為:
I
o
U
=
∣
A
∩
B
∣
∣
A
∪
B
∣
=
T
P
T
P
+
F
P
+
F
N
I o U=\frac{|A \cap B|}{|A \cup B|}=\frac{T P}{T P+F P+F N}
IoU=∣A∪B∣∣A∩B∣?=TP+FP+FNTP?
c) Dice系數(shù)定義為兩倍的交集除以像素和,也叫F1 score,與IoU呈正相關(guān)關(guān)系,其計(jì)算公式為:
?Dice?
=
2
∣
A
∩
B
∣
∣
A
∣
+
∣
B
∣
=
2
T
P
2
T
P
+
F
P
+
F
N
\text { Dice }=\frac{2|A \cap B|}{|A|+|B|}=\frac{2 T P}{2 T P+F P+F N}
?Dice?=∣A∣+∣B∣2∣A∩B∣?=2TP+FP+FN2TP?
d) 敏感度Sens和特異性Spec分別是描述識(shí)別出的陽(yáng)性占所有陽(yáng)性的比例,以及描述識(shí)別出的負(fù)例占所有負(fù)例的比例,計(jì)算公式分別為:
?Sens?
=
T
P
T
P
+
F
N
\text { Sens }=\frac{T P}{T P+F N}
?Sens?=TP+FNTP?
?Spec? = T N F P + T N \text { Spec }=\frac{T N}{F P+T N} ?Spec?=FP+TNTN?
具體的實(shí)現(xiàn)方法首先是自定義metrics_類,并按照MindSpore官方文檔繼承nn.Metric父類,接著根據(jù)上述5個(gè)評(píng)價(jià)指標(biāo)的計(jì)算公式,在類中定義5個(gè)指標(biāo)的計(jì)算方法,之后通過重新實(shí)現(xiàn)clear方法來初始化相關(guān)參數(shù);重新實(shí)現(xiàn)update方法來傳入模型預(yù)測(cè)值和標(biāo)簽,通過上述定義的各評(píng)價(jià)指標(biāo)計(jì)算方法,計(jì)算每個(gè)指標(biāo)的值并存入一個(gè)列表;最后通過重新實(shí)現(xiàn)eval方法來講存儲(chǔ)各評(píng)估指標(biāo)值的列表返回。上述流程對(duì)應(yīng)的代碼如下:
import numpy as np
from mindspore._checkparam import Validator as validator
from mindspore.nn import Metric
from mindspore import Tensor
class metrics_(Metric):
def __init__(self, metrics, smooth=1e-5):
super(metrics_, self).__init__()
self.metrics = metrics
self.smooth = validator.check_positive_float(smooth, "smooth")
self.metrics_list = [0. for i in range(len(self.metrics))]
self._samples_num = 0
self.clear()
def Acc_metrics(self,y_pred, y):
tp = np.sum(y_pred.flatten() == y.flatten(), dtype=y_pred.dtype)
total = len(y_pred.flatten())
single_acc = float(tp) / float(total)
return single_acc
def IoU_metrics(self,y_pred, y):
intersection = np.sum(y_pred.flatten() * y.flatten())
unionset = np.sum(y_pred.flatten() + y.flatten()) - intersection
single_iou = float(intersection) / float(unionset + self.smooth)
return single_iou
def Dice_metrics(self,y_pred, y):
intersection = np.sum(y_pred.flatten() * y.flatten())
unionset = np.sum(y_pred.flatten()) + np.sum(y.flatten())
single_dice = 2*float(intersection) / float(unionset + self.smooth)
return single_dice
def Sens_metrics(self,y_pred, y):
tp = np.sum(y_pred.flatten() * y.flatten())
actual_positives = np.sum(y.flatten())
single_sens = float(tp) / float(actual_positives + self.smooth)
return single_sens
def Spec_metrics(self,y_pred, y):
true_neg = np.sum((1 - y.flatten()) * (1 - y_pred.flatten()))
total_neg = np.sum((1 - y.flatten()))
single_spec = float(true_neg) / float(total_neg + self.smooth)
return single_spec
def clear(self):
"""Clears the internal evaluation result."""
self.metrics_list = [0. for i in range(len(self.metrics))]
self._samples_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError("For 'update', it needs 2 inputs (predicted value, true value), ""but got {}.".format(len(inputs)))
y_pred = Tensor(inputs[0]).asnumpy() #modelarts,cpu
# y_pred = np.array(Tensor(inputs[0])) #cpu
y_pred[y_pred > 0.5] = float(1)
y_pred[y_pred <= 0.5] = float(0)
y = Tensor(inputs[1]).asnumpy()
self._samples_num += y.shape[0]
if y_pred.shape != y.shape:
raise ValueError(f"For 'update', predicted value (input[0]) and true value (input[1]) "
f"should have same shape, but got predicted value shape: {y_pred.shape}, "
f"true value shape: {y.shape}.")
for i in range(y.shape[0]):
if "acc" in self.metrics:
single_acc = self.Acc_metrics(y_pred[i], y[i])
self.metrics_list[0] += single_acc
if "iou" in self.metrics:
single_iou = self.IoU_metrics(y_pred[i], y[i])
self.metrics_list[1] += single_iou
if "dice" in self.metrics:
single_dice = self.Dice_metrics(y_pred[i], y[i])
self.metrics_list[2] += single_dice
if "sens" in self.metrics:
single_sens = self.Sens_metrics(y_pred[i], y[i])
self.metrics_list[3] += single_sens
if "spec" in self.metrics:
single_spec = self.Spec_metrics(y_pred[i], y[i])
self.metrics_list[4] += single_spec
def eval(self):
if self._samples_num == 0:
raise RuntimeError("The 'metrics' can not be calculated, because the number of samples is 0, "
"please check whether your inputs(predicted value, true value) are empty, or has "
"called update method before calling eval method.")
for i in range(len(self.metrics_list)):
self.metrics_list[i] = self.metrics_list[i] / float(self._samples_num)
return self.metrics_list
測(cè)試metrics:
x = Tensor(np.array([[[[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.8]]]]))
y = Tensor(np.array([[[[0, 1, 1], [1, 0, 0], [0, 1, 1]]]]))
metric = metrics_(["acc", "iou", "dice", "sens", "spec"],smooth=1e-5)
metric.clear()
metric.update(x, y)
res = metric.eval()
print( '丨acc: %.4f丨丨iou: %.4f丨丨dice: %.4f丨丨sens: %.4f丨丨spec: %.4f丨' % (res[0], res[1], res[2], res[3],res[4]), flush=True)
丨acc: 0.6667丨丨iou: 0.5000丨丨dice: 0.6667丨丨sens: 0.6000丨丨spec: 0.7500丨
2.5 模型訓(xùn)練及評(píng)估
在模型訓(xùn)練時(shí),通過2.1節(jié)中自定義的create_dataset方法創(chuàng)建了訓(xùn)練集和驗(yàn)證集,圖像尺寸統(tǒng)一調(diào)整為224x224;損失函數(shù)使用nn.BCELoss,優(yōu)化器使用nn.Adam。實(shí)現(xiàn)計(jì)算每個(gè)epoch結(jié)束后,在2.4節(jié)中定義的5個(gè)評(píng)估指標(biāo),并保存當(dāng)前最優(yōu)模型。
模型訓(xùn)練部分的代碼如下:
import mindspore.nn as nn
from mindspore import ops
import mindspore
from mindspore import ms_function
import ml_collections
def get_config():
"""configuration """
config = ml_collections.ConfigDict()
config.epochs = 100
config.train_data_path = "src/datasets/ISBI/train/"
config.val_data_path = "src/datasets/ISBI/val/"
config.imgsize = 224
config.batch_size = 4
config.pretrained_path = None
config.in_channel = 3
config.n_classes = 1
config.lr = 0.0001
return config
cfg = get_config()
train_dataset = create_dataset(cfg.train_data_path, img_size=cfg.imgsize, batch_size= cfg.batch_size, augment=True, shuffle = True)
val_dataset = create_dataset(cfg.val_data_path, img_size=cfg.imgsize, batch_size= cfg.batch_size, augment=False, shuffle = False)
def train(model, dataset, loss_fn, optimizer, met):
# Define forward function
def forward_fn(data, label):
logits = model(data)
loss = loss_fn(logits, label)
return loss, logits
# Get gradient function
grad_fn = ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
# Define function of one-step training
@ms_function
def train_step(data, label):
(loss, logits), grads = grad_fn(data, label)
loss = ops.depend(loss, optimizer(grads))
return loss, logits
size = dataset.get_dataset_size()
model.set_train(True)
train_loss = 0
train_pred = []
train_label = []
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
loss, logits = train_step(data, label)
train_loss += loss.asnumpy()
train_pred.extend(logits.asnumpy())
train_label.extend(label.asnumpy())
train_loss /= size
metric = metrics_(met, smooth=1e-5)
metric.clear()
metric.update(train_pred, train_label)
res = metric.eval()
print(f'Train loss:{train_loss:>4f}','丨acc: %.3f丨丨iou: %.3f丨丨dice: %.3f丨丨sens: %.3f丨丨spec: %.3f丨' % (res[0], res[1], res[2], res[3], res[4]))
def val(model, dataset, loss_fn, met):
size = dataset.get_dataset_size()
model.set_train(False)
val_loss = 0
val_pred = []
val_label = []
for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
pred = model(data)
val_loss += loss_fn(pred, label).asnumpy()
val_pred.extend(pred.asnumpy())
val_label.extend(label.asnumpy())
val_loss /= size
metric = metrics_(met, smooth=1e-5)
metric.clear()
metric.update(val_pred, val_label)
res = metric.eval()
print(f'Val loss:{val_loss:>4f}','丨acc: %.3f丨丨iou: %.3f丨丨dice: %.3f丨丨sens: %.3f丨丨spec: %.3f丨' % (res[0], res[1], res[2], res[3], res[4]))
checkpoint = res[1]
return checkpoint, res[4]
net = UNet(cfg.in_channel, cfg.n_classes)
criterion = nn.BCEWithLogitsLoss()
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=cfg.lr)
iters_per_epoch = train_dataset.get_dataset_size()
total_train_steps = iters_per_epoch * cfg.epochs
print('iters_per_epoch: ', iters_per_epoch)
print('total_train_steps: ', total_train_steps)
metrics_name = ["acc", "iou", "dice", "sens", "spec"]
best_iou = 0
ckpt_path = 'checkpoint/best_UNet.ckpt'
for epoch in range(cfg.epochs):
print(f"Epoch [{epoch+1} / {cfg.epochs}]")
train(net, train_dataset, criterion, optimizer, metrics_name)
checkpoint_best, spec = val(net, val_dataset, criterion, metrics_name)
if epoch > 2 and spec > 0.2:
if checkpoint_best > best_iou:
print('IoU improved from %0.4f to %0.4f' % (best_iou, checkpoint_best))
best_iou = checkpoint_best
mindspore.save_checkpoint(net, ckpt_path)
print("saving best checkpoint at: {} ".format(ckpt_path))
else:
print('IoU did not improve from %0.4f' % (best_iou),"\n-------------------------------")
print("Done!")
2.6 模型預(yù)測(cè)
代碼如下:文章來源:http://www.zghlxwxcb.cn/news/detail-449966.html
import os
import cv2
import mindspore.dataset as ds
import glob
import mindspore.dataset.vision as vision_C
import mindspore.dataset.transforms as C_transforms
import random
import mindspore
from mindspore.dataset.vision import Inter
import numpy as np
from tqdm import tqdm
def val_transforms(img_size):
return C_transforms.Compose([
vision_C.Resize(img_size, interpolation=Inter.NEAREST),
vision_C.Rescale(1/255., 0),
vision_C.HWC2CHW()
])
class Data_Loader:
def __init__(self, data_path, have_mask):
# 初始化函數(shù),讀取所有data_path下的圖片
self.data_path = data_path
self.have_mask = have_mask
self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
if self.have_mask:
self.label_path = glob.glob(os.path.join(data_path, 'mask/*.png'))
def __getitem__(self, index):
# 根據(jù)index讀取圖片
image = cv2.imread(self.imgs_path[index])
if self.have_mask:
label = cv2.imread(self.label_path[index], cv2.IMREAD_GRAYSCALE)
label = label.reshape((label.shape[0], label.shape[1], 1))
else:
label = image
return image, label
@property
def column_names(self):
column_names = ['image', 'label']
return column_names
def __len__(self):
return len(self.imgs_path)
def create_dataset(data_dir, img_size, batch_size, shuffle, have_mask = False):
mc_dataset = Data_Loader(data_path=data_dir, have_mask = have_mask)
print(len(mc_dataset))
dataset = ds.GeneratorDataset(mc_dataset, mc_dataset.column_names, shuffle=shuffle)
transform_img = val_transforms(img_size)
seed = random.randint(1, 1000)
mindspore.set_seed(seed)
dataset = dataset.map(input_columns='image', num_parallel_workers=1, operations=transform_img)
mindspore.set_seed(seed)
dataset = dataset.map(input_columns="label", num_parallel_workers=1, operations=transform_img)
dataset = dataset.batch(batch_size, num_parallel_workers=1)
return dataset
def model_pred(model, test_loader, result_path, have_mask):
model.set_train(False)
test_pred = []
test_label = []
for batch, (data, label) in enumerate(test_loader.create_tuple_iterator()):
pred = model(data)
pred[pred > 0.5] = float(1)
pred[pred <= 0.5] = float(0)
preds = np.squeeze(pred, axis=0)
img = np.transpose(preds,(1, 2, 0))
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite(os.path.join(result_path, "%05d.png" % batch), img.asnumpy()*255.)
test_pred.extend(pred.asnumpy())
test_label.extend(label.asnumpy())
if have_mask:
mtr = ['acc', 'iou', 'dice', 'sens', 'spec']
metric = metrics_(mtr, smooth=1e-5)
metric.clear()
metric.update(test_pred, test_label)
res = metric.eval()
print(f'丨acc: %.3f丨丨iou: %.3f丨丨dice: %.3f丨丨sens: %.3f丨丨spec: %.3f丨' % (res[0], res[1], res[2], res[3], res[4]))
else:
print("Evaluation metrics cannot be calculated without Mask")
if __name__ == '__main__':
net = UNet(3, 1)
mindspore.load_checkpoint("best_UNet.ckpt", net=net)
result_path = "predict"
test_dataset = create_dataset("datasets/ISBI/test/", 224, 1, shuffle=False, have_mask=False)
model_pred(net, test_dataset, result_path, have_mask=False)
3. 總結(jié)
本案例基于MindSpore框架針對(duì)ISBI數(shù)據(jù)集,完成了數(shù)據(jù)讀取、數(shù)據(jù)集創(chuàng)建、UNet模型構(gòu)建,并根據(jù)特定需求自定義了評(píng)估指標(biāo)和回調(diào)函數(shù),進(jìn)行了模型訓(xùn)練和評(píng)估,順利完成了預(yù)測(cè)結(jié)果的輸出。通過此案例進(jìn)一步加深了對(duì)UNet模型結(jié)構(gòu)和特性的理解,并結(jié)合MindSpore框架提供的文檔和教程,掌握了利用Mindspore框架實(shí)現(xiàn)特定案例的流程,以及多種API的使用方法,為以后在實(shí)際場(chǎng)景中應(yīng)用MindSpore框架提供支持。文章來源地址http://www.zghlxwxcb.cn/news/detail-449966.html
到了這里,關(guān)于基于MindSpore復(fù)現(xiàn)UNet—語(yǔ)義分割的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!