一、代碼
import torch
import torch.nn as nn
from einops import rearrange
from self_attention_cv import TransformerEncoder
class ViT(nn.Module):
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
"""
Args:
img_dim: the spatial image size
in_channels: number of img channels
patch_dim: desired patch dim
num_classes: classification task classes
dim: the linear layer's dim to project the patches for MHSA
blocks: number of transformer blocks
heads: number of heads
dim_linear_block: inner dim of the transformer linear block
dim_head: dim head in case you want to define it. defaults to dim/heads
dropout: for pos emb and transformer
transformer: in case you want to provide another transformer implementation
classification: creates an extra CLS token
"""
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
def expand_cls_to_batch(self, batch):
"""
Args:
batch: batch size
Returns: cls token expanded to the batch size
"""
return self.cls_token.expand([batch, -1, -1])
def forward(self, img, mask=None):
batch_size = img.shape[0]
img_patches = rearrange(
img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat(
(self.expand_cls_to_batch(batch_size), img_patches), dim=1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
if self.classification:
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
二、代碼解讀
2.1 大體理解
這段代碼是一個實現(xiàn)了 Vision Transformer(ViT)模型的 PyTorch 實現(xiàn)。
ViT 是一個基于 Transformer 架構(gòu)的圖像分類模型,其主要思想是將圖像分成一個個固定大小的 patch ,并將這些 patch 看做是一個個 token 輸入到 Transformer 中進行特征提取和分類。
以下是對代碼的解讀:
- ViT類繼承自
nn.Module
類,其構(gòu)造函數(shù)有一系列參數(shù),包括輸入圖像的尺寸、patch的大小、輸出類別數(shù)、注意力機制中的頭數(shù)等等。 -
project_patches
函數(shù)通過一個全連接層將每個patch映射到一個d維的特征空間中。 - 如果
classification = True
,則將一個額外的CLS token添加到輸入的token序列的開頭,即對于每張圖像添加一個形狀為[1, 1, d]的CLS token。同時,在ViT中采用的是絕對位置編碼,因此還添加了一個1D的位置編碼向量,其形狀為[num_patches + 1, d],其中num_patches表示圖像被劃分成的patch數(shù)目。如果classification = False,則不添加CLS token。 -
forward
函數(shù)首先將輸入的圖像進行patch劃分,并通過project_patches函數(shù)將每個patch映射到d維特征空間中。接著,將位置編碼向量加到映射后的patch特征向量上,并進行dropout處理。如果classification=True,則在特征序列開頭添加CLS token。接著將這些特征輸入到Transformer中,進行特征提取。最后輸出分類結(jié)果,如果classification=True,則只返回CLS token的分類結(jié)果。
2.2 詳細理解
from self_attention_cv import TransformerEncoder
self_attention_cv
是一個基于PyTorch實現(xiàn)的庫,提供了在計算機視覺任務(wù)中使用自注意力機制的模塊和網(wǎng)絡(luò),例如Transformer Encoder
和Attention Modules
。
它主要針對圖像分類、對象檢測、語義分割等任務(wù),支持多種自注意力模塊的實現(xiàn),包括Simplified Self-Attention
、Full Self-Attention
和Local Self-Attention
等。此外,該庫還提供了一些常見的計算機視覺任務(wù)模型的實現(xiàn),例如Vision Transformer(ViT)
和Swin Transformer
等。
TransformerEncoder
是一個自注意力機制的編碼器,用于將輸入序列轉(zhuǎn)換為編碼后的序列。自注意力機制允許模型能夠根據(jù)輸入序列中的其他位置來加權(quán)計算每個位置的表示。這種機制在自然語言處理中的應(yīng)用非常廣泛,比如BERT、GPT等模型都采用了自注意力機制。
TransformerEncoder
是基于PyTorch實現(xiàn)的,可以在計算機視覺任務(wù)中使用,例如圖像分類、對象檢測、語義分割等。它支持多頭注意力、殘差連接和LayerNorm等特性。在這個代碼中,ViT模型中的Transformer部分采用了TransformerEncoder作為默認的實現(xiàn)。
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
super().__init__()
assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
這段代碼定義了一個名為 ViT 的 PyTorch 模型類,它是一個使用自注意力機制(Self-Attention)實現(xiàn)的視覺 Transformer 模型。其中主要參數(shù)包括:
-
img_dim
:輸入圖片的空間大小 -
in_channels
:輸入圖片的通道數(shù) -
patch_dim
:將圖片劃分成固定大小的 patch 的大小 -
num_classes
:分類任務(wù)的類別數(shù) -
dim
:線性層的維度,用于將每個 patch 投影到 MHSA 空間 -
blocks
:Transformer 模型中的塊數(shù) -
heads
:注意力頭的數(shù)量 -
dim_linear_block
:線性塊內(nèi)部的維度 -
dim_head
:每個頭的維度,如果沒有指定則默認為 dim/heads -
dropout
:用于位置編碼和 Transformer 的 dropout 概率 -
transformer
:可選的 TransformerEncoder 類實例 -
classification
:是否包含額外的 CLS 標記以用于分類任務(wù)
def __init__(self, *,
img_dim,
in_channels=3,
patch_dim=16,
num_classes=10,
dim=512,
blocks=6,
heads=4,
dim_linear_block=1024,
dim_head=None,
dropout=0, transformer=None, classification=True):
super().__init__()
這里定義了 ViT 類的構(gòu)造函數(shù),其包含多個參數(shù),包括輸入圖像大小 img_dim
,輸入通道數(shù) in_channels
,分塊大小 patch_dim
,分類數(shù)目 num_classes
,嵌入維度 dim
,Transformer編碼器的塊數(shù) blocks
,頭數(shù) heads
,線性塊的維度 dim_linear_block
,注意力頭維度 dim_head
,Dropout概率 dropout
,可選的Transformer編碼器 transformer
,以及是否進行分類的標志 classification
。
assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
self.p = patch_dim
self.classification = classification
這里檢查 img_dim
是否能夠被 patch_dim
整除,如果不能整除,則會引發(fā)斷言錯誤。同時,將 patch_dim
存儲到 self.p
中,并將是否進行分類的標志存儲到 self.classification
中。
tokens = (img_dim // patch_dim) ** 2
self.token_dim = in_channels * (patch_dim ** 2)
self.dim = dim
self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
self.project_patches = nn.Linear(self.token_dim, dim)
這里計算了輸入圖像中可分塊的數(shù)量 tokens
,并將每個塊的維度 self.token_dim
設(shè)置為 in_channels * (patch_dim ** 2)
。
將嵌入維度 dim
存儲到 self.dim
中,并根據(jù) dim_head
是否為 None,設(shè)置注意力頭維度 self.dim_head
。self.project_patches
是一個線性層,用于將每個塊投影到嵌入空間中。
self.emb_dropout = nn.Dropout(dropout)
if self.classification:
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
self.mlp_head = nn.Linear(dim, num_classes)
else:
self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))
這里定義了嵌入層的Dropout
層,并根據(jù)是否進行分類的標志,設(shè)置類別標記 self.cls_token
、位置嵌入 self.pos_emb1D
和 MLP
頭 self.mlp_head
。如果不進行分類,則不需要 self.cls_token
和 self.mlp_head
。
if transformer is None:
self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
dim_head=self.dim_head,
dim_linear_block=dim_linear_block,
dropout=dropout)
else:
self.transformer = transformer
self.emb_dropout = nn.Dropout(dropout):
定義了一個dropout
層,用于在embedding后對其進行dropout操作。
if self.classification::
如果是分類任務(wù),就執(zhí)行下面的操作,否則跳過。
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)):
定義了一個可訓練參數(shù)cls_token
,表示分類token,它是一個1x1xdim的tensor,其中dim表示embedding維度。
self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)):
定義了一個可訓練參數(shù)pos_emb1D,表示位置嵌入,它是一個(tokens+1)xdim的tensor,其中tokens表示圖像被分成的patch數(shù),dim表示embedding維度。
self.mlp_head = nn.Linear(dim, num_classes):
定義了一個全連接層,將embedding映射到輸出類別的數(shù)量。
最后,根據(jù)傳入的參數(shù)來選擇使用默認的TransformerEncoder,還是使用傳入的transformer。如果沒有傳入,則使用默認的TransformerEncoder,否則使用傳入的transformer。
def expand_cls_to_batch(self, batch):
"""
Args:
batch: batch size
Returns: cls token expanded to the batch size
"""
return self.cls_token.expand([batch, -1, -1])
該方法的作用是將 Transformer 中的分類 token 擴展到整個批次的樣本數(shù)。它接受一個 batch 參數(shù)作為批次大小,返回一個形狀為 [batch, 1, dim] 的張量,其中 dim 是 Transformer 模型的維度大小。在這個方法中,使用了 PyTorch 的 expand() 方法來實現(xiàn)擴展操作。
def forward(self, img, mask=None):
batch_size = img.shape[0]
img_patches = rearrange(
img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
patch_x=self.p, patch_y=self.p)
# project patches with linear layer + add pos emb
img_patches = self.project_patches(img_patches)
if self.classification:
img_patches = torch.cat(
(self.expand_cls_to_batch(batch_size), img_patches), dim=1)
patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)
# feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
y = self.transformer(patch_embeddings, mask)
if self.classification:
# we index only the cls token for classification. nlp tricks :P
return self.mlp_head(y[:, 0, :])
else:
return y
在 forward
函數(shù)中,接收輸入的 img
和 mask
。
通過 img_dim
和 patch_dim
計算出 tokens
數(shù)量,其中 tokens 為圖像分割成的塊的數(shù)量。
將輸入的 img 分成 patch,并通過 rearrange 函數(shù)重組成形狀為 [batch_size, tokens, patch_dim * patch_dim * in_channels]
的張量。
通過 Linear
層將每個 patch 映射到 dim 維度,并加上位置編碼 pos_emb1D
。
如果是用于分類任務(wù),則在序列的開頭插入一個 CLS token
,然后與處理后的 patch 張量按列拼接。
對 patch_embeddings 應(yīng)用 dropout,并輸入到 TransformerEncoder 中,返回輸出張量 y,形狀為 [batch_size, tokens, dim]
。
如果是用于分類任務(wù),則從 y 中取出 CLS token,輸入到一個 Linear 層中進行分類,輸出分類結(jié)果。文章來源:http://www.zghlxwxcb.cn/news/detail-436562.html
如果不是分類任務(wù),則直接返回 y。文章來源地址http://www.zghlxwxcb.cn/news/detail-436562.html
到了這里,關(guān)于【計算機視覺】ViT:代碼逐行解讀的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!