前言
- 博客很久沒有更新了,今天就來更新一篇博客吧,哈哈;
- 最近在做圖像分割相關(guān)的任務(wù),因此,寫這么一篇博客來簡單實現(xiàn)一下分割是怎么做的,內(nèi)容簡單,枯燥,需要耐心看,哈哈;
- 博客的內(nèi)容相對簡單,比較適合剛接觸分割的同學(xué)參考學(xué)習(xí)(這篇博客在算法訓(xùn)練上沒有涉及到訓(xùn)練策略、數(shù)據(jù)增強方法,特意留下余地處給大家自行發(fā)揮)
內(nèi)容簡介
- U2Net算法介紹
- 本博客訓(xùn)練效果截圖展示
- 本博客代碼框架介紹
- 數(shù)據(jù)集數(shù)據(jù)集準(zhǔn)備
- 自定義dataset
- u2net、u2netp網(wǎng)絡(luò)結(jié)構(gòu)定義
- 訓(xùn)練代碼
- 模型推理代碼
- 總結(jié)以及博客代碼的Github地址
U2Net算法介紹
- 關(guān)于算法介紹,CSDN上很多大神有詳細的解讀,大家可自行去搜索閱讀學(xué)習(xí),本博客目的是實操,所以此處省略上千字,哈哈
- 官方論文地址:https://arxiv.org/pdf/2005.09007.pdf
- 官方Github repo 地址:https://github.com/xuebinqin/U-2-Net
本博客代碼訓(xùn)練效果截圖展示
- 任務(wù)圖片分割結(jié)果可視化展示
- 如上圖所示,模型在測試集上的推理效果(左上為原始標(biāo)注mask,左下為預(yù)測的mask,右邊圖像為原始圖片)可以看出,模型的效果還是比較理想的;
代碼框架介紹
- 項目的整體框架如下圖所示
- 第一個Folder :backup
backup為訓(xùn)練過程模型的保存的folder,在訓(xùn)練過程中,代碼會自動在該目錄下生成文件夾,并保存訓(xùn)練過程的權(quán)重pth文件
- 第二個Folder: dataset
dataset目錄為訓(xùn)練數(shù)據(jù)集存放的目錄包括了參與訓(xùn)練的原始圖片、以及對應(yīng)的標(biāo)注mask,訓(xùn)練數(shù)據(jù)集的組成方式由圖片由如下的方式組成:
-images
-train
-0.jpg
-1.jpg
-....
-test
-0.jpg
-1.jpg
-....
-val
-0.jpg
-1.jpg
-....
-masks
-train
-0.jpg
-1.jpg
-....
-test
-0.jpg
-1.jpg
-....
-val
-0.jpg
-1.jpg
-....
- 第三個Folder:src
src文件夾下有兩個文件,一個是網(wǎng)絡(luò)模型的定義文件u2net.py,另一個為自定義的dataset.py
- train_u2net.py文件: 模型訓(xùn)練代碼
- inference_u2net.py文件: 模型的推理代碼
訓(xùn)練數(shù)據(jù)集準(zhǔn)備
- 請參考上一章節(jié)中dataset Folder的描述方式來準(zhǔn)備您的訓(xùn)練數(shù)據(jù)集;
- 注意:請保持原始圖片和mask圖片的命名一致,若不一致的話,需自行修改調(diào)整dataset代碼部分
自定義dataset
- 一般來說dataset的組成部分有核心的兩個
__getitiem__ 方法 (根據(jù)索引返回樣本數(shù)據(jù))
__len__ 方法 (返回數(shù)據(jù)集中樣本的個數(shù))
(注意:本博客中dataset類中,未寫數(shù)據(jù)增強部分,特意給大家留下空間自行學(xué)習(xí)和發(fā)揮)
- 根據(jù)上述描述,接下來我們開始自定義dataset
src/seg_dataset.py
# coding: utf-8
# author: hxy
# 2022-04-20
"""
數(shù)據(jù)讀取dataset
"""
import os
import cv2
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
# dataset for u2net
class U2netSegDataset(Dataset):
def __init__(self, img_dir, mask_dir, input_size=(320, 320)):
"""
:param img_dir: 數(shù)據(jù)集圖片文件夾路徑
:param mask_dir: 數(shù)據(jù)集mask文件夾路徑
:param input_size: 圖片輸入的尺寸
"""
self.img_dir = img_dir
self.mask_dir = mask_dir
self.input_size = input_size
self.samples = list()
self.gt_mask = list()
self.std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
self.mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
self.load_data()
def __len__(self):
return len(self.samples)
def load_data(self):
img_dir_full_path = self.img_dir
mask_dir_full_path = self.mask_dir
img_files = os.listdir(img_dir_full_path)
for img_name in tqdm(img_files):
img_full_path = os.path.join(img_dir_full_path, img_name)
mask_full_path = os.path.join(mask_dir_full_path, img_name)
img = cv2.imread(img_full_path)
img = cv2.resize(img, self.input_size)
img2rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img2norm = (img2rgb - self.mean) / self.std
# 圖像格式改為nchw
img2nchw = np.transpose(img2norm, [2, 0, 1]).astype(np.float32)
gt_mask = cv2.imread(mask_full_path)
gt_mask = cv2.resize(gt_mask, self.input_size)
gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
gt_mask = gt_mask / 255.
gt_mask = np.expand_dims(gt_mask, axis=0)
self.samples.append(img2nchw)
self.gt_mask.append(gt_mask)
return self.samples, self.gt_mask
def __getitem__(self, index):
img = self.samples[index]
mask = self.gt_mask[index]
return img, mask
上面的代碼塊簡單描述一下: 用os模塊遍歷文件夾,獲取所有文件的名字,并將他們的全部路徑拼接起來,opencv讀取,然后對讀取的照片array做預(yù)處理(resize、歸一化、通道轉(zhuǎn)換),最后將預(yù)處理好的圖片append到對應(yīng)的list中去即可;
u2net、u2netp網(wǎng)絡(luò)結(jié)構(gòu)定義
- 網(wǎng)絡(luò)結(jié)構(gòu)的定義, 該部分代碼是直接從源repo中copy過來的,所以直接貼在下來供大家參考使用;
src/u2net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
# src = F.upsample(src, size=tar.shape[2:], mode='bilinear') # old version torch
src = F.upsample(src, size=tar.shape[2:], mode='bilinear', align_corners=True)
return src
### RSU-7 ###
class RSU7(nn.Module): # UNet07DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Module): # UNet06DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Module): # UNet05DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Module): # UNet04DRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Module): # UNet04FRES(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NET, self).__init__()
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# -------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
### U^2-Net small ###
class U2NETP(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NETP, self).__init__()
self.stage1 = RSU7(in_ch, 16, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 16, 64)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(64, 16, 64)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(64, 16, 64)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(64, 16, 64)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(64, 16, 64)
# decoder
self.stage5d = RSU4F(128, 16, 64)
self.stage4d = RSU4(128, 16, 64)
self.stage3d = RSU5(128, 16, 64)
self.stage2d = RSU6(128, 16, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(64, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
def forward(self, x):
hx = x
# stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
# stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
# stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
# stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
# stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
# stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
# decoder
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
# side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)
訓(xùn)練代碼
- 訓(xùn)練代碼
深度學(xué)習(xí)訓(xùn)練代碼的一般流程是: 模型定義 -> 數(shù)據(jù)加載 -> 模型訓(xùn)練 ->模型驗證
本博客中訓(xùn)練代碼的實現(xiàn)邏輯如下:
1 定義網(wǎng)絡(luò)
2 加載數(shù)據(jù)
3 定義損失函數(shù)和優(yōu)化器
4 開始訓(xùn)練
- 訓(xùn)練網(wǎng)絡(luò)
- 將梯度置為0
- 求loss
- 反向傳播
- 更新參數(shù)
(在本博客的訓(xùn)練代碼中未寫驗證部分代碼,留給各位同學(xué)自行實現(xiàn))
** train_u2net.py**
# coding: utf-8
# author: hxy
# 20220420
"""
訓(xùn)練代碼:u2net、u2netp
train it from scratch.
"""
import os
import datetime
import torch
import numpy as np
from tqdm import tqdm
from src.u2net import U2NET, U2NETP
from src.seg_dataset import U2netSegDataset
from torch.utils.data import DataLoader
# 參考u2net源碼loss的設(shè)定
bce_loss = torch.nn.BCELoss(reduction='mean')
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0, labels_v)
loss1 = bce_loss(d1, labels_v)
loss2 = bce_loss(d2, labels_v)
loss3 = bce_loss(d3, labels_v)
loss4 = bce_loss(d4, labels_v)
loss5 = bce_loss(d5, labels_v)
loss6 = bce_loss(d6, labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
# print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
# loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
# loss6.data.item()))
return loss0, loss
def load_data(img_folder, mask_folder, batch_size, num_workers, input_size):
"""
:param img_folder: 圖片保存的fodler
:param mask_folder: mask保存的fodler
:param batch_size: batch_size的設(shè)定
:param num_workers: 數(shù)據(jù)加載cpu核心數(shù)
:param input_size: 模型輸入尺寸
:return:
"""
train_dataset = U2netSegDataset(img_dir=os.path.join(img_folder, 'train'),
mask_dir=os.path.join(mask_folder, 'train'),
input_size=input_size)
val_dataset = U2netSegDataset(img_dir=os.path.join(img_folder, 'val'),
mask_dir=os.path.join(mask_folder, 'val'),
input_size=input_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader
def train_model(epoch_nums, cuda_device, model_save_dir):
"""
:param epoch_nums: 訓(xùn)練總的epoch
:param cuda_device: 指定gpu訓(xùn)練
:param model_save_dir: 模型保存folder
:return:
"""
current_time = datetime.datetime.now()
current_time = datetime.datetime.strftime(current_time, '%Y-%m-%d-%H:%M')
model_save_dir = os.path.join(model_save_dir, current_time)
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
else:
pass
device = torch.device(cuda_device)
train_loader, val_loader = load_data(img_folder='dataset',
mask_folder='dataset',
batch_size=32,
num_workers=10,
input_size=(160, 160))
# input 3-channels, output 1-channels
net = U2NET(3, 1)
#net = U2NETP(3, 1)
# if torch.cuda.device_count() > 1:
# net = torch.nn.DataParallel(net, device_ids=[6, 7])
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
for epoch in range(0, epoch_nums):
run_loss = list()
run_tar_loss = list()
net.train()
for i, (inputs, gt_masks) in enumerate(tqdm(train_loader)):
optimizer.zero_grad()
inputs = inputs.type(torch.FloatTensor)
gt_masks = gt_masks.type(torch.FloatTensor)
inputs, gt_masks = inputs.to(device), gt_masks.to(device)
d0, d1, d2, d3, d4, d5, d6 = net(inputs)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, gt_masks)
loss.backward()
optimizer.step()
run_loss.append(loss.item())
run_tar_loss.append(loss2.item())
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
print("--Train Epoch:{}--".format(epoch))
print("--Train run_loss:{:.4f}--".format(np.mean(run_loss)))
print("--Train run_tar_loss:{:.4f}--\n".format(np.mean(run_tar_loss)))
if epoch % 20 == 0:
checkpoint_name = 'checkpoint_' + str(epoch) + '_' + str(np.mean(run_loss)) + '.pth'
torch.save(net.state_dict(), os.path.join(model_save_dir, checkpoint_name))
print("--model saved:{}--".format(checkpoint_name))
if __name__ == '__main__':
train_model(epoch_nums=500, cuda_device='cuda:7',
model_save_dir='backup')
在這部分訓(xùn)練代碼中, 并沒有出現(xiàn)很多訓(xùn)練策略,如各種學(xué)習(xí)率調(diào)整策略、多階段學(xué)習(xí)等等…該代碼實現(xiàn)的為最基礎(chǔ)的訓(xùn)練代碼,因此,您有足夠的空間去自行發(fā)揮;
模型推理程序
- 算法模型推理
推理程序的編寫邏輯一般是: 加載模型-> 讀取圖片 —>圖片預(yù)處理(需要保持和訓(xùn)練過程中的圖片預(yù)處理一致) ->模型推理 ->獲取結(jié)果,進行后處理 ->保存圖片,可視化查看結(jié)果
inference_u2net.py文章來源:http://www.zghlxwxcb.cn/news/detail-450665.html
# coding: utf-8
# author: hxy
# 20220420
"""
u2net/u2netP模型推理程序
"""
import os
import cv2
import torch
import numpy as np
from time import time
from tqdm import tqdm
from src.u2net import U2NET, U2NETP
"""
初始化模型加載
"""
try:
print('===loading model===')
current_project_path = os.getcwd()
net = U2NET(3, 1)
# net = U2NETP(3, 1)
checkpoint_path = os.path.join(current_project_path,
'backup/*****.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(checkpoint_path, map_location='cuda:1'))
else:
net.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
net.eval()
print('===model lode sucessed===')
except Exception as e:
print('===model load error:{}==='.format(e))
# 計算dice
def dice_coef(output, target): # output為預(yù)測結(jié)果 target為真實結(jié)果
smooth = 1e-5 # 防止0除
intersection = (output * target).sum()
return (2. * intersection + smooth) / \
(output.sum() + target.sum() + smooth)
# 圖像歸一化操作
def img2norm(img_array, input_size):
std = [0.229, 0.224, 0.225]
mean = [0.485, 0.456, 0.406]
_std = np.array(std).reshape((1, 1, 3))
_mean = np.array(mean).reshape((1, 1, 3))
img_array = cv2.resize(img_array, input_size)
norm_img = (img_array - _mean) / _std
return norm_img
# 歸一化預(yù)測結(jié)果
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d - mi) / (ma - mi)
return dn
# 推理
def inference1folder(img_folder, mask_folder, input_size):
total_times = list()
total_dices = list()
img_files = os.listdir(img_folder)
for img_file in tqdm(img_files):
img_full_path = os.path.join(img_folder, img_file)
mask_full_path = os.path.join(mask_folder, img_file)
img = cv2.imread(img_full_path)
gt_mask = cv2.imread(mask_full_path)
gt_mask = cv2.resize(gt_mask, input_size)
gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY)
gt_mask = gt_mask / 255.
ori_h, ori_w = img.shape[:2]
img2rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
norm_img = img2norm(img2rgb, input_size)
x_tensor = torch.from_numpy(norm_img).permute(2, 0, 1).float()
x_tensor = torch.unsqueeze(x_tensor, 0)
start_t = time()
d1, d2, d3, d4, d5, d6, d7 = net(x_tensor)
end_t = time()
total_times.append(end_t - start_t)
pred = d1[:, 0, :, :]
pred = normPRED(pred)
pred = pred.squeeze().cpu().data.numpy()
dice_value = dice_coef(pred, gt_mask)
total_dices.append(dice_value)
# pred[pred>=0.3]=255
# pred[pred<0.3]=0
# pred_res = pred
pred_res = pred * 255
pred_res = cv2.resize(pred_res, (ori_w, ori_h))
cv2.imwrite(os.path.join(current_project_path, 'infer_output/', img_file), pred_res)
print('==inference 1 pic avg cost:{:.4f}ms=='.format(np.mean(total_times) * 1000))
print('==inference avg dice:{:.4f}=='.format(np.mean(total_dices)))
return None
if __name__ == '__main__':
test_img_folder = os.path.join(os.getcwd(), 'dataset/images/test')
test_gt_mask_folder = os.path.join(os.getcwd(), 'dataset/masks/test')
inference1folder(img_folder=test_img_folder, mask_folder=test_gt_mask_folder, input_size=(160, 160))
著一部分代碼沒什么好說的,仔細看就完事,當(dāng)然我只寫了針對于一個folder的推理代碼,您可以嘗試推理視頻file;或者你也可以加一些更加炫酷的后處理讓你的推理結(jié)果看起來更加具有美觀;文章來源地址http://www.zghlxwxcb.cn/news/detail-450665.html
總結(jié)以及博客代碼的Github地址
- 一篇博客寫完總歸還是要來點總結(jié)才完美的!
- 本篇博客實現(xiàn)的是最基礎(chǔ)的訓(xùn)練過程和訓(xùn)練代碼,所以你有很多的發(fā)揮空間;
- 例如:嘗試使用不同的loss函數(shù)(dice loss、bce dice loss、iou loss等等)
- 添加數(shù)據(jù)增強操作(建議使用albumentation庫,torchversion也行)
- 使用不同的調(diào)參策略訓(xùn)練模型(不同的學(xué)習(xí)率衰減策略、多階段訓(xùn)練等等)
- 嘗試使用不同的優(yōu)化器訓(xùn)練模型等等。。。。。
- 等你上述嘗試都做過了,你可嘗試使用不同的網(wǎng)絡(luò),src文件夾內(nèi)不斷豐富不同網(wǎng)絡(luò)結(jié)構(gòu)
- 優(yōu)化一下代碼的編寫,封裝一下之類的,哈哈。。
- 總是很多實驗可以做,可學(xué)習(xí)的東西也很多。。
- 最后,希望本篇博客能夠給你帶來幫助~互相學(xué)習(xí)~文章代碼有不知之處多多包涵!
- 本博客代碼Github地址: https://github.com/YingXiuHe/u2net-pytorch.git/
到了這里,關(guān)于U2Net、U2NetP分割模型訓(xùn)練---自定義dataset、訓(xùn)練代碼訓(xùn)練自己的數(shù)據(jù)集的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!