返回至系列文章導(dǎo)航博客

完整項目下載:下載鏈接
【閑魚】https://m.tb.cn/h.52C8psW?tk=fMpwdwfqjz3 CZ3457 「我在閑魚發(fā)布了【舌象數(shù)據(jù)集,詳情見csdn!http://t.csdn.cn】」
點擊鏈接直接打開
1 簡介
舌體分割是舌診檢測的基礎(chǔ),唯有做到準確分割舌體才能保證后續(xù)訓(xùn)練以及預(yù)測的準確性。此部分真正的任務(wù)是在用戶上傳的圖像中準確尋找到屬于舌頭的像素點。舌體分割屬于生物醫(yī)學(xué)圖像分割領(lǐng)域。分割效果如下:
2 數(shù)據(jù)集介紹
舌象數(shù)據(jù)集包含舌象原圖以及分割完成的二元圖,共979*2張,示例圖片如下:
3 模型介紹
U-Net是一個優(yōu)秀的語義分割模型,在中e診中U-Net共三部分,分別是主干特征提取部分、加強特征提取部分、預(yù)測部分。利用主干特征提取部分獲得5個初步有效的特征層,之后通過加強特征提取部分對上述獲取到的5個有效特征層進行上采樣并進行特征融合。最終獲得了一個結(jié)合所有特征的有效特征層,并利用最終有效特征層對像素點進行預(yù)測,找到屬于舌體的像素點。具體操作詳情如下圖所示:
進行標注后利用PyTorch框架構(gòu)建U-Net模型抓取舌象圖像特征,預(yù)測舌象圖像標簽。為對模型進行評價,在訓(xùn)練中計算每次循環(huán)的平均損失率。最終每張圖的損失了約為2%左右。具體的平均損失率變化如下圖:
訓(xùn)練共歷時4天,共979張標記圖像,最終平均預(yù)測損失率約為2%。模型預(yù)測,即舌體分割的效果非常理想,在此展示當損失率為40%與損失率為2%時的分割結(jié)果示例,示例如下圖所示:
(1)損失率為40%時分割結(jié)果圖
(2)損失率為2%時分割結(jié)果圖
根據(jù)模型預(yù)測結(jié)果對屬于舌體的像素點進行匹配提取,將不屬于舌體的部分以墨綠色進行填充,最終的舌體分割效果圖如下:
4 代碼實現(xiàn)細節(jié)
4.1 相關(guān)文件介紹
notedata文件夾中有分割標注圖片、ordata文件夾中有原始圖片、params文件夾中有訓(xùn)練模型文件、result文件夾中有測試樣例圖片、train_image文件夾中有訓(xùn)練過程圖片。
4.2 utils.py
工具類:由于數(shù)據(jù)集中各個圖片的大小是不一樣的,為了保障后續(xù)工作可以順利進行,這里應(yīng)該定義一個工具類將圖片可以等比例縮放至256*256(可以改看自己需求)。
from PIL import Image
def keep_image_size_open(path, size=(256, 256)):
img = Image.open(path)
temp = max(img.size)
mask = Image.new('RGB', (temp, temp), (0,0,0))
mask.paste(img, (0,0))
mask = mask.resize(size)
return mask
4.3 data.py
這里主要是將數(shù)據(jù)集中標簽圖片與原圖進行匹配合并~具體步驟代碼注釋中有詳解!
import os
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor()
])
class MyDataset(Dataset):
def __init__(self, path): #拿到標簽文件夾中圖片的名字
self.path = path
self.name = os.listdir(os.path.join(path, 'notedata'))
def __len__(self): #計算標簽文件中文件名的數(shù)量
return len(self.name)
def __getitem__(self, index): #將標簽文件夾中的文件名在原圖文件夾中進行匹配(由于標簽是png的格式而原圖是jpg所以需要進行一個轉(zhuǎn)化)
segment_name = self.name[index] #XX.png
segment_path = os.path.join(self.path, 'notedata', segment_name)
image_path = os.path.join(self.path, 'ordata', segment_name.replace('png', 'jpg')) #png與jpg進行轉(zhuǎn)化
segment_image = keep_image_size_open(segment_path) #等比例縮放
image = keep_image_size_open(image_path) #等比例縮放
return transform(image), transform(segment_image)
if __name__ == "__main__":
data = MyDataset("E:/ITEM_TIME/project/UNET/")
print(data[0][0].shape)
print(data[0][1].shape)
可見數(shù)據(jù)集已經(jīng)規(guī)整!
4.4 net.py
Unet網(wǎng)絡(luò)的編寫!
from torch import nn
import torch
from torch.nn import functional as F
class Conv_Block(nn.Module): #卷積
def __init__(self, in_channel, out_channel):
super(Conv_Block, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(in_channel, out_channel, 3, 1, 1, padding_mode='reflect',
bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.LeakyReLU(),
nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode='reflect',
bias=False),
nn.BatchNorm2d(out_channel),
nn.Dropout2d(0.3),
nn.LeakyReLU()
)
def forward(self, x):
return self.layer(x)
class DownSample(nn.Module): #下采樣
def __init__(self, channel):
super(DownSample, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(channel, channel,3,2,1,padding_mode='reflect',
bias=False),
nn.BatchNorm2d(channel),
nn.LeakyReLU()
)
def forward(self,x):
return self.layer(x)
class UpSample(nn.Module): #上采樣(最鄰近插值法)
def __init__(self, channel):
super(UpSample, self).__init__()
self.layer = nn.Conv2d(channel, channel//2,1,1)
def forward(self,x, feature_map):
up = F.interpolate(x, scale_factor=2, mode='nearest')
out = self.layer(up)
return torch.cat((out,feature_map),dim=1)
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.c1=Conv_Block(3,64)
self.d1=DownSample(64)
self.c2=Conv_Block(64, 128)
self.d2=DownSample(128)
self.c3=Conv_Block(128,256)
self.d3=DownSample(256)
self.c4=Conv_Block(256,512)
self.d4=DownSample(512)
self.c5=Conv_Block(512,1024)
self.u1=UpSample(1024)
self.c6=Conv_Block(1024,512)
self.u2=UpSample(512)
self.c7=Conv_Block(512,256)
self.u3=UpSample(256)
self.c8=Conv_Block(256,128)
self.u4=UpSample(128)
self.c9=Conv_Block(128,64)
self.out = nn.Conv2d(64,3,3,1,1)
self.Th = nn.Sigmoid()
def forward(self,x):
R1 = self.c1(x)
R2 = self.c2(self.d1(R1))
R3 = self.c3(self.d2(R2))
R4 = self.c4(self.d3(R3))
R5 = self.c5(self.d4(R4))
O1 = self.c6(self.u1(R5,R4))
O2 = self.c7(self.u2(O1,R3))
O3 = self.c8(self.u3(O2,R2))
O4 = self.c9(self.u4(O3,R1))
return self.Th(self.out(O4))
if __name__ == "__main__":
x = torch.randn(2, 3, 256, 256)
net = UNet()
print(net(x).shape)
結(jié)果匹配說明沒問題~
4.5 train.py
訓(xùn)練代碼~
from torch import nn
from torch import optim
import torch
from data import *
from net import *
from torchvision.utils import save_image
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'
data_path = 'E:/ITEM_TIME/project/UNET/'
save_path = 'train_image'
if __name__ == "__main__":
dic = []###
data_loader = DataLoader(MyDataset(data_path),batch_size=3,shuffle=True) #batch_size用3/4都可以看電腦性能
net = UNet().to(device)
if os.path.exists(weight_path):
net.load_state_dict(torch.load(weight_path))
print('success load weight')
else:
print('not success load weight')
opt = optim.Adam(net.parameters())
loss_fun = nn.BCELoss()
epoch = 1
while True:
avg = []###
for i, (image,segment_image) in enumerate(data_loader):
image,segment_image = image.to(device),segment_image.to(device)
out_image = net(image)
train_loss = loss_fun(out_image, segment_image)
opt.zero_grad()
train_loss.backward()
opt.step()
if i%5 == 0:
print('{}-{}-train_loss===>>{}'.format(epoch,i,train_loss.item()))
if i%50 == 0:
torch.save(net.state_dict(), weight_path)
#為方便看效果將原圖、標簽圖、訓(xùn)練圖進行拼接
_image = image[0]
_segment_image = segment_image[0]
_out_image = out_image[0]
img = torch.stack([_image,_segment_image,_out_image],dim=0)
save_image(img, f'{save_path}/{i}.jpg')
avg.append(float(train_loss.item()))###
loss_avg = sum(avg)/len(avg)
dic.append(loss_avg)
epoch += 1
print(dic)
可見代碼成功運行~上面的損失率是在訓(xùn)練4天后的效果,剛開始肯定很大很差,需要有耐心!
4.6 test.py
測試代碼,對圖片進行智能分割~
from net import *
from utils import keep_image_size_open
import os
import torch
from data import *
from torchvision.utils import save_image
from PIL import Image
import numpy as np
net = UNet().cpu() #或者放在cuda上
weights = 'params/unet.pth' #導(dǎo)入網(wǎng)絡(luò)
if os.path.exists(weights):
net.load_state_dict(torch.load(weights))
print('success')
else:
print('no loading')
_input = 'xxxx.jpg' #導(dǎo)入測試圖片
img = keep_image_size_open(_input)
img_data = transform(img)
print(img_data.shape)
img_data = torch.unsqueeze(img_data, dim=0)
print(img_data)
out = net(img_data)
save_image(out, 'result/result.jpg')
save_image(img_data, 'result/orininal.jpg')
print(out)
#E:\ITEM_TIME\UNET\ordata\4292.jpg
img_after = Image.open(r"result\result.jpg")
img_before = Image.open(r"result\orininal.jpg")
#img.show()
img_after_array = np.array(img_after)#把圖像轉(zhuǎn)成數(shù)組格式img = np.asarray(image)
img_before_array = np.array(img_before)
shape_after = img_after_array.shape
shape_before = img_before_array.shape
print(shape_after,shape_before)
#將分隔好的圖片進行對應(yīng)像素點還原,即將黑白分隔圖轉(zhuǎn)化為有顏色的提取圖
if shape_after == shape_before:
height = shape_after[0]
width = shape_after[1]
dst = np.zeros((height,width,3))
for h in range(0,height):
for w in range (0,width):
(b1,g1,r1) = img_after_array[h,w]
(b2,g2,r2) = img_before_array[h,w]
if (b1, g1, r1) <= (90, 90, 90):
img_before_array[h, w] = (144,238,144)
dst[h,w] = img_before_array[h,w]
img2 = Image.fromarray(np.uint8(dst))
img2.save(r"result\blend.png","png")
else:
print("失??!")
結(jié)果展示:
(1)原圖(orininal.jpg):
(2)模型分割圖(result.jpg):
(3)對應(yīng)像素點還原圖(blend.png):就是將(2)中的圖白色的部分用原圖像素點填充,黑色的部分用綠色填充
至此,舌體分割完成!文章來源:http://www.zghlxwxcb.cn/news/detail-665048.html
文章來源地址http://www.zghlxwxcb.cn/news/detail-665048.html
到了這里,關(guān)于計算機視覺智能中醫(yī)(三):基于Unet模型的舌頭舌體圖片分割的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!