UNet網(wǎng)絡(luò)詳解及PyTorch實(shí)現(xiàn)
一、UNet網(wǎng)絡(luò)原理
??U-Net,自2015年誕生以來,便以其卓越的性能在生物醫(yī)學(xué)圖像分割領(lǐng)域嶄露頭角。作為FCN的一種變體,U-Net憑借其Encoder-Decoder的精巧結(jié)構(gòu),不僅在醫(yī)學(xué)圖像分析中大放異彩,更在衛(wèi)星圖像分割、工業(yè)瑕疵檢測(cè)等多個(gè)領(lǐng)域展現(xiàn)出強(qiáng)大的應(yīng)用能力。UNet是一種常用于圖像分割的卷積神經(jīng)網(wǎng)絡(luò)架構(gòu),其特點(diǎn)在于其U型結(jié)構(gòu),包括一個(gè)收縮路徑(下采樣)和一個(gè)擴(kuò)展路徑(上采樣)。這種結(jié)構(gòu)使得UNet能夠在捕獲上下文信息的同時(shí),也能精確地定位到目標(biāo)邊界。
-
收縮路徑(編碼器Encoder):通過連續(xù)的卷積和池化操作,逐步減小特征圖的尺寸,從而捕獲到圖像的上下文信息。
-
擴(kuò)展路徑(解碼器Decoder):通過上采樣操作逐步恢復(fù)特征圖的尺寸,并與收縮路徑中對(duì)應(yīng)尺度的特征圖進(jìn)行拼接(concatenate),以融合不同尺度的特征信息。
-
跳躍連接:UNet中的跳躍連接使得擴(kuò)展路徑能夠利用到收縮路徑中的高分辨率特征,從而提高了分割的精度。
-
輸出層:UNet的輸出層通常是一個(gè)1x1的卷積層,用于將特征圖轉(zhuǎn)換為與輸入圖像相同尺寸的分割圖。
二、基于PyTorch的UNet實(shí)現(xiàn)
??下面是一個(gè)簡(jiǎn)單的基于PyTorch的UNet實(shí)現(xiàn),用于圖像分割任務(wù)。(環(huán)境安裝可以看我往期博客)
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = DoubleConv(64, 128)
self.down2 = DoubleConv(128, 256)
self.down3 = DoubleConv(256, 512)
factor = 2 if bilinear else 1
self.down4 = DoubleConv(512, 1024 // factor)
self.up1 = nn.ConvTranspose2d(1024 // factor, 512 // factor, kernel_size=2, stride=2)
self.up2 = nn.ConvTranspose2d(512 // factor, 256 // factor, kernel_size=2, stride=2)
self.up3 = nn.ConvTranspose2d(256 // factor, 128 // factor, kernel_size=2, stride=2)
self.up4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
三、訓(xùn)練與推理的完整代碼
??首先,我們需要準(zhǔn)備數(shù)據(jù)集、定義損失函數(shù)和優(yōu)化器,然后編寫訓(xùn)練循環(huán)。
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from unet_model import UNet # 假設(shè)UNet定義在unet_model.py文件中
# 設(shè)定超參數(shù)
num_epochs = 10
learning_rate = 0.001
batch_size = 4
# 數(shù)據(jù)預(yù)處理
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# 加載訓(xùn)練集
train_dataset = datasets.ImageFolder(root='path_to_train_dataset', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 定義模型、損失函數(shù)和優(yōu)化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(train_dataset.classes) # 根據(jù)數(shù)據(jù)集確定類別數(shù)
model = UNet(n_channels=3, n_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 訓(xùn)練循環(huán)
for epoch in range(num_epochs):
model.train() # 設(shè)置模型為訓(xùn)練模式
running_loss = 0.0
for i, data in enumerate(train_loader):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad() # 梯度清零
outputs = model(inputs) # 前向傳播
loss = criterion(outputs, labels) # 計(jì)算損失
loss.backward() # 反向傳播
optimizer.step() # 更新權(quán)重
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
# 保存模型
torch.save(model.state_dict(), 'unet_model.pth')
推理
在推理階段,我們加載已訓(xùn)練好的模型,并對(duì)測(cè)試集或單個(gè)圖像進(jìn)行預(yù)測(cè)。
python
# 加載模型
model.load_state_dict(torch.load('unet_model.pth'))
model.eval() # 設(shè)置模型為評(píng)估模式
# 如果需要,準(zhǔn)備測(cè)試集
test_dataset = datasets.ImageFolder(root='path_to_test_dataset', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 對(duì)測(cè)試集進(jìn)行推理
with torch.no_grad():
for inputs, _ in test_loader:
inputs = inputs.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
# 可以將predicted保存為文件或進(jìn)行其他處理
# 對(duì)單個(gè)圖像進(jìn)行推理
image_path = 'path_to_single_image.png'
image = Image.open(image_path).convert('RGB') # 確保是RGB格式
image = transform(image).unsqueeze(0).to(device) # 對(duì)圖像進(jìn)行預(yù)處理并添加到batch維度
with torch.no_grad():
prediction = model(image)
_, predicted = torch.max(prediction, 1)
predicted_class = train_dataset.classes[predicted.item()] # 獲取預(yù)測(cè)的類別名
# 可以將predicted保存為文件或進(jìn)行可視化
這里我假設(shè)你已經(jīng)有了適當(dāng)?shù)挠?xùn)練和測(cè)試數(shù)據(jù)集,并且它們已經(jīng)被組織成了ImageFolder可以理解的格式(即每個(gè)類別的圖像都在一個(gè)單獨(dú)的子文件夾中)。此外,代碼中的transform可能需要根據(jù)你的具體數(shù)據(jù)集進(jìn)行調(diào)整。
??在推理階段,我們使用torch.max來找出每個(gè)圖像最有可能的類別,并通過predicted_class變量打印或返回該類別。對(duì)于測(cè)試集,你可能希望將預(yù)測(cè)結(jié)果保存為文件,以便后續(xù)分析或可視化。對(duì)于單個(gè)圖像,你可以直接進(jìn)行可視化或?qū)⑵浔4鏋閹в蟹指罱Y(jié)果的圖像。
四、總結(jié)
??我們?cè)敿?xì)介紹了如何使用PyTorch實(shí)現(xiàn)并訓(xùn)練一個(gè)U-Net模型,以及如何在訓(xùn)練和推理階段使用它。首先,我們定義了一個(gè)U-Net模型的結(jié)構(gòu),該結(jié)構(gòu)通過下采樣路徑捕獲上下文信息,并通過上采樣路徑精確定位目標(biāo)區(qū)域。然后,我們準(zhǔn)備了訓(xùn)練和測(cè)試數(shù)據(jù)集,并應(yīng)用了適當(dāng)?shù)臄?shù)據(jù)預(yù)處理步驟。
??在訓(xùn)練階段,我們?cè)O(shè)置了模型、損失函數(shù)和優(yōu)化器,并編寫了一個(gè)循環(huán)來迭代訓(xùn)練數(shù)據(jù)集。在每個(gè)迭代中,我們執(zhí)行前向傳播來計(jì)算模型的輸出,計(jì)算損失,執(zhí)行反向傳播來更新模型的權(quán)重,并打印每個(gè)epoch的平均損失以監(jiān)控訓(xùn)練過程。訓(xùn)練完成后,我們保存了模型的權(quán)重。在推理階段,我們加載了已訓(xùn)練的模型,并將其設(shè)置為評(píng)估模式以關(guān)閉諸如dropout或batch normalization等訓(xùn)練特定的層。然后,我們對(duì)測(cè)試數(shù)據(jù)集或單個(gè)圖像進(jìn)行推理,使用模型生成預(yù)測(cè),并通過torch.max找到最有可能的類別。對(duì)于測(cè)試集,你可能希望保存預(yù)測(cè)結(jié)果以便后續(xù)分析;對(duì)于單個(gè)圖像,你可以直接進(jìn)行可視化或?qū)⑵浔4鏋閹в蟹指罱Y(jié)果的圖像。文章來源:http://www.zghlxwxcb.cn/news/detail-860483.html
??通過本博客,你應(yīng)該能夠了解如何使用PyTorch實(shí)現(xiàn)和訓(xùn)練一個(gè)U-Net模型,并能夠?qū)⑵鋺?yīng)用于圖像分割任務(wù)。當(dāng)然,實(shí)際應(yīng)用中可能還需要考慮更多的細(xì)節(jié)和優(yōu)化,如更復(fù)雜的數(shù)據(jù)增強(qiáng)、學(xué)習(xí)率調(diào)整策略、模型的正則化等。文章來源地址http://www.zghlxwxcb.cn/news/detail-860483.html
到了這里,關(guān)于【PyTorch 實(shí)戰(zhàn)2:UNet 分割模型】10min揭秘 UNet 分割網(wǎng)絡(luò)如何工作以及pytorch代碼實(shí)現(xiàn)(詳細(xì)代碼實(shí)現(xiàn))的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!