??作者簡介:禿頭小蘇,致力于用最通俗的語言描述問題
??往期回顧:深度學(xué)習(xí)語義分割篇——FCN原理詳解篇
??近期目標(biāo):寫好專欄的每一篇文章
??支持小蘇:點(diǎn)贊????、收藏?、留言??
?
深度學(xué)習(xí)語義分割篇——FCN源碼解析篇
寫在前面
???本篇文章參考霹靂吧啦Wz在B站上的視頻進(jìn)行講解,點(diǎn)擊???下載FCN源碼。閱讀本文之前建議先閱讀上篇對FCN原理講解的文章。
???本文將從數(shù)據(jù)集讀取、模型訓(xùn)練、模型推理和模型搭建幾部分為大家講解,每次做代碼的講解我都要說一句話,就是不管是看視頻還是看文章只是對你了解代碼起輔助的作用,你應(yīng)花更多的時(shí)間自己調(diào)試,這樣你會(huì)對整個(gè)代碼的流程無比熟悉?。?!??????
???廢話也不多說了,讓我們一起來看看FCN的源碼吧。??????
?
數(shù)據(jù)集讀取——my_dataset.py
???在讀取數(shù)據(jù)集部分,我們定義了一個(gè)VOCSegmentation類,首先我們需要獲取輸入(image)和標(biāo)簽(target)的路徑,相關(guān)代碼如下:
class VOCSegmentation(data.Dataset):
def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
super(VOCSegmentation, self).__init__()
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
assert os.path.exists(root), "path '{}' does not exist.".format(root)
image_dir = os.path.join(root, 'JPEGImages')
mask_dir = os.path.join(root, 'SegmentationClass')
txt_path = os.path.join(root, "ImageSets", "Segmentation", txt_name)
assert os.path.exists(txt_path), "file '{}' does not exist.".format(txt_path)
with open(os.path.join(txt_path), "r") as f:
file_names = [x.strip() for x in f.readlines() if len(x.strip()) > 0]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
???這部分非常簡單啦,voc_root我們應(yīng)該傳入VOCdevkit所在的文件夾,以我的數(shù)據(jù)路徑為例,我應(yīng)指定voc_root="D:\數(shù)據(jù)集\VOC\VOCtrainval_11-May-2012"
???最終self.image和self.masks里存儲的就是我們輸入和標(biāo)簽的路徑了。接著我們對輸入圖片和標(biāo)簽進(jìn)行transformer預(yù)處理,本代碼主要進(jìn)行了隨機(jī)縮放、水平翻轉(zhuǎn)、隨機(jī)裁剪、toTensor和Normalize【訓(xùn)練集采用了這些,驗(yàn)證集僅使用了隨機(jī)縮放、toTensor和Normalize】,相關(guān)代碼如下:【這部分代碼其實(shí)是在train.py文件中的,這里放在了此部分講解】
#訓(xùn)練集所用預(yù)處理方法
class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)
def __call__(self, img, target):
return self.transforms(img, target)
# 驗(yàn)證集所用預(yù)處理方法
class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.RandomResize(base_size, base_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
def __call__(self, img, target):
return self.transforms(img, target)
???上述代碼中crop_size設(shè)置為480,即訓(xùn)練圖片都會(huì)裁剪到480*480大小,而驗(yàn)證時(shí)沒有使用隨機(jī)裁剪方法,因此驗(yàn)證集的圖片尺寸是不一致的。
???在數(shù)據(jù)集讀取類中,還定義了collate_fn方法,其實(shí)在訓(xùn)練過程中加載數(shù)據(jù)時(shí)用到的,定義了我們數(shù)據(jù)是如何打包的,代碼如下:
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
# 計(jì)算該batch數(shù)據(jù)中,channel, h, w的最大值
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
return batched_imgs
???這個(gè)方法即是將我們一個(gè)batch的數(shù)據(jù)打包到一塊兒,一起輸入網(wǎng)絡(luò)。這里光看代碼可能不好理解,打上斷點(diǎn)調(diào)試調(diào)試吧!?。??????
?
模型訓(xùn)練——train.py
???其實(shí),模型的訓(xùn)練步驟大致都差不多,不熟悉的可以先參考我的這篇博文:使用pytorch自己構(gòu)建網(wǎng)絡(luò)模型實(shí)戰(zhàn)??????
???下面一起來看看FCN的訓(xùn)練過程吧!?。??????
數(shù)據(jù)集讀取和加載
# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> train.txt
train_dataset = VOCSegmentation(args.data_path,
year="2012",
transforms=get_transform(train=True),
txt_name="train.txt")
# VOCdevkit -> VOC2012 -> ImageSets -> Segmentation -> val.txt
val_dataset = VOCSegmentation(args.data_path,
year="2012",
transforms=get_transform(train=False),
txt_name="val.txt")
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
pin_memory=True,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=1,
num_workers=num_workers,
pin_memory=True,
collate_fn=val_dataset.collate_fn)
???我想這部分大家肯定沒什么問題啦,每個(gè)網(wǎng)絡(luò)訓(xùn)練基本都是這樣的數(shù)據(jù)讀取和加載步驟,我就不過多介紹了。???
?
創(chuàng)建網(wǎng)絡(luò)模型
model = create_model(aux=args.aux, num_classes=num_classes)
???這里大家現(xiàn)在大家就可以理解為是FCN原理部分所創(chuàng)建的模型,即以VGG為backbone構(gòu)建的網(wǎng)絡(luò)。有關(guān)網(wǎng)絡(luò)模型的搭建我會(huì)在下文講述。??????
?
設(shè)置損失函數(shù)、優(yōu)化器
# 設(shè)置優(yōu)化器
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
???此代碼損失函數(shù)的設(shè)置是在訓(xùn)練一個(gè)epoch數(shù)據(jù)時(shí)定義的,使用的是cross_entropy損失函數(shù),后文會(huì)重點(diǎn)解釋。??????
?
網(wǎng)絡(luò)訓(xùn)練???
for epoch in range(args.start_epoch, args.epochs):
mean_loss, lr = train_one_epoch(model, optimizer, train_loader, device, epoch,
lr_scheduler=lr_scheduler, print_freq=args.print_freq, scaler=scaler)
? 上面定義了一個(gè)train_one_epoch
方法,我們一起來看看:
def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler, print_freq=10, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
lr_scheduler.step()
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(loss=loss.item(), lr=lr)
return metric_logger.meters["loss"].global_avg, lr
???這部分是不是很難看懂呢,大家動(dòng)起手來調(diào)試調(diào)試吧,其實(shí)這部分和其它網(wǎng)絡(luò)訓(xùn)練過程也基本差不多。我重點(diǎn)講一下 loss = criterion(output, target)
,即損失函數(shù)的部分,criterion函數(shù)的定義如下:
def criterion(inputs, target):
losses = {}
for name, x in inputs.items():
# 忽略target中值為255的像素,255的像素是目標(biāo)邊緣或者padding填充
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
if len(losses) == 1:
return losses['out']
return losses['out'] + 0.5 * losses['aux']
???關(guān)于損失函數(shù)部分要講解的內(nèi)容還是很多的,因此我放在了附錄–>損失函數(shù)cross_entropy詳解中,大家可去查看。??????
?
網(wǎng)絡(luò)測試
confmat = evaluate(model, val_loader, device=device, num_classes=num_classes)
???網(wǎng)絡(luò)測試部分原視頻中介紹的很詳細(xì),用到了混淆矩陣,我就不帶大家進(jìn)evaluate中一行一行的看了。但這里我來說一下這部分的調(diào)試小技巧,因?yàn)闇y試是在網(wǎng)絡(luò)訓(xùn)練一個(gè)epoch后執(zhí)行的,但我們肯定很難等訓(xùn)練一個(gè)epoch后再調(diào)試測試部分,因此我們在調(diào)試前先注釋掉訓(xùn)練部分,這樣就可以很快速的跳到測試部分啦,快去試試吧?。?!??????
?
模型保存
save_file = {"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args}
torch.save(save_file, "save_weights/model_{}.pth".format(epoch))
?
模型預(yù)測——predict.py???
???這部分有很多和訓(xùn)練部分重復(fù)的代碼哈,我就不一一的去分析了。重點(diǎn)看一下如何由模型輸出的結(jié)果得到最終的P模式的圖片,相關(guān)代碼如下:
output = model(img.to(device))
prediction = output['out'].argmax(1).squeeze(0)
prediction = prediction.to("cpu").numpy().astype(np.uint8)
mask = Image.fromarray(prediction)
mask.putpalette(pallette)
mask.save("test_result.png")
???上述代碼中我認(rèn)為這句prediction = output['out'].argmax(1).squeeze(0)
是最重要的,其主要作用是在輸出中的chanel維度求最大值對應(yīng)的類別索引,為方便大家理解,作圖如下:
???我們來解釋一下上圖,輸出為 1 × c × h × w 1×c×h×w 1×c×h×w,因?yàn)檫@是預(yù)測,故batch=1,這里使用的是VOC數(shù)據(jù),故這里的c=num_class=21。【包含一個(gè)背景類】首先我們會(huì)取輸出中每個(gè)像素在21個(gè)通道中的最大值,如第一個(gè)像素在21個(gè)通道的最大值在通道0上取得,即上圖橙色的通道。這個(gè)通道對應(yīng)的索引是0,在VOC中是背景類,故這個(gè)像素所屬類別為背景。其它像素同理。
???我們可以來看看模型預(yù)測的結(jié)果,如下圖所示:
? 是不是發(fā)現(xiàn)這只可愛的小貓咪被分割出來了呢,大家快去試試吧?。?!??????
?
模型搭建
???這部分我之所以放在最后一部分,是因?yàn)槲矣X得這部分是相對最好理解的。我們只要照著我們理論部分一步步的搭建就好。需要注意的是理論部分我們介紹時(shí)采用的時(shí)VGG做為backbone,這是因?yàn)楫?dāng)時(shí)論發(fā)表在15年,resnet網(wǎng)絡(luò)也是15出來的,所以論文中沒用到,但是很多年過去,resnet的有效性得到實(shí)踐證明,pytorch官方也采用了resnet作為FCN的backbone,并且使用了空洞卷積。對空洞卷積不了解的請點(diǎn)擊???查看相關(guān)解釋。這里放上backbone的整體結(jié)構(gòu)圖,大家作為參考,剩下的時(shí)間就去調(diào)試吧!??!??????

??????????????圖片來自霹靂吧啦Wz
?
參考鏈接
FCN源碼解析(Pytorch)??????
?
附錄
損失函數(shù)cross_entropy詳解???
???在講解損失函數(shù)之前,我有必要在為大家分析一波VOC的標(biāo)注,在FCN原理詳解篇的附錄我向大家說明說明了標(biāo)注是單通道的P模式圖片,不清楚的請點(diǎn)擊???了解詳情。
???單僅知道標(biāo)注是單通道的圖片還不夠,我們先來看看這張標(biāo)注圖片:
???這張圖片大致可以分為三部分,一部分是藍(lán)框框住的人,一部分是綠框框住的飛機(jī),還有一部分是黃框框住的神秘物體。我先來說說人和飛機(jī)部分,你可以發(fā)現(xiàn),它們一個(gè)是粉紅色(人),一個(gè)是大紅色(飛機(jī)),這是因?yàn)樵谖覀冋{(diào)色板中人和飛機(jī)索引對應(yīng)的類別分別為粉紅色和大紅色,如下圖所示:

???我們也可以來看看標(biāo)注圖片的背景,它是黑色的,背景類別為0,因此在調(diào)色板中0所對應(yīng)的RGB值為[0,0,0],為黑色,如下圖所示:
????????
???接著我們來看看這個(gè)白色的神秘物體,這是什么呢?我們可以看看此標(biāo)注圖像對應(yīng)的原圖,如下:
???通過上圖可以看到,這個(gè)白色的物體其實(shí)也是一個(gè)小飛機(jī),但很難分辨,故標(biāo)注時(shí)用白色像素給隱藏起來了,最后白色對應(yīng)的像素也不會(huì)參與損失計(jì)算。如果你足夠細(xì)心的話,你會(huì)發(fā)現(xiàn)在人和飛機(jī)的邊緣其實(shí)都是存在一圈白色的像素的,這是為了更好的區(qū)分不同類別對應(yīng)的像素。同樣,這里的白色也不會(huì)參與損失計(jì)算。【至于怎么不參與馬上就會(huì)講解,不用急】
???接下來我們可以用程序來看看標(biāo)注圖像中是否有白色像素,代碼如下:
from PIL import Image
import numpy as np
img = Image.open('D:\\數(shù)據(jù)集\\VOC\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\SegmentationClass\\2007_000032.png')
img_np = np.array(img)
? 我們可以看看img_np里的部分?jǐn)?shù)據(jù),如下圖所示:
???可以看到地下的像素是1,表示飛機(jī)(大紅色),上面的像素為0,表示背景(黑色),中間的像素為255,這就對應(yīng)著飛機(jī)周圍的白色像素。我們可以看一下255對應(yīng)的RGB值,如下:【這里的255需要大家記住哦,后面計(jì)算損失時(shí)白色部分不計(jì)算正是通過忽略這個(gè)值實(shí)現(xiàn)的】
??????????
? [224,224,192]表示的RGB顏色為白色。
???有了上面的先驗(yàn)知識,就可以來介紹cross_entropy函數(shù)了。我們直接來看求損失的公式,如下:
???我舉個(gè)例子來解釋一下上面的公式。設(shè)輸入為[0.1,0.2,0.3],即x=[0.1,0.2,0.3],標(biāo)簽為1,即class=1,則
? l o s s ( x , c l a s s ) = ? x [ c l a s s ] + log ? ( ∑ j exp ? ( x [ j ] ) ) = ? 0.2 + l o g ( e x [ 0 ] + e x [ 1 ] + e x [ 2 ] ) = ? 0.2 + l o g ( e 0.1 + e 0.2 + e 0.3 ) loss(x,class) = - x\left[ {class} \right] + \log (\sum\limits_j {\exp (x[j])})=-0.2+log(e^{x[0]}+e^{x[1]}+e^{x[2]})=-0.2+log(e^{0.1}+e^{0.2}+e^{0.3}) loss(x,class)=?x[class]+log(j∑?exp(x[j]))=?0.2+log(ex[0]+ex[1]+ex[2])=?0.2+log(e0.1+e0.2+e0.3)
???通過上文的例子我想你大概知道了損失的計(jì)算方法,上文的x是一維的,現(xiàn)在我們來看一下二維的x是怎么計(jì)算,首先先定義輸入和標(biāo)簽,代碼如下:
import torch
import numpy as np
import math
input = torch.tensor([[0.1, 0.2, 0.3],[0.1, 0.2, 0.3],[0.1, 0.2, 0.3]])
target = torch.tensor([0, 1, 2])
? 可以來看一下input和target的值:

? 接著我們可以先用函數(shù)來計(jì)算損失,如下:
loss = torch.nn.functional.cross_entropy(input, target)
? 計(jì)算得到的loss值如下:
??????????????
???接著我們手動(dòng)來計(jì)算損失,看其是否和直接用函數(shù)計(jì)算一致,即loss=1.1019。手動(dòng)計(jì)算代碼如下:【對于多維數(shù)據(jù),需要計(jì)算target對應(yīng)的x的損失,然后求平均】
res0 = -0.1+np.log(math.exp(0.1)+math.exp(0.2)+math.exp(0.3))
res1 = -0.2+np.log(math.exp(0.1)+math.exp(0.2)+math.exp(0.3))
res2 = -0.3+np.log(math.exp(0.1)+math.exp(0.2)+math.exp(0.3))
res = (res0 + res1 + res2)/3
???計(jì)算得到的結(jié)果如下,和利用函數(shù)計(jì)算時(shí)結(jié)果一致,僅精度有差別,所以這證明了我們的計(jì)算方式是沒有錯(cuò)的。
?????????????
???我們上文在介紹VOC標(biāo)注時(shí)說,計(jì)算損失是會(huì)忽略白色的像素,其就對應(yīng)著標(biāo)簽中的255。這里我們用這個(gè)小例子來說明程序是怎么實(shí)現(xiàn)忽略的,其實(shí)很簡單,只要在函數(shù)調(diào)用時(shí)傳入ignore_index并指定對應(yīng)的值即可。如對本例來說,現(xiàn)我打算忽略target中標(biāo)簽為2的數(shù)據(jù),即不讓其參與損失計(jì)算,我們來看看如何使用cross_entropy函數(shù)來實(shí)現(xiàn):
loss = torch.nn.functional.cross_entropy(input, target, ignore_index=2)
? 上述loss結(jié)果如下:
??????????????
? 現(xiàn)在我們手動(dòng)計(jì)算一下忽略target=2時(shí)的損失結(jié)果,如下:
res0 = -0.1+np.log(math.exp(0.1)+math.exp(0.2)+math.exp(0.3))
res1 = -0.2+np.log(math.exp(0.1)+math.exp(0.2)+math.exp(0.3))
res = (res0 + res1)/2
? 上述代碼中target=2沒有參與損失計(jì)算,其結(jié)果如下:
??????????????
? 上述實(shí)驗(yàn)都證明了我們的計(jì)算方式是沒有偏差的。??????
???相信你讀了上文對cross_entropy解釋,已經(jīng)基本對cross_entropy這個(gè)函數(shù)了解了。但是大家可能會(huì)發(fā)現(xiàn)在我們程序中輸入cross_entropy函數(shù)中的x通常是4維的tensor,即[N,C,H,W],這時(shí)候訓(xùn)練損失是怎么計(jì)算的呢?我們以x的維度為[1,2,2,2]為例為大家講解,首先定義輸入和target,如下:
import torch
import numpy as np
import math
input = torch.tensor([[[[0.1, 0.2],[0.3, 0.4]], [[0.5, 0.6],[0.7, 0.8]]]]) #shape(1 2 2 2 )
target = torch.tensor([[[0, 1],[0, 1]]])
? 來看看input和target的值:

? 接著來看看通過函數(shù)計(jì)算的loss,代碼如下:
loss = torch.nn.functional.cross_entropy(input, target)
? 此時(shí)loss的值為:
??????????????
? 接下來我們就來看看手動(dòng)計(jì)算的步驟,在用代碼實(shí)現(xiàn)前,我先來解釋下大致步驟,如下圖所示:
? 我們會(huì)將數(shù)據(jù)按通道方向展開,然后分別計(jì)算cross_entropy,最后求平均,代碼如下:
res0 = -0.1+np.log(math.exp(0.1)+math.exp(0.5))
res1 = -0.6+np.log(math.exp(0.2)+math.exp(0.6))
res2 = -0.3+np.log(math.exp(0.3)+math.exp(0.7))
res3 = -0.8+np.log(math.exp(0.4)+math.exp(0.8))
res = (res0 + res1 + res2 + res3)/4
? res的結(jié)果如下,其和使用函數(shù)計(jì)算一致。
?????????????
? 那我們不妨在來看看忽略某個(gè)target時(shí)loss的結(jié)果,以忽略target=0為例:
loss = torch.nn.functional.cross_entropy(input, target, ignore_index=0)
? loss的結(jié)果如下:
??????????????
? 我們來看看手動(dòng)計(jì)算的步驟:
? 代碼如下:
res1 = -0.6+np.log(math.exp(0.2)+math.exp(0.6))
res3 = -0.8+np.log(math.exp(0.4)+math.exp(0.8))
res = (res0 + res3)/2
? res的結(jié)果如下,同樣和使用函數(shù)計(jì)算是一致的。
??????????????
? 到這里,我們在來看FCN中的代碼,如下:
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
? 我想大家就很清楚了叭,這里忽略了255像素,不讓其參與到損失的計(jì)算中。
? 這一節(jié)我覺得是整個(gè)FCN最難理解的地方,我已經(jīng)介紹的非常詳細(xì)了,大家自己也要花些時(shí)間理解理解。??????
?
如若文章對你有所幫助,那就??????文章來源:http://www.zghlxwxcb.cn/news/detail-467497.html
???????? 文章來源地址http://www.zghlxwxcb.cn/news/detail-467497.html
到了這里,關(guān)于深度學(xué)習(xí)語義分割篇——FCN源碼解析篇的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!