国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

【計算機視覺】ViT:代碼逐行解讀

這篇具有很好參考價值的文章主要介紹了【計算機視覺】ViT:代碼逐行解讀。希望對大家有所幫助。如果存在錯誤或未考慮完全的地方,請大家不吝賜教,您也可以點擊"舉報違法"按鈕提交疑問。

一、代碼

import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                  dim_head=self.dim_head,
                                                  dim_linear_block=dim_linear_block,
                                                  dropout=dropout)
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

二、代碼解讀

2.1 大體理解

這段代碼是一個實現(xiàn)了 Vision Transformer(ViT)模型的 PyTorch 實現(xiàn)。

ViT 是一個基于 Transformer 架構(gòu)的圖像分類模型,其主要思想是將圖像分成一個個固定大小的 patch ,并將這些 patch 看做是一個個 token 輸入到 Transformer 中進行特征提取和分類。

以下是對代碼的解讀:

  1. ViT類繼承自nn.Module類,其構(gòu)造函數(shù)有一系列參數(shù),包括輸入圖像的尺寸、patch的大小、輸出類別數(shù)、注意力機制中的頭數(shù)等等。
  2. project_patches函數(shù)通過一個全連接層將每個patch映射到一個d維的特征空間中。
  3. 如果classification = True,則將一個額外的CLS token添加到輸入的token序列的開頭,即對于每張圖像添加一個形狀為[1, 1, d]的CLS token。同時,在ViT中采用的是絕對位置編碼,因此還添加了一個1D的位置編碼向量,其形狀為[num_patches + 1, d],其中num_patches表示圖像被劃分成的patch數(shù)目。如果classification = False,則不添加CLS token。
  4. forward函數(shù)首先將輸入的圖像進行patch劃分,并通過project_patches函數(shù)將每個patch映射到d維特征空間中。接著,將位置編碼向量加到映射后的patch特征向量上,并進行dropout處理。如果classification=True,則在特征序列開頭添加CLS token。接著將這些特征輸入到Transformer中,進行特征提取。最后輸出分類結(jié)果,如果classification=True,則只返回CLS token的分類結(jié)果。

2.2 詳細理解

from self_attention_cv import TransformerEncoder

self_attention_cv是一個基于PyTorch實現(xiàn)的庫,提供了在計算機視覺任務(wù)中使用自注意力機制的模塊和網(wǎng)絡(luò),例如Transformer EncoderAttention Modules。

它主要針對圖像分類、對象檢測、語義分割等任務(wù),支持多種自注意力模塊的實現(xiàn),包括Simplified Self-AttentionFull Self-AttentionLocal Self-Attention等。此外,該庫還提供了一些常見的計算機視覺任務(wù)模型的實現(xiàn),例如Vision Transformer(ViT)Swin Transformer等。

TransformerEncoder是一個自注意力機制的編碼器,用于將輸入序列轉(zhuǎn)換為編碼后的序列。自注意力機制允許模型能夠根據(jù)輸入序列中的其他位置來加權(quán)計算每個位置的表示。這種機制在自然語言處理中的應(yīng)用非常廣泛,比如BERT、GPT等模型都采用了自注意力機制。

TransformerEncoder是基于PyTorch實現(xiàn)的,可以在計算機視覺任務(wù)中使用,例如圖像分類、對象檢測、語義分割等。它支持多頭注意力、殘差連接和LayerNorm等特性。在這個代碼中,ViT模型中的Transformer部分采用了TransformerEncoder作為默認的實現(xiàn)。

def __init__(self, *,
                img_dim,
                in_channels=3,
                patch_dim=16,
                num_classes=10,
                dim=512,
                blocks=6,
                heads=4,
                dim_linear_block=1024,
                dim_head=None,
                dropout=0, transformer=None, classification=True):
    super().__init__()
    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification
    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

    if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

這段代碼定義了一個名為 ViT 的 PyTorch 模型類,它是一個使用自注意力機制(Self-Attention)實現(xiàn)的視覺 Transformer 模型。其中主要參數(shù)包括:

  • img_dim:輸入圖片的空間大小
  • in_channels:輸入圖片的通道數(shù)
  • patch_dim:將圖片劃分成固定大小的 patch 的大小
  • num_classes:分類任務(wù)的類別數(shù)
  • dim:線性層的維度,用于將每個 patch 投影到 MHSA 空間
  • blocks:Transformer 模型中的塊數(shù)
  • heads:注意力頭的數(shù)量
  • dim_linear_block:線性塊內(nèi)部的維度
  • dim_head:每個頭的維度,如果沒有指定則默認為 dim/heads
  • dropout:用于位置編碼和 Transformer 的 dropout 概率
  • transformer:可選的 TransformerEncoder 類實例
  • classification:是否包含額外的 CLS 標記以用于分類任務(wù)
def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
    super().__init__()

這里定義了 ViT 類的構(gòu)造函數(shù),其包含多個參數(shù),包括輸入圖像大小 img_dim,輸入通道數(shù) in_channels,分塊大小 patch_dim,分類數(shù)目 num_classes,嵌入維度 dim,Transformer編碼器的塊數(shù) blocks,頭數(shù) heads,線性塊的維度 dim_linear_block,注意力頭維度 dim_head,Dropout概率 dropout,可選的Transformer編碼器 transformer,以及是否進行分類的標志 classification。

    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification

這里檢查 img_dim 是否能夠被 patch_dim 整除,如果不能整除,則會引發(fā)斷言錯誤。同時,將 patch_dim 存儲到 self.p 中,并將是否進行分類的標志存儲到 self.classification 中。

    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

這里計算了輸入圖像中可分塊的數(shù)量 tokens,并將每個塊的維度 self.token_dim 設(shè)置為 in_channels * (patch_dim ** 2)。

將嵌入維度 dim 存儲到 self.dim 中,并根據(jù) dim_head 是否為 None,設(shè)置注意力頭維度 self.dim_head。self.project_patches 是一個線性層,用于將每個塊投影到嵌入空間中。

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

這里定義了嵌入層的Dropout層,并根據(jù)是否進行分類的標志,設(shè)置類別標記 self.cls_token、位置嵌入 self.pos_emb1DMLPself.mlp_head。如果不進行分類,則不需要 self.cls_tokenself.mlp_head。

if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

self.emb_dropout = nn.Dropout(dropout): 定義了一個dropout層,用于在embedding后對其進行dropout操作。

if self.classification:: 如果是分類任務(wù),就執(zhí)行下面的操作,否則跳過。

self.cls_token = nn.Parameter(torch.randn(1, 1, dim)): 定義了一個可訓練參數(shù)cls_token,表示分類token,它是一個1x1xdim的tensor,其中dim表示embedding維度。

self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)): 定義了一個可訓練參數(shù)pos_emb1D,表示位置嵌入,它是一個(tokens+1)xdim的tensor,其中tokens表示圖像被分成的patch數(shù),dim表示embedding維度。

self.mlp_head = nn.Linear(dim, num_classes): 定義了一個全連接層,將embedding映射到輸出類別的數(shù)量。

最后,根據(jù)傳入的參數(shù)來選擇使用默認的TransformerEncoder,還是使用傳入的transformer。如果沒有傳入,則使用默認的TransformerEncoder,否則使用傳入的transformer。

def expand_cls_to_batch(self, batch):
    """
    Args:
        batch: batch size
    Returns: cls token expanded to the batch size
    """
    return self.cls_token.expand([batch, -1, -1])

該方法的作用是將 Transformer 中的分類 token 擴展到整個批次的樣本數(shù)。它接受一個 batch 參數(shù)作為批次大小,返回一個形狀為 [batch, 1, dim] 的張量,其中 dim 是 Transformer 模型的維度大小。在這個方法中,使用了 PyTorch 的 expand() 方法來實現(xiàn)擴展操作。

def forward(self, img, mask=None):
    batch_size = img.shape[0]
    img_patches = rearrange(
        img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.p, patch_y=self.p)
    # project patches with linear layer + add pos emb
    img_patches = self.project_patches(img_patches)

    if self.classification:
        img_patches = torch.cat(
            (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

    patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

    # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
    y = self.transformer(patch_embeddings, mask)

    if self.classification:
        # we index only the cls token for classification. nlp tricks :P
        return self.mlp_head(y[:, 0, :])
    else:
        return y

forward 函數(shù)中,接收輸入的 imgmask。

通過 img_dimpatch_dim 計算出 tokens 數(shù)量,其中 tokens 為圖像分割成的塊的數(shù)量。

將輸入的 img 分成 patch,并通過 rearrange 函數(shù)重組成形狀為 [batch_size, tokens, patch_dim * patch_dim * in_channels] 的張量。

通過 Linear 層將每個 patch 映射到 dim 維度,并加上位置編碼 pos_emb1D。

如果是用于分類任務(wù),則在序列的開頭插入一個 CLS token,然后與處理后的 patch 張量按列拼接。

對 patch_embeddings 應(yīng)用 dropout,并輸入到 TransformerEncoder 中,返回輸出張量 y,形狀為 [batch_size, tokens, dim]。

如果是用于分類任務(wù),則從 y 中取出 CLS token,輸入到一個 Linear 層中進行分類,輸出分類結(jié)果。

如果不是分類任務(wù),則直接返回 y。文章來源地址http://www.zghlxwxcb.cn/news/detail-436562.html

到了這里,關(guān)于【計算機視覺】ViT:代碼逐行解讀的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務(wù),不擁有所有權(quán),不承擔相關(guān)法律責任。如若轉(zhuǎn)載,請注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實不符,請點擊違法舉報進行投訴反饋,一經(jīng)查實,立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費用

相關(guān)文章

  • 【計算機視覺 | Kaggle】飛機凝結(jié)軌跡識別 Baseline 分享和解讀(含源代碼)

    【計算機視覺 | Kaggle】飛機凝結(jié)軌跡識別 Baseline 分享和解讀(含源代碼)

    比賽名稱:Google Research - Identify Contrails to Reduce Global Warming 訓練 ML 模型以識別衛(wèi)星圖像中的尾跡 比賽類型:計算機視覺、語義分割 Contrails 是“凝結(jié)軌跡”的縮寫,是在飛機發(fā)動機排氣中形成的線狀冰晶云,由飛機飛過大氣中的超潮濕區(qū)域時產(chǎn)生。持續(xù)的尾跡對全球變暖的貢

    2024年02月14日
    瀏覽(21)
  • 【計算機視覺】Visual Transformer (ViT)模型結(jié)構(gòu)以及原理解析

    【計算機視覺】Visual Transformer (ViT)模型結(jié)構(gòu)以及原理解析

    Visual Transformer (ViT) 出自于論文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在視覺領(lǐng)域的開篇之作。 本文將盡可能簡潔地介紹一下ViT模型的整體架構(gòu)以及基本原理。 ViT模型是基于Transformer Encoder模型的,在這里假設(shè)讀者已經(jīng)了解Transfo

    2024年02月02日
    瀏覽(33)
  • 【計算機視覺】Gaussian Splatting源碼解讀補充(二)

    【計算機視覺】Gaussian Splatting源碼解讀補充(二)

    第一部分 本文是對學習筆記之——3D Gaussian Splatting源碼解讀的補充,并訂正了一些錯誤。 其中出現(xiàn)的輔助函數(shù): 這部分的參考資料: [1] CUDA Tutorial [2] An Even Easier Introduction to CUDA [3] Introduction to CUDA Programming CUDA是一個為支持CUDA的GPU提供的平臺和編程模型。該平臺使GPU能夠進

    2024年04月10日
    瀏覽(72)
  • 【計算機視覺】Gaussian Splatting源碼解讀補充(一)

    【計算機視覺】Gaussian Splatting源碼解讀補充(一)

    本文旨在補充@gwpscut創(chuàng)作的博文學習筆記之——3D Gaussian Splatting源碼解讀。 Gaussian Splatting Github地址:https://github.com/graphdeco-inria/gaussian-splatting 論文地址:https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf 這部分可以參考PlenOctrees論文的附錄B。 有時候從不同的

    2024年04月09日
    瀏覽(52)
  • 【計算機視覺 | 目標檢測】Grounding DINO:開集目標檢測論文解讀

    【計算機視覺 | 目標檢測】Grounding DINO:開集目標檢測論文解讀

    介紹一篇較新的目標檢測工作: 論文地址為: github 地址為: 作者展示一種開集目標檢測方案: Grounding DINO ,將將基于 Transformer 的檢測器 DINO 與真值預訓練相結(jié)合。 開集檢測關(guān)鍵是引入 language 至閉集檢測器,用于開集概念泛化。作者將閉集檢測器分為三個階段,提出一種

    2024年02月10日
    瀏覽(92)
  • 【計算機視覺 | 目標檢測】Open-Vocabulary DETR with Conditional Matching論文解讀

    【計算機視覺 | 目標檢測】Open-Vocabulary DETR with Conditional Matching論文解讀

    論文題目:具有條件匹配的開放詞匯表DETR 開放詞匯對象檢測是指在自然語言的引導下對新對象進行檢測的問題,越來越受到社會的關(guān)注。理想情況下,我們希望擴展一個開放詞匯表檢測器,這樣它就可以基于自然語言或范例圖像形式的用戶輸入生成邊界框預測。這為人機交

    2024年01月21日
    瀏覽(18)
  • 13 計算機視覺-代碼詳解

    為了防止在訓練集上過擬合,有兩種辦法,第一種是擴大訓練集數(shù)量,但是需要大量的成本;第二種就是應(yīng)用遷移學習,將源數(shù)據(jù)學習到的知識遷移到目標數(shù)據(jù)集,即在把在源數(shù)據(jù)訓練好的參數(shù)和模型(除去輸出層)直接復制到目標數(shù)據(jù)集訓練。 13.2.1 獲取數(shù)據(jù)集 ?13.2.2 初始

    2024年02月12日
    瀏覽(19)
  • 計算機視覺之姿態(tài)識別(原理+代碼實操)

    計算機視覺之姿態(tài)識別(原理+代碼實操)

    ?人體分割使用的方法可以大體分為人體骨骼關(guān)鍵點檢測、語義分割等方式實現(xiàn)。這里主要分析與姿態(tài)相關(guān)的人體骨骼關(guān)鍵點檢測。人體骨骼關(guān)鍵點檢測輸出是人體的骨架信息,一般主要作為人體姿態(tài)識別的基礎(chǔ)部分,主要用于分割、對齊等。一般實現(xiàn)流程為: ?主要檢測人

    2023年04月16日
    瀏覽(22)
  • 【計算機視覺】DINOv2(視覺大模型)代碼使用和測試(完整的源代碼)

    【計算機視覺】DINOv2(視覺大模型)代碼使用和測試(完整的源代碼)

    輸出為: 命令是一個Git命令,用于克?。–lone)名為\\\"dinov2\\\"的存儲庫。它使用了一個名為\\\"ghproxy.com\\\"的代理,用于加速GitHub的克隆操作。 我們需要切換為output的路徑: 以下是代碼的逐行中文解讀: 這段代碼的功能是對給定的圖像進行一系列處理和特征提取,并使用PCA對特征進

    2024年02月16日
    瀏覽(26)
  • 【計算機視覺】YOLOv8如何使用?(含源代碼)

    comments description keywords true Boost your Python projects with object detection, segmentation and classification using YOLOv8. Explore how to load, train, validate, predict, export, track and benchmark models with ease. YOLOv8, Ultralytics, Python, object detection, segmentation, classification, model training, validation, prediction, model export, bench

    2024年02月04日
    瀏覽(29)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包