前言:CVPR2022oral 用transformer應(yīng)用到low-level任務(wù)
Restormer: Efficient Transformer for High-Resolution Image Restoration
引言
low-level task 如deblurring\denoising\dehazing等任務(wù)多是基于CNN做的,這樣的局限性有二:
第一是卷積操作的感受野受限,很難建立起全局依賴,
第二就是卷積操作的卷積核初始化是固定的,而attention的設(shè)計(jì)可以通過(guò)像素之間的關(guān)系自適應(yīng)的調(diào)整權(quán)重
現(xiàn)有的transformer用于low-level任務(wù)最大的瓶頸在于分辨率太大了,自注意力機(jī)制的復(fù)雜度隨著空間分辨率的增加二次增長(zhǎng),現(xiàn)有的一些解決方案有:
1.劃成很多個(gè)8 * 8的像素小窗口,在這個(gè)小窗口內(nèi)進(jìn)行應(yīng)用自注意力
2.化成不重疊的48 * 48的塊,塊與塊之間進(jìn)行自注意力機(jī)制
然而,這樣的設(shè)計(jì)和transformer建立全局依賴的初衷是矛盾的
因此,本文解決了用transformer處理這類問(wèn)題的計(jì)算復(fù)雜性,將其計(jì)算復(fù)雜度降低成和空間分辨率線性相關(guān)
改進(jìn)了SA self-attention部分和feed-forward部分,并提出了一種漸進(jìn)式patch訓(xùn)練方式來(lái)處理基于transformer的圖像復(fù)原問(wèn)題
相關(guān)工作
(這里不得不感嘆看到這位作者介紹相關(guān)工作,都有一種被俯視的感覺(jué),之前的一篇論文直接點(diǎn)某某,某某,are good examples, 這次直接建議閱讀 NTIRE 挑戰(zhàn)報(bào)告了)
方法
文章pipeline,類似Unet結(jié)構(gòu)
SA設(shè)計(jì)
這里最大的改動(dòng)就是把HW * HW的attention變成了通道 * 通道的attention,計(jì)算量是降下來(lái)了,但是不過(guò)是把全局特征通道重組,沒(méi)有辦法建立空間像素關(guān)系的依賴,建立像素依賴的部分實(shí)際上還是3 * 3的按通道分組卷積Dconv(綠色方框)部分,(看到這樣的設(shè)計(jì)都能有效果也是驚了)
其中,消融實(shí)驗(yàn),可以看到 (a)(b)差別不大,但是MTA加上一個(gè)3 * 3的Dconv的提升很大,SA代碼
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b,c,h,w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q,k,v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
FN設(shè)計(jì)
和傳統(tǒng)的Feed-forward部分不同,這里分了兩支進(jìn)行MLP,并且HW依舊保持排列好的狀態(tài)所以還是可以用3 * 3 分組卷積,下面的分支過(guò)了一個(gè)GeLU激活函數(shù)與上面的分支相乘
消融實(shí)驗(yàn)
可以看到 (b)(d)比較,單加上一個(gè)gated分支反倒效果不好,但(b)(e)直接上3 * 3的按通道分組卷積效果提升很明顯,起作用的還是3 * 3的卷積核來(lái)學(xué)習(xí)空間信息
FN的設(shè)計(jì)代碼
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
實(shí)驗(yàn)
作者做了去雨、去糊、去噪等實(shí)驗(yàn),在各個(gè)數(shù)據(jù)集上效果都挺好的
去糊實(shí)驗(yàn)結(jié)果文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-745224.html
總結(jié)
雖然這是一篇transformer的文章,但是通道與通道之間的注意力和傳統(tǒng)的Transformer也沒(méi)什么聯(lián)系了,并且前文花了很多篇幅講transformer可以建立起 long-range pixel interactions,但是網(wǎng)絡(luò)設(shè)計(jì)卻仍然還是沒(méi)有利用到transformer的全局像素依賴的這個(gè)屬性
(個(gè)人疑惑的一個(gè)點(diǎn)是在于,既然簡(jiǎn)單的幾層堆疊 [4,6,6,8] 的3*3的空間像素層上的卷積依賴已經(jīng)能有這么好的效果,long-range pixel interactions對(duì)于low-level的任務(wù)真的有必要嗎…)文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-745224.html
到了這里,關(guān)于論文閱讀 | Restormer: Efficient Transformer for High-Resolution Image Restoration的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!