摘要:
視頻恢復(fù)任務(wù),包括超分辨率、去模糊等,越來越受到計(jì)算機(jī)視覺界的關(guān)注。在NTIRE19挑戰(zhàn)賽中發(fā)布了一個(gè)名為REDS的具有挑戰(zhàn)性的基準(zhǔn)測(cè)試。這個(gè)新的基準(zhǔn)從兩個(gè)方面挑戰(zhàn)了現(xiàn)有的方法:(1)如何對(duì)給定大運(yùn)動(dòng)的多幀圖像進(jìn)行對(duì)齊,(2)如何有效地融合不同運(yùn)動(dòng)和模糊的不同幀。在這項(xiàng)工作中,我們提出了一個(gè)新的具有增強(qiáng)可變形卷積的視頻恢復(fù)框架,稱為EDVR,以解決這些挑戰(zhàn)。首先,為了處理大的運(yùn)動(dòng),我們?cè)O(shè)計(jì)了一個(gè)金字塔,級(jí)聯(lián)和可變形(PCD)對(duì)齊模塊,其中的幀對(duì)齊是在特征級(jí)使用可變形的卷積以粗到細(xì)的方式完成的。其次,我們提出了時(shí)間和空間注意力(TSA)融合模塊,該模塊將注意力應(yīng)用于時(shí)間和空間上,以強(qiáng)調(diào)后續(xù)恢復(fù)的重要特征。由于這些模塊,我們的EDVR贏得冠軍,并在NTIRE19視頻恢復(fù)和增強(qiáng)挑戰(zhàn)的所有四個(gè)軌道上以巨大的優(yōu)勢(shì)超過第二名。EDVR在視頻超分辨率和去模糊方面也顯示出了優(yōu)于已發(fā)表的最先進(jìn)的方法的性能。
1.介紹
對(duì)齊。大多數(shù)現(xiàn)有方法通過明確估計(jì)參考幀與其相鄰幀之間的光流場(chǎng)來執(zhí)行對(duì)齊[2、48、13]。相鄰幀根據(jù)估計(jì)的運(yùn)動(dòng)場(chǎng)進(jìn)行扭曲。另一類研究通過動(dòng)態(tài)濾波[10]或可變形卷積[40]實(shí)現(xiàn)隱式運(yùn)動(dòng)補(bǔ)償。REDS對(duì)現(xiàn)有的對(duì)齊算法提出了巨大的挑戰(zhàn)。特別是,對(duì)于基于流的方法來說,精確的流估計(jì)和準(zhǔn)確的扭曲可能具有挑戰(zhàn)性且耗時(shí)。在存在大運(yùn)動(dòng)的情況下,難以在單一分辨率尺度內(nèi)明確或隱式地進(jìn)行運(yùn)動(dòng)補(bǔ)償。
融合。將對(duì)齊幀的特征進(jìn)行融合是視頻修復(fù)任務(wù)中的另一個(gè)關(guān)鍵步驟。大多數(shù)現(xiàn)有方法要么使用卷積在所有幀上進(jìn)行早期融合[2],要么采用遞歸網(wǎng)絡(luò)逐漸融合多個(gè)幀[32、6]。劉等人[22]提出了一個(gè)時(shí)間自適應(yīng)網(wǎng)絡(luò),可以在不同的時(shí)間尺度上動(dòng)態(tài)融合。這些現(xiàn)有方法都沒有考慮每個(gè)幀的底層視覺信息的重要性——不同的幀和位置對(duì)重建的貢獻(xiàn)并不相等,因?yàn)橐恍騾^(qū)域受到了不完美的對(duì)齊和模糊的影響。
解決方案。提出了一個(gè)統(tǒng)一的框架,稱為EDVR,可擴(kuò)展到各種視頻修復(fù)任務(wù),包括超分辨率和去模糊。EDVR的核心包括:(1)一個(gè)稱為金字塔級(jí)聯(lián)和可變形卷積(PCD)的對(duì)齊模塊,以及(2)一個(gè)稱為時(shí)空注意力(TSA)的融合模塊。
PCD模塊受到TDAN[40]的啟發(fā),使用可變形卷積在特征級(jí)別上將每個(gè)相鄰幀與參考幀對(duì)齊。與TDAN不同的是,采用粗到精的方式進(jìn)行對(duì)齊,以處理大規(guī)模和復(fù)雜的運(yùn)動(dòng)。具體而言,我們使用金字塔結(jié)構(gòu),首先使用粗略的估計(jì)將較低尺度的特征對(duì)齊,然后將偏移量和對(duì)齊特征傳播到較高尺度,以促進(jìn)精確的運(yùn)動(dòng)補(bǔ)償,類似于光流估計(jì)中采用的概念[7, 9]。此外,在金字塔對(duì)齊操作之后,我們級(jí)聯(lián)了一個(gè)額外的可變形卷積,進(jìn)一步提高了對(duì)齊的魯棒性。
提出的TSA是一個(gè)融合模塊,有助于跨多個(gè)對(duì)齊特征聚合信息。為了更好地考慮每個(gè)幀的視覺信息,通過計(jì)算參考幀和每個(gè)相鄰幀的特征之間的逐元素相關(guān)性引入了時(shí)間注意力。然后,在每個(gè)位置上,相關(guān)系數(shù)對(duì)每個(gè)相鄰特征進(jìn)行加權(quán),表示其對(duì)重建參考圖像的信息量。然后將所有幀的加權(quán)特征進(jìn)行卷積和融合。在具有時(shí)間注意力的融合之后,我們進(jìn)一步應(yīng)用空間注意力,對(duì)每個(gè)通道中的每個(gè)位置分配權(quán)重,以更有效地利用跨通道和空間信息。
我們參加了視頻修復(fù)和增強(qiáng)挑戰(zhàn)的四個(gè)賽道[29, 28],包括視頻超分辨率(清晰/模糊)和視頻去模糊(清晰/壓縮偽影)。由于有效的對(duì)齊和融合模塊,我們的EDVR在所有四個(gè)具有挑戰(zhàn)性的賽道中獲得了冠軍,展示了我們方法的效果和通用性。除了比賽結(jié)果,我們還在現(xiàn)有的視頻超分辨率和去模糊基準(zhǔn)測(cè)試中報(bào)告了比較結(jié)果。在這些視頻修復(fù)任務(wù)中,我們的EDVR表現(xiàn)出優(yōu)于現(xiàn)有方法的性能。
1.1 可變形卷積
可變形卷積(Deformable Convolution)是一種卷積神經(jīng)網(wǎng)絡(luò)中的操作,它允許網(wǎng)絡(luò)在進(jìn)行卷積時(shí)對(duì)輸入特征圖進(jìn)行局部的空間變形。傳統(tǒng)的卷積操作只能在固定窗口內(nèi)進(jìn)行卷積計(jì)算,而可變形卷積則引入了額外的偏移量(offsets)來動(dòng)態(tài)調(diào)整局部感受野的采樣位置。
在可變形卷積中,每個(gè)位置的卷積核采樣位置是通過應(yīng)用一個(gè)偏移量來計(jì)算的。這些偏移量是在訓(xùn)練過程中通過學(xué)習(xí)得到的,它們表示了輸入特征圖中每個(gè)位置的局部偏移信息。通過使用這些偏移量,可變形卷積可以根據(jù)輸入特征圖的局部結(jié)構(gòu)自適應(yīng)地調(diào)整卷積核的采樣位置,從而更好地適應(yīng)圖像中的變形和不規(guī)則模式。
可以看到可變形卷積在普通卷積的基礎(chǔ)上加了一個(gè)偏移量,偏移量是由另一個(gè)卷積生成的,通常是小數(shù)。因?yàn)槭切?shù),所以對(duì)不上實(shí)際的像素點(diǎn),這里需要通過雙線性插值來確定該點(diǎn)的像素值。?
優(yōu)點(diǎn):目標(biāo)物體往往具有不同尺度的大小,但是傳統(tǒng)的CNN只具有固定的感受野,這就導(dǎo)致特征提取效果不佳。而可變形卷積允許感受野內(nèi)的像素學(xué)習(xí)一個(gè)偏移量自適應(yīng)地去調(diào)整采樣點(diǎn)的位置,這能提高模型的特征提取能力和魯棒性。
2.模型
2.1總體結(jié)構(gòu)
以視頻超分辨率為例,EDVR將2N+1個(gè)低分辨率幀作為輸入,并生成一個(gè)高分辨率的輸出。每個(gè)相鄰幀都通過PCD對(duì)齊模塊在特征級(jí)別上與參考幀對(duì)齊。TSA融合模塊融合不同幀的圖像信息。這兩個(gè)模塊的詳細(xì)信息在第3.2節(jié)和第3.3節(jié)中進(jìn)行了描述。融合后的特征經(jīng)過重建模塊,該模塊是EDVR中一系列殘差塊的級(jí)聯(lián),可以輕松地用單幅圖像超分辨率中的任何其他先進(jìn)模塊進(jìn)行替換[46, 51]。上采樣操作在網(wǎng)絡(luò)的末端進(jìn)行,以增加空間尺寸。最后,通過將預(yù)測(cè)的圖像殘差添加到直接上采樣的圖像上,得到高分辨率幀。
對(duì)于其他具有高空間分辨率輸入的任務(wù),比如視頻去模糊,輸入幀首先通過步幅卷積層進(jìn)行下采樣。然后,大部分計(jì)算都在低分辨率空間中進(jìn)行,這在很大程度上節(jié)省了計(jì)算成本。末端的上采樣層將特征調(diào)整回原始輸入分辨率。在對(duì)齊模塊之前使用了一個(gè)預(yù)去模糊模塊,對(duì)模糊輸入進(jìn)行預(yù)處理,提高對(duì)齊精度。
2.2 使用金字塔、級(jí)聯(lián)和可變形卷積進(jìn)行對(duì)齊
我們首先簡(jiǎn)要回顧了用于對(duì)齊[40]的可變形卷積的使用,即,對(duì)齊每個(gè)相鄰幀的特征到參考幀的特征。與基于光流的方法不同,該方法對(duì)每一幀的特征進(jìn)行變形對(duì)齊,用Ft+i, i∈[?N:+N]表示。我們使用調(diào)制的可變形模塊[53]。給定K個(gè)采樣位置的可變形卷積核,我們將wk和pk分別表示為第K個(gè)位置的權(quán)值和預(yù)定的偏移量。例如,3×3內(nèi)核定義為K=9, pk∈{(?1,?1),(?1,0),···,(1,1)}。在每個(gè)位置p0處的對(duì)齊特征Fa t+i可以通過以下方法得到:
其中,?P={?P}, f為由多個(gè)卷積層組成的一般函數(shù),[.,.]為串聯(lián)運(yùn)算。為簡(jiǎn)單起見,在描述和圖形中,我們只考慮?pk的可學(xué)習(xí)偏移,忽略?mk的調(diào)制。由于p0 +pk +?pk為分?jǐn)?shù),采用[3]中的雙線性插值。
為了解決對(duì)齊過程中的復(fù)雜運(yùn)動(dòng)和大視差問題,我們基于光流中成熟的原理提出了PCD模塊:金字塔處理[31,35]和級(jí)聯(lián)細(xì)化[7,8,9]。具體來說,如圖3中黑色虛線所示,為了生成第l層的特征Fl t+i,我們使用跨步卷積濾波器將第(l?1)層的特征降采樣2倍,得到特征表示的l層金字塔。在第l層,偏移量和對(duì)齊特征也分別由第l+1層的上采樣偏移量和對(duì)齊特征預(yù)測(cè)(圖3中的紫色虛線):
式中(·)↑s表示按因子s進(jìn)行縮放,DConv為eq . 1中描述的可變形卷積,g為具有多個(gè)卷積層的一般函數(shù)?!?上采樣采用雙線性插值實(shí)現(xiàn)。我們?cè)贓DVR中使用三層金字塔,即L=3。為了降低計(jì)算成本,我們不隨著空間大小的減小而增加信道數(shù)。
在金字塔結(jié)構(gòu)的基礎(chǔ)上,進(jìn)行后續(xù)的可變形對(duì)齊,進(jìn)一步細(xì)化粗對(duì)齊的特征(圖3中背景為淺紫色的部分)。這種由粗到細(xì)的方式使得PCD模塊將對(duì)齊提高到亞像素精度。我們?cè)?.3節(jié)中論證了PCD的有效性。值得注意的是,PCD對(duì)齊模塊是與整個(gè)框架一起共同學(xué)習(xí)的,沒有額外的監(jiān)督[40],也沒有光流[48]等其他任務(wù)的預(yù)訓(xùn)練。?
2.3時(shí)間和空間注意力的融合
幀間時(shí)間關(guān)系和幀內(nèi)空間關(guān)系是融合的關(guān)鍵因?yàn)?)由于遮擋、模糊區(qū)域和視差問題,不同的相鄰幀信息不相等;2)前一對(duì)準(zhǔn)階段產(chǎn)生的不對(duì)準(zhǔn)和不對(duì)準(zhǔn)對(duì)后續(xù)重建性能產(chǎn)生不利影響。因此,對(duì)相鄰幀進(jìn)行像素級(jí)動(dòng)態(tài)聚合是實(shí)現(xiàn)有效融合的必要條件。為了解決上述問題,我們提出了TSA融合模塊來為每一幀分配像素級(jí)聚合權(quán)值。具體來說,我們?cè)谌诤线^程中采用了時(shí)間和空間的注意,如圖4時(shí)間注意力的目標(biāo)是在嵌入空間中計(jì)算幀的相似度。直觀地說,在一個(gè)嵌入空間中,應(yīng)該更多地注意一個(gè)更類似于參考幀的相鄰幀。對(duì)于每一幀i∈[?N:+N],相似距離h可計(jì)算為:
?其中θ(Fa t+i)和φ(Fa t)是兩種嵌入,可以通過簡(jiǎn)單的卷積濾波器實(shí)現(xiàn)。使用sigmoid激活函數(shù)限制輸出在[0,1],穩(wěn)定梯度反向傳播。注意,對(duì)于每個(gè)空間位置,時(shí)間注意都具有空間特異性,即h(Fa t+i, Fa t)的空間大小與Fa t+i相同。
然后將時(shí)間注意圖以像素方式乘以原始對(duì)齊特征Fa t+i。額外的融合卷積層被采用來聚合這些注意力調(diào)制的特征Fa t+i:
其中⊙和[·,·,·]分別表示元素級(jí)乘法和級(jí)聯(lián)。
然后通過融合特征計(jì)算出空間注意掩模。采用金字塔設(shè)計(jì),增加注意接受場(chǎng)。然后,通過掩模的逐元乘法和加法對(duì)融合特征進(jìn)行調(diào)制,類似于[45]。TSA模塊的有效性見第4.3節(jié)。
2.4 二階段重建?
雖然單個(gè)EDVR裝備PCD對(duì)齊模塊和TSA融合模塊可以達(dá)到目前最先進(jìn)的性能,但觀察到恢復(fù)后的圖像偏置偏置級(jí)聯(lián)DConv金字塔并不完美,特別是當(dāng)輸入幀模糊或嚴(yán)重失真時(shí)。在這種惡劣的環(huán)境下,運(yùn)動(dòng)補(bǔ)償和細(xì)節(jié)聚合受到影響,導(dǎo)致重建性能較差。
直觀地說,粗還原框架將大大減輕對(duì)齊和融合的壓力。因此,我們采用兩階段策略來進(jìn)一步提高性能。具體來說,一個(gè)類似但較淺的EDVR網(wǎng)絡(luò)被級(jí)聯(lián),以細(xì)化第一階段的輸出幀。好處有兩方面:1)有效消除了前一模型無法處理的嚴(yán)重運(yùn)動(dòng)模糊,提高了恢復(fù)質(zhì)量;2)緩解了輸出幀之間的不一致性。
3.實(shí)驗(yàn)
3.1 實(shí)驗(yàn)數(shù)據(jù)
訓(xùn)練數(shù)據(jù)集。以往關(guān)于視頻處理的研究[21,10,34]通常是在私有數(shù)據(jù)集上發(fā)展或評(píng)估的。缺乏標(biāo)準(zhǔn)和開放的視頻數(shù)據(jù)集限制了公平的比較。REDS[26]是NTIRE19競(jìng)賽中新提出的高質(zhì)量(720p)視頻數(shù)據(jù)集。REDS由240個(gè)訓(xùn)練片段、30個(gè)驗(yàn)證片段和30個(gè)測(cè)試片段組成(每個(gè)片段有100個(gè)連續(xù)幀)。在比賽期間,由于沒有測(cè)試場(chǎng)地的真實(shí)情況,我們選擇了4個(gè)具有代表性的片段(場(chǎng)景和動(dòng)作多樣)作為我們的測(cè)試集,標(biāo)記為REDS41。剩下的訓(xùn)練和驗(yàn)證片段被重新分組為我們的訓(xùn)練數(shù)據(jù)集(共有266個(gè)片段)。為了與我們?cè)诟?jìng)爭(zhēng)中的方法和過程相一致,本文也采用了這種配置。
Vimeo-90K[48]是一個(gè)廣泛使用的訓(xùn)練數(shù)據(jù)集,通常會(huì)與Vid4[21]和Vimeo-90K測(cè)試數(shù)據(jù)集(記為Vimeo-90K- t)一起進(jìn)行評(píng)估。當(dāng)訓(xùn)練集的分布與測(cè)試集的分布偏離時(shí),我們觀察到數(shù)據(jù)集的偏差。
訓(xùn)練細(xì)節(jié)。PCD對(duì)齊模塊采用5個(gè)殘差塊(RB)進(jìn)行特征提取。我們?cè)谥亟K中使用40個(gè)RBs,在第二階段模型中使用20個(gè)RBs。每個(gè)殘留塊的通道大小設(shè)置為128。我們分別使用大小為64×64和256×256的RGB補(bǔ)丁作為視頻SR和去模糊任務(wù)的輸入。迷你批量大小設(shè)置為32。網(wǎng)絡(luò)采用5個(gè)連續(xù)幀(即N=2)作為輸入,除非另有規(guī)定。我們通過隨機(jī)水平翻轉(zhuǎn)和90?旋轉(zhuǎn)來增加訓(xùn)練數(shù)據(jù)。我們只采用Charbonnier懲罰函數(shù)[17]作為最終損失,定義為:
文章來源:http://www.zghlxwxcb.cn/news/detail-799551.html
通過設(shè)置β1=0.9和β2=0.999,我們使用Adam優(yōu)化器[14]訓(xùn)練我們的模型。初始化學(xué)習(xí)速率為4×10?4。我們用淺層網(wǎng)絡(luò)的參數(shù)來初始化更深層次的網(wǎng)絡(luò),以便更快地收斂。我們使用PyTorch框架來實(shí)現(xiàn)我們的模型,并使用8個(gè)NVIDIA Titan Xp gpu來訓(xùn)練它們。?文章來源地址http://www.zghlxwxcb.cn/news/detail-799551.html
4.代碼
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.models.archs.arch_util import (DCNv2Pack, ResidualBlockNoBN,
make_layer)
class PCDAlignment(nn.Module):
"""Alignment module using Pyramid, Cascading and Deformable convolution
(PCD). It is used in EDVR.
Ref:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
Args:
num_feat (int): Channel number of middle features. Default: 64.
deformable_groups (int): Deformable groups. Defaults: 8.
"""
def __init__(self, num_feat=64, deformable_groups=8):
super(PCDAlignment, self).__init__()
# Pyramid has three levels:
# L3: level 3, 1/4 spatial size
# L2: level 2, 1/2 spatial size
# L1: level 1, original spatial size
self.offset_conv1 = nn.ModuleDict()
self.offset_conv2 = nn.ModuleDict()
self.offset_conv3 = nn.ModuleDict()
self.dcn_pack = nn.ModuleDict()
self.feat_conv = nn.ModuleDict()
# Pyramids
for i in range(3, 0, -1):
level = f'l{i}'
self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1,
1)
if i == 3:
self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1,
1)
else:
self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3,
1, 1)
self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1,
1)
self.dcn_pack[level] = DCNv2Pack(
num_feat,
num_feat,
3,
padding=1,
deformable_groups=deformable_groups)
if i < 3:
self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1,
1)
# Cascading dcn
self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.cas_dcnpack = DCNv2Pack(
num_feat,
num_feat,
3,
padding=1,
deformable_groups=deformable_groups)
self.upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_feat_l, ref_feat_l):
"""Align neighboring frame features to the reference frame features.
Args:
nbr_feat_l (list[Tensor]): Neighboring feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
ref_feat_l (list[Tensor]): Reference feature list. It
contains three pyramid levels (L1, L2, L3),
each with shape (b, c, h, w).
Returns:
Tensor: Aligned features.
"""
# Pyramids
upsampled_offset, upsampled_feat = None, None
for i in range(3, 0, -1):
level = f'l{i}'
offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
offset = self.lrelu(self.offset_conv1[level](offset))
if i == 3:
offset = self.lrelu(self.offset_conv2[level](offset))
else:
offset = self.lrelu(self.offset_conv2[level](torch.cat(
[offset, upsampled_offset], dim=1)))
offset = self.lrelu(self.offset_conv3[level](offset))
feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
if i < 3:
feat = self.feat_conv[level](
torch.cat([feat, upsampled_feat], dim=1))
if i > 1:
feat = self.lrelu(feat)
if i > 1: # upsample offset and features
# x2: when we upsample the offset, we should also enlarge
# the magnitude.
upsampled_offset = self.upsample(offset) * 2
upsampled_feat = self.upsample(feat)
# Cascading
offset = torch.cat([feat, ref_feat_l[0]], dim=1)
offset = self.lrelu(
self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
feat = self.lrelu(self.cas_dcnpack(feat, offset))
return feat
class TSAFusion(nn.Module):
"""Temporal Spatial Attention (TSA) fusion module.
Temporal: Calculate the correlation between center frame and
neighboring frames;
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
(SFT: Recovering realistic texture in image super-resolution by deep
spatial feature transform.)
Args:
num_feat (int): Channel number of middle features. Default: 64.
num_frame (int): Number of frames. Default: 5.
center_frame_idx (int): The index of center frame. Default: 2.
"""
def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
super(TSAFusion, self).__init__()
self.center_frame_idx = center_frame_idx
# temporal attention (before fusion conv)
self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
# spatial attention (after fusion conv)
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
def forward(self, aligned_feat):
"""
Args:
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
Returns:
Tensor: Features after TSA with the shape (b, c, h, w).
"""
b, t, c, h, w = aligned_feat.size()
# temporal attention
embedding_ref = self.temporal_attn1(
aligned_feat[:, self.center_frame_idx, :, :, :].clone())
embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
corr_l = [] # correlation list
for i in range(t):
emb_neighbor = embedding[:, i, :, :, :]
corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
# fusion
feat = self.lrelu(self.feat_fusion(aligned_feat))
# spatial attention
attn = self.lrelu(self.spatial_attn1(aligned_feat))
attn_max = self.max_pool(attn)
attn_avg = self.avg_pool(attn)
attn = self.lrelu(
self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
# pyramid levels
attn_level = self.lrelu(self.spatial_attn_l1(attn))
attn_max = self.max_pool(attn_level)
attn_avg = self.avg_pool(attn_level)
attn_level = self.lrelu(
self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
attn_level = self.upsample(attn_level)
attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
attn = self.lrelu(self.spatial_attn4(attn))
attn = self.upsample(attn)
attn = self.spatial_attn5(attn)
attn_add = self.spatial_attn_add2(
self.lrelu(self.spatial_attn_add1(attn)))
attn = torch.sigmoid(attn)
# after initialization, * 2 makes (attn * 2) to be close to 1.
feat = feat * attn * 2 + attn_add
return feat
class PredeblurModule(nn.Module):
"""Pre-dublur module.
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
hr_in (bool): Whether the input has high resolution. Default: False.
"""
def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
super(PredeblurModule, self).__init__()
self.hr_in = hr_in
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
if self.hr_in:
# downsample x4 by stride conv
self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
# generate feature pyramid
self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
self.resblock_l1 = nn.ModuleList(
[ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
self.upsample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
feat_l1 = self.lrelu(self.conv_first(x))
if self.hr_in:
feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
# generate feature pyramid
feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
feat_l3 = self.upsample(self.resblock_l3(feat_l3))
feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
for i in range(2):
feat_l1 = self.resblock_l1[i](feat_l1)
feat_l1 = feat_l1 + feat_l2
for i in range(2, 5):
feat_l1 = self.resblock_l1[i](feat_l1)
return feat_l1
class EDVR(nn.Module):
"""EDVR network structure for video super-resolution.
Now only support X4 upsampling factor.
Paper:
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
Args:
num_in_ch (int): Channel number of input image. Default: 3.
num_out_ch (int): Channel number of output image. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_frame (int): Number of input frames. Default: 5.
deformable_groups (int): Deformable groups. Defaults: 8.
num_extract_block (int): Number of blocks for feature extraction.
Default: 5.
num_reconstruct_block (int): Number of blocks for reconstruction.
Default: 10.
center_frame_idx (int): The index of center frame. Frame counting from
0. Default: 2.
hr_in (bool): Whether the input has high resolution. Default: False.
with_predeblur (bool): Whether has predeblur module.
Default: False.
with_tsa (bool): Whether has TSA module. Default: True.
"""
def __init__(self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_frame=5,
deformable_groups=8,
num_extract_block=5,
num_reconstruct_block=10,
center_frame_idx=2,
hr_in=False,
with_predeblur=False,
with_tsa=True):
super(EDVR, self).__init__()
if center_frame_idx is None:
self.center_frame_idx = num_frame // 2
else:
self.center_frame_idx = center_frame_idx
self.hr_in = hr_in
self.with_predeblur = with_predeblur
self.with_tsa = with_tsa
# extract features for each frame
if self.with_predeblur:
self.predeblur = PredeblurModule(
num_feat=num_feat, hr_in=self.hr_in)
self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
else:
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
# extrat pyramid features
self.feature_extraction = make_layer(
ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# pcd and tsa module
self.pcd_align = PCDAlignment(
num_feat=num_feat, deformable_groups=deformable_groups)
if self.with_tsa:
self.fusion = TSAFusion(
num_feat=num_feat,
num_frame=num_frame,
center_frame_idx=self.center_frame_idx)
else:
self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
# reconstruction
self.reconstruction = make_layer(
ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
# upsample
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
self.pixel_shuffle = nn.PixelShuffle(2)
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
b, t, c, h, w = x.size()
if self.hr_in:
assert h % 16 == 0 and w % 16 == 0, (
'The height and width must be multiple of 16.')
else:
assert h % 4 == 0 and w % 4 == 0, (
'The height and width must be multiple of 4.')
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
# extract features for each frame
# L1
if self.with_predeblur:
feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
if self.hr_in:
h, w = h // 4, w // 4
else:
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
feat_l1 = self.feature_extraction(feat_l1)
# L2
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
# L3
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
feat_l1 = feat_l1.view(b, t, -1, h, w)
feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
# PCD alignment
ref_feat_l = [ # reference feature list
feat_l1[:, self.center_frame_idx, :, :, :].clone(),
feat_l2[:, self.center_frame_idx, :, :, :].clone(),
feat_l3[:, self.center_frame_idx, :, :, :].clone()
]
aligned_feat = []
for i in range(t):
nbr_feat_l = [ # neighboring feature list
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(),
feat_l3[:, i, :, :, :].clone()
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
if not self.with_tsa:
aligned_feat = aligned_feat.view(b, -1, h, w)
feat = self.fusion(aligned_feat)
out = self.reconstruction(feat)
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
out = self.lrelu(self.conv_hr(out))
out = self.conv_last(out)
if self.hr_in:
base = x_center
else:
base = F.interpolate(
x_center, scale_factor=4, mode='bilinear', align_corners=False)
out += base
return out
到了這里,關(guān)于視頻超分經(jīng)典文章(一)EDVR: Video Restoration with Enhanced Deformable Convolutional Networks (包含代碼)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!