1. 前言
文末附有源碼下載地址。
灰度圖自動上色
2.圖像格式(RGB,HSV,Lab)
2.1 RGB
想要對灰度圖片上色,首先要了解圖像的格式,對于一副普通的圖像通常為RGB格式的,即紅、綠、藍三個通道,可以使用opencv分離圖像的三個通道,代碼如下所示:
import cv2
img=cv2.imread('pic/7.jpg')
B,G,R=cv2.split(img)
cv2.imshow('img',img)
cv2.imshow('B',B)
cv2.imshow('G',G)
cv2.imshow('R',R)
cv2.waitKey(0)
代碼運行結(jié)果如下所示。
2.2 hsv
hsv是圖像的另一種格式,其中h代表圖像的色調(diào),s代表飽和度,v代表圖像亮度,可以通過調(diào)節(jié)h、s、v的值來改變圖像的色調(diào)、飽和度、亮度等信息。
同樣可以使用opencv將圖像從RGB格式轉(zhuǎn)換成hsv格式。然后可以分離h、s、v三個通道并顯示圖像代碼如下所示:
import cv2
img=cv2.imread('pic/7.jpg')
hsv=cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
h,s,v=cv2.split(hsv)
cv2.imshow('hsv',hsv)
cv2.imshow('h',h)
cv2.imshow('s',s)
cv2.imshow('v',v)
cv2.waitKey(0)
運行結(jié)果如下所示:
2.3 Lab
Lab是圖像的另一種格式,也是本文使用的格式,其中L代表灰度圖像,a、b代表顏色通道,本文使用L通道灰度圖作為輸入,ab兩個顏色通道作為輸出,訓(xùn)練生成對抗網(wǎng)絡(luò),將圖像由RGB格式轉(zhuǎn)換成Lab格式的代碼如下所示:
import cv2
img=cv2.imread('pic/7.jpg')
Lab=cv2.cvtColor(img,cv2.COLOR_BGR2Lab)
L,a,b=cv2.split(Lab)
cv2.imshow('Lab',Lab)
cv2.imshow('L',L)
cv2.imshow('a',a)
cv2.imshow('b',b)
cv2.waitKey(0)
3. 生成對抗網(wǎng)絡(luò)(GAN)
生成對抗網(wǎng)絡(luò)主要包含兩部分,分別是生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)。
生成網(wǎng)絡(luò)負責生成圖像,判別網(wǎng)絡(luò)負責鑒定生成圖像的好壞,二者相輔相成,相互博弈。
本文使用U-net作為生成網(wǎng)絡(luò),使用ResNet18作為判別網(wǎng)絡(luò)。U-net網(wǎng)絡(luò)的結(jié)構(gòu)圖如下所示:
3.1 生成網(wǎng)絡(luò)(Unet)
pytorch構(gòu)建unet網(wǎng)絡(luò)的代碼如下所示:
class DownsampleLayer(nn.Module):
def __init__(self,in_ch,out_ch):
super(DownsampleLayer, self).__init__()
self.Conv_BN_ReLU_2=nn.Sequential(
nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
self.downsample=nn.Sequential(
nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
def forward(self,x):
"""
:param x:
:return: out輸出到深層,out_2輸入到下一層,
"""
out=self.Conv_BN_ReLU_2(x)
out_2=self.downsample(out)
return out,out_2
class UpSampleLayer(nn.Module):
def __init__(self,in_ch,out_ch):
# 512-1024-512
# 1024-512-256
# 512-256-128
# 256-128-64
super(UpSampleLayer, self).__init__()
self.Conv_BN_ReLU_2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU(),
nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU()
)
self.upsample=nn.Sequential(
nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU()
)
def forward(self,x,out):
'''
:param x: 輸入卷積層
:param out:與上采樣層進行cat
:return:
'''
x_out=self.Conv_BN_ReLU_2(x)
x_out=self.upsample(x_out)
cat_out=torch.cat((x_out,out),dim=1)
return cat_out
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
#下采樣
self.d1=DownsampleLayer(3,out_channels[0])#3-64
self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
#上采樣
self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
#輸出
self.o=nn.Sequential(
nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(out_channels[0]),
nn.ReLU(),
nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels[0]),
nn.ReLU(),
nn.Conv2d(out_channels[0],3,3,1,1),
nn.Sigmoid(),
# BCELoss
)
def forward(self,x):
out_1,out1=self.d1(x)
out_2,out2=self.d2(out1)
out_3,out3=self.d3(out2)
out_4,out4=self.d4(out3)
out5=self.u1(out4,out_4)
out6=self.u2(out5,out_3)
out7=self.u3(out6,out_2)
out8=self.u4(out7,out_1)
out=self.o(out8)
return out
3.2 判別網(wǎng)絡(luò)(resnet18)
resnet18的結(jié)構(gòu)圖如下所示:
在pytorch內(nèi)部自帶resnet18模型,只需一行代碼即可構(gòu)建resnet18模型,然后還需要去除網(wǎng)絡(luò)最后的全連接層,代碼如下所示:
from torchvision import models
resnet18=models.resnet18(pretrained=False)
del resnet18.fc
print(resnet18)
4. 數(shù)據(jù)集
本文使用的是自然風(fēng)景類的數(shù)據(jù)圖片,在網(wǎng)站上爬取了大概1000多張數(shù)據(jù)圖片,部分圖片如下所示
5. 模型訓(xùn)練與預(yù)測流程圖
5.1 訓(xùn)練流程圖
如下圖所示,首先將RGB圖像轉(zhuǎn)換成Lab圖像,然后將L通道作為生成網(wǎng)絡(luò)輸入,生成網(wǎng)絡(luò)的輸出為新的ab兩通道,然后將圖像原始的ab通道,與生成網(wǎng)絡(luò)生成的ab通道輸入判別網(wǎng)絡(luò)中。
5.2 預(yù)測流程圖
下圖為模型的預(yù)測過程,在預(yù)測過程中判別網(wǎng)絡(luò)已經(jīng)沒有作用了,首先將RGB圖像轉(zhuǎn)換成,Lab圖像,接著將L灰度圖輸入生成網(wǎng)絡(luò)可以得到新的ab通道圖像,接著將L通道圖像與生成的ab通道圖像進行拼接(concate),拼接以后可以得到一張新的Lab圖像,然后再將其轉(zhuǎn)換成RGB格式,此時圖像即為上色以后的圖像。
6. 模型預(yù)測效果
下圖為模型的預(yù)測效果。左側(cè)的為灰度圖像,中間的為原始的彩色圖像,右側(cè)的是模型上色以后的圖像。整體上看,網(wǎng)絡(luò)的上色效果還不錯。
7. GUI界面制作
為了更加方便使用模型,本文使用pyqt5制作操作界面,其界面如下圖所示:首先可以從電腦中加載圖像,還可以切換上一張或者下一張,可以將圖像灰度化顯示??梢詫ζ渖仙?,然后可以調(diào)整上色后圖像的H、S、V信息,最后支持圖像導(dǎo)出,可以將上色后的圖像保存到本地中。文章來源:http://www.zghlxwxcb.cn/news/detail-790758.html
8.代碼下載
鏈接中包含了訓(xùn)練代碼,測試代碼,以及界面代碼。此外還包含1000多張數(shù)據(jù)集,直接運行main.py程序即可彈出操作界面。
代碼下載:下載地址列表1文章來源地址http://www.zghlxwxcb.cn/news/detail-790758.html
到了這里,關(guān)于基于深度學(xué)習(xí)的圖片上色(Opencv,Pytorch,CNN)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!