paper:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
摘要
- 把transformer直接應(yīng)用于圖像塊序列,也可以在圖像分類(lèi)任務(wù)上表現(xiàn)很好。
- 通過(guò)在大數(shù)據(jù)集上預(yù)訓(xùn)練,然后遷移到中等規(guī)模和小規(guī)模數(shù)據(jù)集上,ViT可以取得和SOTA的卷積網(wǎng)絡(luò)同樣出色(甚至更好)的結(jié)果,同時(shí)需要更少的訓(xùn)練資源。
介紹
1、將標(biāo)準(zhǔn)transformer直接應(yīng)用于圖像,只做最小可能修改
將一幅圖像分割成多個(gè)圖像塊,然后將這些圖像塊的embedding序列作為輸入,送到transformer。這里的圖像塊類(lèi)似于NLP中的token。
2、在中等規(guī)模數(shù)據(jù)集(如ImageNet)上訓(xùn)練ViT,模型結(jié)果會(huì)比resnet結(jié)構(gòu)的模型低一點(diǎn)。
和CNN相比,transformer缺乏一些歸納偏置(inductive bias),比如平移不變性和局部性。但是在大規(guī)模數(shù)據(jù)集上,直接從數(shù)據(jù)中學(xué)習(xí),更加有效。
方法
網(wǎng)絡(luò)結(jié)構(gòu)
圖片來(lái)源:https://zhuanlan.zhihu.com/p/342261872
輸入圖像維度為$$H×W×C$$,分割成N個(gè)$$P×P$$大小的圖像塊,N為$$HW/P^2$$,圖像塊通過(guò)線性映射得到D維的向量,D在transformer的所有層中保持不變。
不同層的操作計(jì)算過(guò)程如下:
公式1是將圖像塊映射成embedding,這里加了一個(gè)可學(xué)習(xí)的class token $$x_{class}$$(類(lèi)似BERT),與其他圖像塊嵌入向量一起輸入到 Transformer 編碼器中,其在網(wǎng)絡(luò)最后的輸出,作為整個(gè)圖像的表示y,就是公式4中的結(jié)果。Transformer 編碼器中的具體過(guò)程這里不作展開(kāi),可參考Transformer原理理解_qiumokucao的博客-CSDN博客。
公式2是multiheaded self-attention的計(jì)算過(guò)程,公式3是MLP的計(jì)算過(guò)程。
實(shí)際實(shí)現(xiàn)過(guò)程中,圖像塊映射成embedding可以通過(guò)卷積實(shí)現(xiàn):
# 其中fh,fw是patch的高和寬,讓卷積核的大小和stride與patch大小相等
self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(fh, fw), stride=(fh, fw))
?另外,網(wǎng)絡(luò)最后接MLP head的時(shí)候,可以只使用class token對(duì)應(yīng)的結(jié)果(如公式4中描述),也可以對(duì)所有結(jié)果進(jìn)行pooling,然后接MLP head。參考https://github.com/lucidrains/vit-pytorch.git中實(shí)現(xiàn):
def forward(self, img):
x = self.to_patch_embedding(img) #圖像轉(zhuǎn)成embedding
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1) #引入cls_tokens
x += self.pos_embedding[:, :(n + 1)] #加入位置embedding
x = self.dropout(x)
x = self.transformer(x)
# 根據(jù)設(shè)置選擇cls_tokens對(duì)應(yīng)的輸出或者進(jìn)行pooling
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
return self.mlp_head(x)
Hybrid Architecture
可以將ViT應(yīng)用于CNN的特征之上,區(qū)別就是這里把CNN的特征映射為embedding,其余部分跟ViT的處理過(guò)程一樣
模型微調(diào)(fine-tune)
在大規(guī)模數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練,然后在下游任務(wù)中進(jìn)行微調(diào)。微調(diào)時(shí),把預(yù)訓(xùn)練的預(yù)測(cè)頭去掉,添加一個(gè)$$D×K$$的全連接層,K為預(yù)測(cè)類(lèi)別數(shù)。
微調(diào)時(shí)可以采用更大的輸入分辨率,保持patch size不變,這樣輸入到transformer的序列長(zhǎng)度會(huì)變長(zhǎng),事實(shí)上ViT可以處理任意長(zhǎng)的序列,不過(guò)預(yù)訓(xùn)練的position embedding就失去意義了,這時(shí)作者對(duì)position embedding進(jìn)行了2D插值處理。
實(shí)驗(yàn)結(jié)果
數(shù)據(jù)集
ImageNet:1.3M images,1k classes
ImageNet-21k:14M images,21k classes
JFT:303M high-resolution images,18k classes
模型參數(shù)
Layers:Encoder Block 數(shù)量
Hidden Size D:隱藏層特征大小,其在各 Encoder Block 保持一致
MLP Size:MLP 特征大小,通常設(shè)為 4D
Heads:MSA 中的 heads 數(shù)量
Patch Size:模型輸入的 Patch size,ViT 中共有兩個(gè)設(shè)置:14x14 和 16x16,該參數(shù)僅影響計(jì)算量,patch size越小,序列長(zhǎng)度越長(zhǎng),計(jì)算量越大。
實(shí)驗(yàn)結(jié)果
?文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-544622.html
JFT+TPU的鈔能力!?文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-544622.html
到了這里,關(guān)于Vision Transformer (ViT)介紹的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!