摘要:本文介紹深度學(xué)習(xí)的SRGAN圖像超分重建算法,使用Python以及Pytorch框架實現(xiàn),包含完整訓(xùn)練、測試代碼,以及訓(xùn)練數(shù)據(jù)集文件。博文介紹圖像超分算法的原理,包括生成對抗網(wǎng)絡(luò)和SRGAN模型原理和實現(xiàn)的代碼,同時結(jié)合具體內(nèi)容進(jìn)行解釋說明,完整代碼資源文件請轉(zhuǎn)至文末的下載鏈接。本博文目錄如下:
?點擊跳轉(zhuǎn)至文末所有涉及的完整代碼文件下載頁?
前言
????????一張低分辨率的圖像要想放大為更高尺寸的圖像,需要對確實的細(xì)節(jié)進(jìn)行插值,常見的線性插值方法利用相鄰像素的信息進(jìn)行補(bǔ)充,但放大后圖像模糊、質(zhì)量低下的問題仍然存在。圖像的超分辨率重建技術(shù)指的是將給定的低分辨率圖像通過特定的算法恢復(fù)成相應(yīng)的高分辨率圖像。簡單來理解超分辨率重建就是將小尺寸圖像變?yōu)榇蟪叽鐖D像,使圖像更加“清晰”,但放大時通過了深度學(xué)習(xí)的技術(shù)補(bǔ)充了更多細(xì)節(jié)。
????????這種細(xì)節(jié)的補(bǔ)充并不是簡單插值,而是在經(jīng)過大量現(xiàn)實數(shù)據(jù)的訓(xùn)練后,針對細(xì)節(jié)的“推理”填充,你可以簡單理解為和人在見過大量的高清照片后,對于模糊的相似部分能夠“腦補(bǔ)”畫面,根據(jù)畫面大體輪廓將具體細(xì)節(jié)勾畫出來。圖像超分效果如下圖所示:
????????可以看到,通過特定的超分辨率重建算法,使得原本模糊的圖像變得清晰了,直至今日,依托深度學(xué)習(xí)技術(shù),圖像的超分辨率重建已經(jīng)取得了非凡的成績,在效果上愈發(fā)真實和清晰。至于應(yīng)用就更加廣泛,如醫(yī)學(xué)成像、遙感、公共安防、視頻感知等。在影視素材畫質(zhì)的增強(qiáng)恢復(fù)中,許多基于深度學(xué)習(xí)的超分重建技術(shù)得到了實際應(yīng)用,比如Topaz Video Enhance AI軟件。有了圖像超分技術(shù),從此再也不用忍受渣渣畫質(zhì)了,老司機(jī)萌覺得還可以怎么用呢。
1. 實現(xiàn)原理
????????圖像的超分重建算法按照時間和效果,可以分為傳統(tǒng)算法和深度學(xué)習(xí)算法兩類。傳統(tǒng)的超分辨率重建算法主要依靠基本的數(shù)字圖像處理技術(shù)進(jìn)行重建,常見的有基于插值的超分辨率重建、基于退化模型的超分辨率重建、基于學(xué)習(xí)的超分辨率重建等。
1.1 超分重建流程
????????基于深度學(xué)習(xí)進(jìn)行超分辨率重建的算法,較早的要屬SRCNN(Super-Resolution Convolutional Neural Network)算法了,作為開山之作,其原理簡單。SRCNN利用深度學(xué)習(xí)模型和大批量樣本數(shù)據(jù)的訓(xùn)練,在超分性能上超越了一大批傳統(tǒng)圖像處理算法,從此深度學(xué)習(xí)開始向超分辨率領(lǐng)域研究邁進(jìn)。SRCNN的網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示:
????????以上模型(博主已添加中文注釋)來自Chao Dong等人的論文“Image Super-Resolution Using Deep Convolutional Networks”,主要由一個三層結(jié)構(gòu)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)構(gòu)成。對于一張低分辨率圖像,首先使用雙立方插值將其放大至目標(biāo)尺寸,使用以上的CNN模型去擬合低分辨率圖像與高分辨率圖像之間的非線性映射,最后通過重構(gòu)將網(wǎng)絡(luò)輸出的結(jié)果作為高分辨率圖像。
????????SRCNN的流程可以簡單理解為兩步:圖像放大和修復(fù),如下圖所示。其中,放大是采用某種方式(SRCNN采用插值上采樣)將圖像放大到指定倍數(shù),再利用大數(shù)據(jù)的學(xué)習(xí)模型結(jié)合圖像修復(fù)原理,將放大后的圖像映射為最終輸出目標(biāo)。可以看出,超分辨率重建相比簡單的插值放大,其在此基礎(chǔ)上又具備了圖像修復(fù)的作用,因此在超分性能上無疑大大增強(qiáng)。因此,超分辨率重建的很多算法也被學(xué)者遷移到圖像修復(fù)領(lǐng)域中,完成一些諸如jpep壓縮去燥、去模糊等任務(wù)。
????????除此之外,對于模型的訓(xùn)練其流程也具有參考意義:(1)尋找大量真實場景下的高清圖像樣本,對每張圖片進(jìn)行下采樣處理以降低圖像分辨率(如2倍下采樣、4倍下采樣等),這樣經(jīng)過下采樣圖像長寬均得到等比例縮??;(2)將采樣后的圖像作為低分辨率圖像用于輸入,采樣前的圖像作為高分辨率圖像作為真實值,以此構(gòu)成有效的訓(xùn)練樣本集;(3)利用深度學(xué)習(xí)模型對低分辨率圖像進(jìn)行放大重建為高分辨率的輸出結(jié)果,將其與原始高分辨率圖像進(jìn)行比較計算誤差,調(diào)整模型參數(shù)并不斷迭代,使得誤差下降至最低;(4)訓(xùn)練完的模型可以用于對新的低分辨率圖像進(jìn)行重建,得到高分辨率圖像。
1.2 SRResNet的深度網(wǎng)絡(luò)
????????相比只有3個卷積層的SRCNN,SRResNet采用更深的網(wǎng)絡(luò)結(jié)構(gòu)模型,抽取出更高級的圖像特征,深層模型對圖像可以更好的進(jìn)行表達(dá),實現(xiàn)超分重構(gòu)的性能也得到加強(qiáng)。深度殘差網(wǎng)絡(luò)(ResNet)的提出,很好解決了深層模型不能很好收斂的問題,其在圖像分類、圖像分割、目標(biāo)檢測等領(lǐng)域有著廣泛應(yīng)用。
????????ResNet成功的重要一點,是在傳統(tǒng)網(wǎng)絡(luò)中引入了殘差學(xué)習(xí)(Residual Learning),從而有效解決深層網(wǎng)絡(luò)中梯度消失和精度下降的問題,使得網(wǎng)絡(luò)層數(shù)能夠大大加深。殘差網(wǎng)絡(luò)的原理圖如下圖所示,從圖中可以看出原始數(shù)據(jù)x不僅有直接進(jìn)入下一層的鏈接,還有一條跨越兩層網(wǎng)絡(luò)(跳鏈)的鏈接,將x帶入到輸出中,此時輸出改為F(x)+x,使得整個模型訓(xùn)練時不容易發(fā)散。這里我繪制了一個殘差模塊,如下圖所示:
????????至此可以借住ResNet的特性,在SRCNN的基礎(chǔ)上我們就可以構(gòu)建更加強(qiáng)大的網(wǎng)絡(luò)結(jié)構(gòu),用于超分重建的深度神經(jīng)網(wǎng)絡(luò)。SRResNet模型的主干網(wǎng)絡(luò)其實采用了這種網(wǎng)絡(luò)結(jié)構(gòu),如下圖所示:
????????SRResNet模型中采用了多個深度殘差模塊(16個殘差模塊)對圖像特征進(jìn)行提取,保證整個網(wǎng)絡(luò)穩(wěn)定的同時,采用深度模型提升性能。以上模型中的卷積層僅僅改變了圖像的通道數(shù),并未修改圖像尺寸,由此可見目前為止的模型仍然可以看出是SRCNN類似的修復(fù)模型。
????????SRResNet模型利用子像素卷積來放大圖像,即在以上模型后繼續(xù)添加兩個子像素卷積模塊,每個子像素卷積模塊使得輸入圖像放大2倍,因此這個模型最終可以將圖像放大4倍。SRResNet模型主要包含兩部分:深度殘差模型、子像素卷積模型。深度殘差模型用來進(jìn)行高效的特征提取,可以在一定程度上削弱圖像噪點;子像素卷積模型主要用來放大圖像尺寸,其結(jié)構(gòu)如下圖所示:
????????以上模型中,k表示卷積核大小,s為步長,n表示通道數(shù)。最后模型在輸出前增加了一個卷積層用于數(shù)據(jù)調(diào)整和增強(qiáng)。為了訓(xùn)練模型SRResNet算法采用了MSE作為目標(biāo)函數(shù),即最小化模型輸出的高分辨率圖像(F(X)與原始分辨率圖像(Y)的均方誤差,其目標(biāo)函數(shù)公式如下:
L
=
1
n
∑
i
=
1
n
∥
F
(
X
i
;
θ
)
?
Y
i
∥
2
L=\frac{1}{n}\sum_{i=1}^{n}\left \| F(X_{i};\theta )-Y_i\right \|^2
L=n1?i=1∑n?∥F(Xi?;θ)?Yi?∥2
????????MSE被廣泛應(yīng)用于超分重建算法的目標(biāo)函數(shù),但使用該目標(biāo)函數(shù)重建的超分圖像可能出現(xiàn)不能很好符合人眼主觀感受的問題,SRGAN算法則針對該問題進(jìn)行了改進(jìn)。
2.SRGAN 原理與代碼實現(xiàn)
????????SRResNet算法通過深層的卷積模塊完成特征映射,但也存在重建出的圖像過于平滑,紋理細(xì)節(jié)信息丟失的缺陷。究其原因是采用MSE的目標(biāo)函數(shù),紋理細(xì)節(jié)處理難以滿足人眼主觀感受,為此如何“無中生有”重建紋理細(xì)節(jié),那就需要利用生成對抗網(wǎng)絡(luò)(Generative Adversarial Network, GAN)。
2.1 生成對抗網(wǎng)絡(luò)簡介
????????生成對抗網(wǎng)絡(luò)(GAN)的靈感與博弈論中博弈的思想相契合,對于深度學(xué)習(xí)而言,不再是簡單的單一模型(如SRResNet),而是構(gòu)造兩個深度學(xué)習(xí)模型:生成網(wǎng)絡(luò)(Generator)和判別網(wǎng)絡(luò)(Discriminator),兩個模型相互博弈,即生成網(wǎng)絡(luò)Generator產(chǎn)生以假亂真的圖像,而判別網(wǎng)絡(luò)Discriminator具備辨別圖像真?zhèn)蔚哪芰?,彼此在相互競爭對抗中達(dá)到更好效果。GAN的模型結(jié)構(gòu)如下圖所示:
????????上圖中生成網(wǎng)絡(luò)和判別網(wǎng)絡(luò)的主要功能:(1)生成網(wǎng)絡(luò)(Generator),它通過某種特定的網(wǎng)絡(luò)結(jié)構(gòu)以及目標(biāo)函數(shù)來生成圖像;(2)判別網(wǎng)絡(luò)(Discriminator),判別一張圖片是不是“真實的”,即判斷輸入的照片是不是由Generator生成;Generator的作用就是盡可能的生成逼真的圖像來迷惑Discriminator,使得Discriminator判斷失??;而Discriminator的作用就是盡可能的挖掘Generator的破綻,來判斷圖像到底是不是由Generator生成的“假冒偽劣”。
????????GAN已經(jīng)應(yīng)用于圖像補(bǔ)全、去噪,風(fēng)格遷移,超分重建等圖像領(lǐng)域,這里運(yùn)用GAN能夠減少損失函數(shù)的設(shè)計成本,從功能上看利用一定的基準(zhǔn),直接加上判別器,對抗訓(xùn)練會幫助我們解決很多問題。相比之前的簡單模型,GAN可以產(chǎn)生更加清晰、真實的效果。
2.2 感知損失函數(shù)
????????在SRGAN中重新設(shè)計了新的損失函數(shù)——感知損失(Perceptual Loss),它由內(nèi)容損失和對抗損失構(gòu)成:
- 對抗損失:與一般GAN定義類似,即重建出的圖像被判別器正確判斷的損失;
- 內(nèi)容損失:內(nèi)容損失更加關(guān)注重建圖片與真實高清圖像的語義特征差異,而不是逐個像素之間的顏色亮度差異;SRGAN的作者考慮計算圖像的固有特征差異,而固有特征提取其實早有專門模型被提出用于分類等任務(wù)。因此截取這些模型的特征提取模塊,用于計算重建圖像和真實圖像的特征(語義特征)提取,然后在提取的特征層上再進(jìn)行MSE計算。
????????值得一說的是,SRGAN在進(jìn)行語義特征提取時,選取了VGG19模型,截取模型的有用部分后,截取的模型被稱為truncated_vgg19模型。至此內(nèi)容損失的計算總結(jié)如下:
- 根據(jù)SRResNet模型重建出超分圖像(Super-Resolution,SR);
- 對于原始高清圖像H和重建出的超分圖像SR,分別應(yīng)用truncated_vgg19模型,計算得到兩幅圖像的特征圖H_fea和SR_fea;
- 計算推理后的特征圖H_fea和SR_fea的MSE值;
2.3 SRGAN網(wǎng)絡(luò)結(jié)構(gòu)
????????SRGAN 是由 Christian Ledig 和他的團(tuán)隊在 2017 年的論文 “Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network” 中提出的。在這篇論文中,他們提出了一種新的超分辨率方法,不僅可以恢復(fù)高分辨率圖像的細(xì)節(jié),還能使得生成的圖像在視覺上更接近于真實圖像。這種方法結(jié)合了深度學(xué)習(xí)中的生成對抗網(wǎng)絡(luò)(GAN)和殘差網(wǎng)絡(luò)(ResNet),兩者的結(jié)合提高了超分辨率的效果。SRGAN的網(wǎng)絡(luò)結(jié)構(gòu)由兩部分組成,分別為生成器模型(Generator)和判別器模型(Discriminator)。
????????(1)生成器:生成器的目標(biāo)是將低分辨率的輸入圖像變換為高分辨率圖像。在 SRGAN 中,生成器是一個深度殘差網(wǎng)絡(luò)(Deep Residual Network)。其核心是一系列的殘差塊,每個殘差塊中包含兩個卷積層,每個卷積層后面都跟有批量歸一化(Batch Normalization)和參數(shù)化ReLU(PReLU)激活函數(shù)。在所有的殘差塊后,通過兩個卷積層和一個像素級卷積層(PixelShuffle)將特征映射轉(zhuǎn)換回高分辨率圖像。這個結(jié)構(gòu)允許模型學(xué)習(xí)低分辨率和高分辨率圖像之間的殘差映射,從而使得網(wǎng)絡(luò)能夠有效地重建高分辨率圖像的細(xì)節(jié)。生成器的網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示:
????????(2)判別器:判別器是一個卷積神經(jīng)網(wǎng)絡(luò),其目標(biāo)是區(qū)分生成的圖像是否來自真實的高分辨率圖像。在 SRGAN 中,判別器的網(wǎng)絡(luò)結(jié)構(gòu)是一個深度卷積神經(jīng)網(wǎng)絡(luò),其中包括一系列的卷積層、批量歸一化層和LeakyReLU激活函數(shù)。最后通過全連接層和sigmoid激活函數(shù)輸出圖像的真實性概率。判別器的網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示:
????????這兩個網(wǎng)絡(luò)相互對抗:生成器嘗試生成越來越真實的圖像以欺騙判別器,而判別器則努力提高其區(qū)分真實圖像和生成圖像的能力。通過這種對抗過程,模型最終可以生成出具有高質(zhì)量細(xì)節(jié)的超分辨率圖像。在實際操作中,SRGAN 需要大量的訓(xùn)練數(shù)據(jù)和計算資源,且訓(xùn)練過程需要一定的技巧和經(jīng)驗。盡管如此,SRGAN 仍然是圖像超分辨率領(lǐng)域的一種重要技術(shù),為生成逼真的高分辨率圖像提供了一種有效的方法。
2.4 SRGAN網(wǎng)絡(luò)訓(xùn)練
????????SRGAN 的訓(xùn)練主要分為兩個階段:預(yù)訓(xùn)練階段和對抗訓(xùn)練階段。
????????預(yù)訓(xùn)練階段:這個階段主要是為了訓(xùn)練生成器。SRGAN 中的生成器是一個深度殘差網(wǎng)絡(luò),其目標(biāo)是學(xué)習(xí)一個從低分辨率圖像到高分辨率圖像的映射。在預(yù)訓(xùn)練階段,我們主要使用均方誤差(MSE)作為損失函數(shù),這樣可以確保網(wǎng)絡(luò)可以學(xué)習(xí)到一個相對精確的映射。這個階段的訓(xùn)練可以使用高分辨率圖像和對應(yīng)的低分辨率圖像作為訓(xùn)練數(shù)據(jù)。
????????對抗訓(xùn)練階段:在預(yù)訓(xùn)練階段結(jié)束后,我們得到了一個可以生成相對準(zhǔn)確的高分辨率圖像的生成器。然后,我們進(jìn)入對抗訓(xùn)練階段,這個階段的目標(biāo)是訓(xùn)練生成器和判別器進(jìn)行對抗。在這個階段,生成器的目標(biāo)是生成盡可能真實的高分辨率圖像以欺騙判別器,而判別器的目標(biāo)是盡可能準(zhǔn)確地區(qū)分真實的高分辨率圖像和生成器生成的高分辨率圖像。對抗訓(xùn)練的損失函數(shù)通常包括對抗損失和內(nèi)容損失兩部分。對抗損失來自判別器對生成圖像的判別結(jié)果,內(nèi)容損失則是生成圖像和真實高分辨率圖像在特征空間上的差異。
????????SRGAN 的損失函數(shù)主要包括兩部分:對抗損失(Adversarial Loss)和感知損失(Perceptual Loss)。
????????對抗損失(Adversarial Loss):對抗損失主要用于衡量生成器生成的圖像和真實圖像在判別器中的判別結(jié)果的差距。對抗損失的目標(biāo)是鼓勵生成器生成能夠欺騙判別器的圖像。在 SRGAN 中,使用了交叉熵?fù)p失(Cross-Entropy Loss)作為對抗損失,其公式如下:
L adv ( G , D ) = E I h r ~ p train ( I h r ) [ log ? D ( I h r ) ] + E I l r ~ p I l r ( I l r ) [ log ? ( 1 ? D ( G ( I l r ) ) ) ] L_{\text{adv}}(G,D) = \mathbb{E}_{I_{hr}\sim p_{\text{train}}(I_{hr})}[\log D(I_{hr})] + \mathbb{E}_{I_{lr}\sim p_{I_{lr}}(I_{lr})}[\log(1-D(G(I_{lr})))] Ladv?(G,D)=EIhr?~ptrain?(Ihr?)?[logD(Ihr?)]+EIlr?~pIlr??(Ilr?)?[log(1?D(G(Ilr?)))]
其中,G 是生成器,D 是判別器, I h r I_{hr} Ihr?是真實的高分辨率圖像, I l r I_{lr} Ilr? 是低分辨率圖像。
????????感知損失(Perceptual Loss):感知損失則用于衡量生成圖像和真實圖像在特征空間上的差距。在 SRGAN 中,感知損失包括內(nèi)容損失(Content Loss)和紋理損失(Texture Loss)。內(nèi)容損失是通過預(yù)訓(xùn)練的 VGG19 網(wǎng)絡(luò)提取出的特征圖之間的歐幾里得距離,紋理損失則是生成圖像和真實圖像的 Gram 矩陣之間的差距。感知損失的公式如下:
L perc ( G ) = E I l r ~ p I l r ( I l r ) , I h r ~ p train ( I h r ) [ ∥ ? ( I h r ) ? ? ( G ( I l r ) ) ∥ 1 + λ ∥ ? gram ( I h r ) ? ? gram ( G ( I l r ) ) ∥ 1 ] L_{\text{perc}}(G) = \mathbb{E}_{I_{lr}\sim p_{I_{lr}}(I_{lr}), I_{hr}\sim p_{\text{train}}(I_{hr})}[\| \phi(I_{hr}) - \phi(G(I_{lr})) \|_1 + \lambda \| \phi_{\text{gram}}(I_{hr}) - \phi_{\text{gram}}(G(I_{lr})) \|_1] Lperc?(G)=EIlr?~pIlr??(Ilr?),Ihr?~ptrain?(Ihr?)?[∥?(Ihr?)??(G(Ilr?))∥1?+λ∥?gram?(Ihr?)??gram?(G(Ilr?))∥1?]
其中,?? 是 VGG19 的特征提取函數(shù), ? g r a m \phi_{gram} ?gram? 是 Gram 矩陣的計算函數(shù),λ 是紋理損失的權(quán)重。
????????總的損失函數(shù)就是這兩部分的加權(quán)和:
L
(
G
,
D
)
=
L
adv
(
G
,
D
)
+
α
L
perc
(
G
)
L(G,D) = L_{\text{adv}}(G,D) + \alpha L_{\text{perc}}(G)
L(G,D)=Ladv?(G,D)+αLperc?(G)
其中,α 是感知損失的權(quán)重。
3. 代碼編寫
????????這里我們使用pytorch實現(xiàn)以上的SRGAN網(wǎng)絡(luò),模型代碼如下:
class SRResNet(nn.Module):
"""
SRResNet模型
"""
def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
"""
:參數(shù) large_kernel_size: 第一層卷積和最后一層卷積核大小
:參數(shù) small_kernel_size: 中間層卷積核大小
:參數(shù) n_channels: 中間層通道數(shù)
:參數(shù) n_blocks: 殘差模塊數(shù)
:參數(shù) scaling_factor: 放大比例
"""
super(SRResNet, self).__init__()
# 放大比例必須為 2、 4 或 8
scaling_factor = int(scaling_factor)
assert scaling_factor in {2, 4, 8}, "放大比例必須為 2、 4 或 8!"
# 第一個卷積塊
self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
batch_norm=False, activation='PReLu')
# 一系列殘差模塊, 每個殘差模塊包含一個跳連接
self.residual_blocks = nn.Sequential(
*[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)])
# 第二個卷積塊
self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
kernel_size=small_kernel_size,
batch_norm=True, activation=None)
# 放大通過子像素卷積模塊實現(xiàn), 每個模塊放大兩倍
n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
self.subpixel_convolutional_blocks = nn.Sequential(
*[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=n_channels, scaling_factor=2) for i
in range(n_subpixel_convolution_blocks)])
# 最后一個卷積模塊
self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size,
batch_norm=False, activation='Tanh')
def forward(self, lr_imgs):
"""
前向傳播.
:參數(shù) lr_imgs: 低分辨率輸入圖像集, 張量表示,大小為 (N, 3, w, h)
:返回: 高分辨率輸出圖像集, 張量表示, 大小為 (N, 3, w * scaling factor, h * scaling factor)
"""
output = self.conv_block1(lr_imgs) # (16, 3, 24, 24)
residual = output # (16, 64, 24, 24)
output = self.residual_blocks(output) # (16, 64, 24, 24)
output = self.conv_block2(output) # (16, 64, 24, 24)
output = output + residual # (16, 64, 24, 24)
output = self.subpixel_convolutional_blocks(output) # (16, 64, 24 * 4, 24 * 4)
sr_imgs = self.conv_block3(output) # (16, 3, 24 * 4, 24 * 4)
return sr_imgs
????????在SRGAN模型中,SRResNet是核心的一部分,也就是生成器模型。生成器的任務(wù)是從低分辨率圖像生成高分辨率圖像。以下是 SRResNet 模型的主要結(jié)構(gòu):
- 第一卷積塊(conv_block1):這個模塊用于接收輸入的低分辨率圖像,并進(jìn)行初始的特征提取。這里使用了預(yù)激活的ReLU (PReLU) 作為激活函數(shù),并且不使用批歸一化。卷積核的大小是大核大小(large_kernel_size),默認(rèn)為9。
- 殘差塊(residual_blocks):這是一系列的殘差模塊。每個殘差模塊都包含兩個卷積層和一個跳躍連接。這里默認(rèn)使用了16個殘差模塊。
- 第二卷積塊(conv_block2):這個模塊用于提取特征圖的更深層次的信息。這里使用了批歸一化和ReLU激活函數(shù),但是沒有使用偏置項。
- 子像素卷積塊(subpixel_convolutional_blocks):這些模塊用于將圖像放大到目標(biāo)的高分辨率。每個子像素卷積模塊都能將圖像的分辨率放大兩倍。根據(jù)我們設(shè)置的放大比例(scaling_factor),可能會有多個子像素卷積模塊串聯(lián)在一起。
- 最后的卷積塊(conv_block3):這個模塊用于生成最后的高分辨率圖像。這里使用了tanh作為激活函數(shù),可以將像素值約束在-1到1之間。
????????其中forward 函數(shù)描述了模型的前向傳播過程。首先,我們通過第一卷積塊處理輸入的低分辨率圖像,然后將結(jié)果保存在residual變量中,作為跳躍連接的參考。將處理后的結(jié)果送入殘差模塊進(jìn)行特征提取和非線性變換。然后再次使用卷積操作對特征圖進(jìn)行處理,并將結(jié)果與residual變量相加,實現(xiàn)了特征圖的跳躍連接。通過子像素卷積模塊進(jìn)行上采樣操作,將圖像的分辨率提升到目標(biāo)的高分辨率。最后通過最后的卷積塊生成最終的高分辨率圖像。
3.1 生成器模型代碼
????????生成器(Generator)模型是 SRGAN 中的一個關(guān)鍵部分,其核心任務(wù)是從低分辨率圖像生成高分辨率圖像。在這段代碼中,生成器的結(jié)構(gòu)與 SRResNet 完全一致,其主要代碼如下:
class Generator(nn.Module):
"""
生成器模型,其結(jié)構(gòu)與SRResNet完全一致.
"""
def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
"""
參數(shù) large_kernel_size:第一層和最后一層卷積核大小
參數(shù) small_kernel_size:中間層卷積核大小
參數(shù) n_channels:中間層卷積通道數(shù)
參數(shù) n_blocks: 殘差模塊數(shù)量
參數(shù) scaling_factor: 放大比例
"""
super(Generator, self).__init__()
self.net = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
def forward(self, lr_imgs):
"""
前向傳播.
參數(shù) lr_imgs: 低精度圖像 (N, 3, w, h)
返回: 超分重建圖像 (N, 3, w * scaling factor, h * scaling factor)
"""
sr_imgs = self.net(lr_imgs) # (N, n_channels, w * scaling factor, h * scaling factor)
return sr_imgs
????????以上代碼將 SRResNet 作為一個內(nèi)部網(wǎng)絡(luò) (self.net),并在 forward 方法中調(diào)用它來執(zhí)行超分辨率轉(zhuǎn)換,包括以下生成器模型的主要部分:
????????(1)內(nèi)部網(wǎng)絡(luò)(net):這個模塊是我們之前定義的 SRResNet 模型。其參數(shù),如大核尺寸(large_kernel_size)、小核尺寸(small_kernel_size)、通道數(shù)(n_channels)、殘差模塊數(shù)量(n_blocks)和放大比例(scaling_factor),都會直接傳遞給 SRResNet 模型。
????????(2)forward 方法描述了模型的前向傳播過程。它接收低分辨率圖像,然后通過 SRResNet 模型生成超分辨率圖像。其中,lr_imgs 是輸入的低分辨率圖像,形狀為 (N, 3, w, h);sr_imgs 是輸出的超分辨率圖像,形狀為 (N, 3, w * scaling_factor, h * scaling_factor)。
3.2 判別器模型代碼
????????判別器(Discriminator)是 SRGAN 模型的另一個關(guān)鍵部分,其任務(wù)是判斷輸入圖像是否為真實的高分辨率圖像。在訓(xùn)練過程中,判別器和生成器進(jìn)行博弈,共同推動模型的進(jìn)步,其主要代碼如下:
class Discriminator(nn.Module):
"""
SRGAN判別器
"""
def __init__(self, kernel_size=3, n_channels=64, n_blocks=8, fc_size=1024):
"""
參數(shù) kernel_size: 所有卷積層的核大小
參數(shù) n_channels: 初始卷積層輸出通道數(shù), 后面每隔一個卷積層通道數(shù)翻倍
參數(shù) n_blocks: 卷積塊數(shù)量
參數(shù) fc_size: 全連接層連接數(shù)
"""
super(Discriminator, self).__init__()
in_channels = 3
# 卷積系列,參照論文SRGAN進(jìn)行設(shè)計
conv_blocks = list()
for i in range(n_blocks):
out_channels = (n_channels if i is 0 else in_channels * 2) if i % 2 is 0 else in_channels
conv_blocks.append(
ConvolutionalBlock(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=1 if i % 2 is 0 else 2, batch_norm=i is not 0, activation='LeakyReLu'))
in_channels = out_channels
self.conv_blocks = nn.Sequential(*conv_blocks)
# 固定輸出大小
self.adaptive_pool = nn.AdaptiveAvgPool2d((6, 6))
self.fc1 = nn.Linear(out_channels * 6 * 6, fc_size)
self.leaky_relu = nn.LeakyReLU(0.2)
self.fc2 = nn.Linear(1024, 1)
# 最后不需要添加sigmoid層,因為PyTorch的nn.BCEWithLogitsLoss()已經(jīng)包含了這個步驟
def forward(self, imgs):
"""
前向傳播.
參數(shù) imgs: 用于作判別的原始高清圖或超分重建圖,張量表示,大小為(N, 3, w * scaling factor, h * scaling factor)
返回: 一個評分值, 用于判斷一副圖像是否是高清圖, 張量表示,大小為 (N)
"""
batch_size = imgs.size(0)
output = self.conv_blocks(imgs)
output = self.adaptive_pool(output)
output = self.fc1(output.view(batch_size, -1))
output = self.leaky_relu(output)
logit = self.fc2(output)
return logit
????????在以上代碼中,定義了判別器模型的主要部分:
- 卷積塊序列(conv_blocks):這是一個由多個卷積塊組成的序列。每個卷積塊都包含一個卷積層,然后可能跟隨一個批歸一化層,最后是一個LeakyReLU 激活函數(shù)。這些卷積塊的參數(shù)(比如卷積核大小、輸入/輸出通道數(shù)、是否使用批歸一化等)都是根據(jù) SRGAN論文中的說明進(jìn)行設(shè)置的。
- 自適應(yīng)平均池化層(adaptive_pool):這一層的作用是將卷積塊序列的輸出調(diào)整到固定的大?。?x6),以便接下來可以連接全連接層。
- 全連接層(fc1和fc2):第一個全連接層(fc1)用于將自適應(yīng)平均池化層的輸出扁平化,并通過線性變換降低維度到指定的尺寸(fc_size,這里設(shè)置為1024)。然后經(jīng)過LeakyReLU激活函數(shù),再連接到第二個全連接層(fc2),最終輸出一個分?jǐn)?shù)值,用于判斷輸入的圖像是否為真實的高分辨率圖像。
????????forward 函數(shù)描述了模型的前向傳播過程。它接收輸入的圖像,首先經(jīng)過卷積塊序列進(jìn)行特征提取,然后經(jīng)過自適應(yīng)平均池化層將特征調(diào)整到固定的大小,接著通過兩個全連接層輸出一個評分值。其中,imgs 是輸入的圖像,形狀為 (N, 3, w * scaling_factor, h * scaling_factor);logit 是輸出的評分值,形狀為 (N)。
????????需要注意的是,這里并沒有在模型的最后添加 Sigmoid 層,因為在計算損失時,我們會使用 PyTorch 的 nn.BCEWithLogitsLoss() 函數(shù),這個函數(shù)內(nèi)部已經(jīng)包含了 Sigmoid 函數(shù)的計算步驟。
3.3 測試生成圖像代碼
????????以下代碼介紹了如何使用訓(xùn)練好的 SRGAN 生成器模型進(jìn)行圖像的超分辨率恢復(fù)。先貼上代碼然后我后面再詳細(xì)解釋這個過程:
# -*- coding: utf-8 -*-
import time
from models import Generator
from utils import *
# 測試圖像
imgPath = './images/girl1.jpg'
# 模型參數(shù)
large_kernel_size = 9 # 第一層卷積和最后一層卷積的核大小
small_kernel_size = 3 # 中間層卷積的核大小
n_channels = 64 # 中間層通道數(shù)
n_blocks = 16 # 殘差模塊數(shù)量
scaling_factor = 4 # 放大比例
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# 預(yù)訓(xùn)練模型
srgan_checkpoint = "./models/checkpoint_srgan.pth"
# 加載模型SRResNet 或 SRGAN
checkpoint = torch.load(srgan_checkpoint, map_location=device)
generator = Generator(large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor)
generator = generator.to(device)
generator.load_state_dict(checkpoint['generator'])
generator.eval()
model = generator
# 加載圖像
img = Image.open(imgPath, mode='r')
img = img.convert('RGB')
# 雙線性上采樣
Bicubic_img = img.resize((int(img.width * scaling_factor), int(img.height * scaling_factor)), Image.BICUBIC)
Bicubic_img.save('./results/test_bicubic.jpg')
# 圖像預(yù)處理
lr_img = convert_image(img, source='pil', target='imagenet-norm')
lr_img.unsqueeze_(0)
# 記錄時間
start = time.time()
# 轉(zhuǎn)移數(shù)據(jù)至設(shè)備
lr_img = lr_img.to(device) # (1, 3, w, h ), imagenet-normed
# 模型推理
with torch.no_grad():
sr_img = model(lr_img).squeeze(0).cpu().detach() # (1, 3, w*scale, h*scale), in [-1, 1]
sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
sr_img.save('./results/test_srgan.jpg')
print('用時 {:.3f} 秒'.format(time.time() - start))
????????以上代碼給出了如何使用 SRGAN 進(jìn)行圖像超分辨率恢復(fù)的整個過程,你可以通過改變測試圖像或模型的參數(shù)來看看模型的效果如何變化,這段代碼包括了以下環(huán)節(jié)。
-
設(shè)置參數(shù):首先定義了一些模型和測試圖像的參數(shù),包括圖像的路徑,模型的參數(shù)(如卷積核大小,通道數(shù),殘差模塊數(shù)量和放大比例等)和設(shè)備(優(yōu)先使用 GPU,如果沒有則使用 CPU)。
-
加載模型:使用 torch.load() 函數(shù)加載預(yù)訓(xùn)練的 SRGAN 模型,并把模型移到相應(yīng)的設(shè)備上。之后設(shè)置模型為評估模式,這意味著模型中的某些層(如批歸一化和丟棄)會根據(jù)需要更改行為。
-
加載和處理圖像:使用 PIL 庫加載測試圖像,并將其轉(zhuǎn)換為 RGB 格式。然后使用雙線性插值方法將圖像大小調(diào)整到目標(biāo)大小,并保存結(jié)果。接下來,對圖像進(jìn)行預(yù)處理,將其從 PIL 圖像轉(zhuǎn)換為適合模型輸入的張量,并增加一個批處理維度。
-
模型推理:首先記錄推理開始的時間,然后將預(yù)處理后的圖像移到相應(yīng)的設(shè)備上。然后,使用 torch.no_grad() 上下文管理器禁止梯度計算(因為在推理過程中不需要計算梯度,這樣可以節(jié)省內(nèi)存),并將圖像輸入模型進(jìn)行超分辨率恢復(fù)。最后,將模型輸出的張量轉(zhuǎn)換回 PIL 圖像,并保存結(jié)果。
-
打印推理時間:計算模型推理的時間,并打印結(jié)果。
????????運(yùn)行出來的結(jié)果如下圖所示,可以對比一下效果,當(dāng)然不同的圖片可能恢復(fù)的效果不一樣。
4. 下載鏈接
????若您想獲得博文中涉及的實現(xiàn)完整全部程序文件(包括測試圖片、視頻,py文件等,如下圖),這里已打包上傳至博主的csdn下載頻道獲取。
完整代碼下載地址:https://download.csdn.net/download/qq_32892383/87953641
COCO訓(xùn)練數(shù)據(jù)集:https://pan.baidu.com/s/18xiqkK2m34TKo1FcKo0RJw?pwd=y5gf 提取碼:y5gf
Python版本:3.8,請勿使用其他版本,需要安裝的依賴請見requirements.txt文件
安裝環(huán)境步驟如下:
(1)首先打開系統(tǒng)的cmd終端(不要用powershell),使用以下命令將命令行路徑切換到你的代碼所在的文件夾(…/你的路徑/SRGAN)下:
cd G:\BlogCode\SRGAN
G:
上面我的代碼在G盤,你應(yīng)該切換到自己的文件夾路徑,然后再次輸入"G:"命令防止沒切換過來:
(2)輸入conda命令創(chuàng)建一個python 3.8的環(huán)境,代碼如下:
conda create -n env_rec python=3.8
等待環(huán)境創(chuàng)建完畢后,使用以下命令激活環(huán)境:
conda activate env_rec
(3)激活環(huán)境后可以使用pip讀取requirements.txt中的依賴庫版本進(jìn)行安裝:
pip install -r requirements.txt
等待完全安裝完畢,此時你可以在pycharm的環(huán)境配置中選擇剛剛新建的環(huán)境運(yùn)行了。如果需要重新訓(xùn)練模型,你需要先下載COCO數(shù)據(jù)集然后解壓到SRGAN文件夾下的data文件夾中,我已經(jīng)打包好該數(shù)據(jù)集其網(wǎng)盤地址如下:
文章來源:http://www.zghlxwxcb.cn/news/detail-530235.html
結(jié)束語
????????由于博主能力有限,博文中提及的方法即使經(jīng)過試驗,也難免會有疏漏之處。希望您能熱心指出其中的錯誤,以便下次修改時能以一個更完美更嚴(yán)謹(jǐn)?shù)臉幼?,呈現(xiàn)在大家面前。同時如果有更好的實現(xiàn)方法也請您不吝賜教。文章來源地址http://www.zghlxwxcb.cn/news/detail-530235.html
到了這里,關(guān)于SRGAN圖像超分重建算法Python實現(xiàn)(含數(shù)據(jù)集代碼)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!