近些年,隨著基于自注意(Self-Attention)結(jié)構(gòu)的模型的發(fā)展,特別是Transformer模型的提出,極大地促進(jìn)了自然語(yǔ)言處理模型的發(fā)展。由于Transformers的計(jì)算效率和可擴(kuò)展性,它已經(jīng)能夠訓(xùn)練具有超過(guò)100B參數(shù)的空前規(guī)模的模型。
ViT則是自然語(yǔ)言處理和計(jì)算機(jī)視覺(jué)兩個(gè)領(lǐng)域的融合結(jié)晶。在不依賴(lài)卷積操作的情況下,依然可以在圖像分類(lèi)任務(wù)上達(dá)到很好的效果。
模型結(jié)構(gòu)
ViT模型的主體結(jié)構(gòu)是基于Transformer模型的Encoder部分(部分結(jié)構(gòu)順序有調(diào)整,如:Normalization的位置與標(biāo)準(zhǔn)Transformer不同),其結(jié)構(gòu)圖[1]如下:
模型特點(diǎn)
ViT模型主要應(yīng)用于圖像分類(lèi)領(lǐng)域。因此,其模型結(jié)構(gòu)相較于傳統(tǒng)的Transformer有以下幾個(gè)特點(diǎn):
數(shù)據(jù)集的原圖像被劃分為多個(gè)patch后,將二維patch(不考慮channel)轉(zhuǎn)換為一維向量,再加上類(lèi)別向量與位置向量作為模型輸入。
模型主體的Block結(jié)構(gòu)是基于Transformer的Encoder結(jié)構(gòu),但是調(diào)整了Normalization的位置,其中,最主要的結(jié)構(gòu)依然是Multi-head Attention結(jié)構(gòu)。
模型在Blocks堆疊后接全連接層,接受類(lèi)別向量的輸出作為輸入并用于分類(lèi)。通常情況下,我們將最后的全連接層稱(chēng)為Head,Transformer Encoder部分為backbone。
下面將通過(guò)代碼實(shí)例來(lái)詳細(xì)解釋基于ViT實(shí)現(xiàn)ImageNet分類(lèi)任務(wù)。
如果你對(duì)MindSpore感興趣,可以關(guān)注昇思MindSpore社區(qū)
一、環(huán)境準(zhǔn)備
1.進(jìn)入ModelArts官網(wǎng)
云平臺(tái)幫助用戶(hù)快速創(chuàng)建和部署模型,管理全周期AI工作流,選擇下面的云平臺(tái)以開(kāi)始使用昇思MindSpore,獲取安裝命令,安裝MindSpore2.0.0-alpha版本,可以在昇思教程中進(jìn)入ModelArts官網(wǎng)
選擇下方CodeLab立即體驗(yàn)
等待環(huán)境搭建完成
2.使用CodeLab體驗(yàn)Notebook實(shí)例
下載NoteBook樣例代碼,Vision Transformer圖像分類(lèi) ,.ipynb
為樣例代碼
選擇ModelArts Upload Files上傳.ipynb
文件
選擇Kernel環(huán)境
切換至GPU環(huán)境,切換成第一個(gè)限時(shí)免費(fèi)
進(jìn)入昇思MindSpore官網(wǎng),點(diǎn)擊上方的安裝
獲取安裝命令
回到Notebook中,在第一塊代碼前加入命令
conda update -n base -c defaults conda
安裝MindSpore 2.0 GPU版本
conda install mindspore=2.0.0a0 -c mindspore -c conda-forge
安裝mindvision
pip install mindvision
安裝下載download
pip install download
二、環(huán)境準(zhǔn)備與數(shù)據(jù)讀取
開(kāi)始實(shí)驗(yàn)之前,請(qǐng)確保本地已經(jīng)安裝了Python環(huán)境并安裝了MindSpore。
首先我們需要下載本案例的數(shù)據(jù)集,可通過(guò)http://image-net.org下載完整的ImageNet數(shù)據(jù)集,本案例應(yīng)用的數(shù)據(jù)集是從ImageNet中篩選出來(lái)的子集。
運(yùn)行第一段代碼時(shí)會(huì)自動(dòng)下載并解壓,請(qǐng)確保你的數(shù)據(jù)集路徑如以下結(jié)構(gòu)。
.dataset/
├── ILSVRC2012_devkit_t12.tar.gz
├── train/
├── infer/
└── val/
from download import download
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/vit_imagenet_dataset.zip"
path = "./"
path = download(dataset_url, path, kind="zip", replace=True)
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms
data_path = './dataset/'
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
三、模型解析
下面將通過(guò)代碼來(lái)細(xì)致剖析ViT模型的內(nèi)部結(jié)構(gòu)。
Transformer基本原理
Transformer模型源于2017年的一篇文章[2]。在這篇文章中提出的基于Attention機(jī)制的編碼器-解碼器型結(jié)構(gòu)在自然語(yǔ)言處理領(lǐng)域獲得了巨大的成功。模型結(jié)構(gòu)如下圖所示:
其主要結(jié)構(gòu)為多個(gè)Encoder和Decoder模塊所組成,其中Encoder和Decoder的詳細(xì)結(jié)構(gòu)如下圖[2]所示:
Encoder與Decoder由許多結(jié)構(gòu)組成,如:多頭注意力(Multi-Head Attention)層,F(xiàn)eed
Forward層,Normaliztion層,甚至殘差連接(Residual
Connection,圖中的“Add”)。不過(guò),其中最重要的結(jié)構(gòu)是多頭注意力(Multi-Head
Attention)結(jié)構(gòu),該結(jié)構(gòu)基于自注意力(Self-Attention)機(jī)制,是多個(gè)Self-Attention的并行組成。所以,理解了Self-Attention就抓住了Transformer的核心。
Attention模塊
from mindspore import nn, ops
class Attention(nn.Cell):
def __init__(self,
dim: int,
num_heads: int = 8,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = ms.Tensor(head_dim ** -0.5)
self.qkv = nn.Dense(dim, dim * 3)
self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
self.out = nn.Dense(dim, dim)
self.out_drop = nn.Dropout(p=1.0-keep_prob)
self.attn_matmul_v = ops.BatchMatMul()
self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
self.softmax = nn.Softmax(axis=-1)
def construct(self, x):
"""Attention construct."""
b, n, c = x.shape
qkv = self.qkv(x)
qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = ops.unstack(qkv, axis=0)
attn = self.q_matmul_k(q, k)
attn = ops.mul(attn, self.scale)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
out = self.attn_matmul_v(attn, v)
out = ops.transpose(out, (0, 2, 1, 3))
out = ops.reshape(out, (b, n, c))
out = self.out(out)
out = self.out_drop(out)
return out
Transformer Encoder
在了解了Self-Attention結(jié)構(gòu)之后,通過(guò)與Feed Forward,Residual
Connection等結(jié)構(gòu)的拼接就可以形成Transformer的基礎(chǔ)結(jié)構(gòu),下面代碼實(shí)現(xiàn)了Feed Forward,Residual
Connection結(jié)構(gòu)。
from typing import Optional, Dict
class FeedForward(nn.Cell):
def __init__(self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
activation: nn.Cell = nn.GELU,
keep_prob: float = 1.0):
super(FeedForward, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dense1 = nn.Dense(in_features, hidden_features)
self.activation = activation()
self.dense2 = nn.Dense(hidden_features, out_features)
self.dropout = nn.Dropout(p=1.0-keep_prob)
def construct(self, x):
"""Feed Forward construct."""
x = self.dense1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.dense2(x)
x = self.dropout(x)
return x
class ResidualCell(nn.Cell):
def __init__(self, cell):
super(ResidualCell, self).__init__()
self.cell = cell
def construct(self, x):
"""ResidualCell construct."""
return self.cell(x) + x
接下來(lái)就利用Self-Attention來(lái)構(gòu)建ViT模型中的TransformerEncoder部分,類(lèi)似于構(gòu)建了一個(gè)Transformer的編碼器部分,如下圖[1]所示:
vit-encoder
ViT模型中的基礎(chǔ)結(jié)構(gòu)與標(biāo)準(zhǔn)Transformer有所不同,主要在于Normalization的位置是放在Self-Attention和Feed
Forward之前,其他結(jié)構(gòu)如Residual Connection,F(xiàn)eed
Forward,Normalization都如Transformer中所設(shè)計(jì)。從Transformer結(jié)構(gòu)的圖片可以發(fā)現(xiàn),多個(gè)子encoder的堆疊就完成了模型編碼器的構(gòu)建,在ViT模型中,依然沿用這個(gè)思路,通過(guò)配置超參數(shù)num_layers,就可以確定堆疊層數(shù)。
Residual
Connection,Normalization的結(jié)構(gòu)可以保證模型有很強(qiáng)的擴(kuò)展性(保證信息經(jīng)過(guò)深層處理不會(huì)出現(xiàn)退化的現(xiàn)象,這是Residual
Connection的作用),Normalization和dropout的應(yīng)用可以增強(qiáng)模型泛化能力。從以下源碼中就可以清晰看到Transformer的結(jié)構(gòu)。將TransformerEncoder結(jié)構(gòu)和一個(gè)多層感知器(MLP)結(jié)合,就構(gòu)成了ViT模型的backbone部分。
class TransformerEncoder(nn.Cell):
def __init__(self,
dim: int,
num_layers: int,
num_heads: int,
mlp_dim: int,
keep_prob: float = 1.,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: nn.Cell = nn.LayerNorm):
super(TransformerEncoder, self).__init__()
layers = []
for _ in range(num_layers):
normalization1 = norm((dim,))
normalization2 = norm((dim,))
attention = Attention(dim=dim,
num_heads=num_heads,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob)
feedforward = FeedForward(in_features=dim,
hidden_features=mlp_dim,
activation=activation,
keep_prob=keep_prob)
layers.append(
nn.SequentialCell([
ResidualCell(nn.SequentialCell([normalization1, attention])),
ResidualCell(nn.SequentialCell([normalization2, feedforward]))
])
)
self.layers = nn.SequentialCell(layers)
def construct(self, x):
"""Transformer construct."""
return self.layers(x)
ViT模型的輸入
傳統(tǒng)的Transformer結(jié)構(gòu)主要用于處理自然語(yǔ)言領(lǐng)域的詞向量(Word Embedding or Word Vector),詞向量與傳統(tǒng)圖像數(shù)據(jù)的主要區(qū)別在于,詞向量通常是一維向量進(jìn)行堆疊,而圖片則是二維矩陣的堆疊,多頭注意力機(jī)制在處理一維詞向量的堆疊時(shí)會(huì)提取詞向量之間的聯(lián)系也就是上下文語(yǔ)義,這使得Transformer在自然語(yǔ)言處理領(lǐng)域非常好用,而二維圖片矩陣如何與一維詞向量進(jìn)行轉(zhuǎn)化就成為了Transformer進(jìn)軍圖像處理領(lǐng)域的一個(gè)小門(mén)檻。
在ViT模型中:
通過(guò)將輸入圖像在每個(gè)channel上劃分為16*16個(gè)patch,這一步是通過(guò)卷積操作來(lái)完成的,當(dāng)然也可以人工進(jìn)行劃分,但卷積操作也可以達(dá)到目的同時(shí)還可以進(jìn)行一次而外的數(shù)據(jù)處理;例如一幅輸入224
x 224的圖像,首先經(jīng)過(guò)卷積處理得到16 x 16個(gè)patch,那么每一個(gè)patch的大小就是14 x 14。
再將每一個(gè)patch的矩陣?yán)斐蔀橐粋€(gè)一維向量,從而獲得了近似詞向量堆疊的效果。上一步得到的14 x 14的patch就轉(zhuǎn)換為長(zhǎng)度為196的向量。
這是圖像輸入網(wǎng)絡(luò)經(jīng)過(guò)的第一步處理。具體Patch Embedding的代碼如下所示:
class PatchEmbedding(nn.Cell):
MIN_NUM_PATCHES = 4
def __init__(self,
image_size: int = 224,
patch_size: int = 16,
embed_dim: int = 768,
input_channels: int = 3):
super(PatchEmbedding, self).__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) ** 2
self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)
def construct(self, x):
"""Path Embedding construct."""
x = self.conv(x)
b, c, h, w = x.shape
x = ops.reshape(x, (b, c, h * w))
x = ops.transpose(x, (0, 2, 1))
return x
輸入圖像在劃分為patch之后,會(huì)經(jīng)過(guò)pos_embedding 和 class_embedding兩個(gè)過(guò)程。
class_embedding主要借鑒了BERT模型的用于文本分類(lèi)時(shí)的思想,在每一個(gè)word
vector之前增加一個(gè)類(lèi)別值,通常是加在向量的第一位,上一步得到的196維的向量加上class_embedding后變?yōu)?97維。增加的class_embedding是一個(gè)可以學(xué)習(xí)的參數(shù),經(jīng)過(guò)網(wǎng)絡(luò)的不斷訓(xùn)練,最終以輸出向量的第一個(gè)維度的輸出來(lái)決定最后的輸出類(lèi)別;由于輸入是16 x 16個(gè)patch,所以輸出進(jìn)行分類(lèi)時(shí)是取 16 x 16個(gè)class_embedding進(jìn)行分類(lèi)。
pos_embedding也是一組可以學(xué)習(xí)的參數(shù),會(huì)被加入到經(jīng)過(guò)處理的patch矩陣中。
由于pos_embedding也是可以學(xué)習(xí)的參數(shù),所以它的加入類(lèi)似于全鏈接網(wǎng)絡(luò)和卷積的bias。這一步就是創(chuàng)造一個(gè)長(zhǎng)度維197的可訓(xùn)練向量加入到經(jīng)過(guò)class_embedding的向量中。
實(shí)際上,pos_embedding總共有4種方案。但是經(jīng)過(guò)作者的論證,只有加上pos_embedding和不加pos_embedding有明顯影響,至于pos_embedding是一維還是二維對(duì)分類(lèi)結(jié)果影響不大,所以,在我們的代碼中,也是采用了一維的pos_embedding,由于class_embedding是加在pos_embedding之前,所以pos_embedding的維度會(huì)比patch拉伸后的維度加1。
總的而言,ViT模型還是利用了Transformer模型在處理上下文語(yǔ)義時(shí)的優(yōu)勢(shì),將圖像轉(zhuǎn)換為一種“變種詞向量”然后進(jìn)行處理,而這樣轉(zhuǎn)換的意義在于,多個(gè)patch之間本身具有空間聯(lián)系,這類(lèi)似于一種“空間語(yǔ)義”,從而獲得了比較好的處理效果。
整體構(gòu)建ViT
以下代碼構(gòu)建了一個(gè)完整的ViT模型。
from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter
def init(init_type, shape, dtype, name, requires_grad):
"""Init."""
initial = initializer(init_type, shape, dtype).init_data()
return Parameter(initial, name=name, requires_grad=requires_grad)
class ViT(nn.Cell):
def __init__(self,
image_size: int = 224,
input_channels: int = 3,
patch_size: int = 16,
embed_dim: int = 768,
num_layers: int = 12,
num_heads: int = 12,
mlp_dim: int = 3072,
keep_prob: float = 1.0,
attention_keep_prob: float = 1.0,
drop_path_keep_prob: float = 1.0,
activation: nn.Cell = nn.GELU,
norm: Optional[nn.Cell] = nn.LayerNorm,
pool: str = 'cls') -> None:
super(ViT, self).__init__()
self.patch_embedding = PatchEmbedding(image_size=image_size,
patch_size=patch_size,
embed_dim=embed_dim,
input_channels=input_channels)
num_patches = self.patch_embedding.num_patches
self.cls_token = init(init_type=Normal(sigma=1.0),
shape=(1, 1, embed_dim),
dtype=ms.float32,
name='cls',
requires_grad=True)
self.pos_embedding = init(init_type=Normal(sigma=1.0),
shape=(1, num_patches + 1, embed_dim),
dtype=ms.float32,
name='pos_embedding',
requires_grad=True)
self.pool = pool
self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
self.norm = norm((embed_dim,))
self.transformer = TransformerEncoder(dim=embed_dim,
num_layers=num_layers,
num_heads=num_heads,
mlp_dim=mlp_dim,
keep_prob=keep_prob,
attention_keep_prob=attention_keep_prob,
drop_path_keep_prob=drop_path_keep_prob,
activation=activation,
norm=norm)
self.dropout = nn.Dropout(p=1.0-keep_prob)
self.dense = nn.Dense(embed_dim, num_classes)
def construct(self, x):
"""ViT construct."""
x = self.patch_embedding(x)
cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
x = ops.concat((cls_tokens, x), axis=1)
x += self.pos_embedding
x = self.pos_dropout(x)
x = self.transformer(x)
x = self.norm(x)
x = x[:, 0]
if self.training:
x = self.dropout(x)
x = self.dense(x)
return x
整體流程圖如下所示:
四、模型訓(xùn)練與推理
模型訓(xùn)練
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train
# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()
# construct model
network = ViT()
# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"
vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
max_lr=0.00005,
total_step=epoch_size * step_size,
step_per_epoch=step_size,
decay_epoch=10)
# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)
# define loss function
class CrossEntropySmooth(LossBase):
"""CrossEntropy."""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = ops.OneHot()
self.sparse = sparse
self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)
# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")
# train model
model.train(epoch_size,
dataset_train,
callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
dataset_sink_mode=False,)
模型驗(yàn)證
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)
trans_val = [
transforms.Decode(),
transforms.Resize(224 + 32),
transforms.CenterCrop(224),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)
# construct model
network = ViT()
# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
num_classes=num_classes)
# define metric
eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
if ascend_target:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")
# evaluate model
result = model.eval(dataset_val)
print(result)
模型推理
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)
trans_infer = [
transforms.Decode(),
transforms.Resize([224, 224]),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_infer = dataset_infer.map(operations=trans_infer,
input_columns=["image"],
num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
class Color(Enum):
"""dedine enum color."""
red = (0, 0, 255)
green = (0, 255, 0)
blue = (255, 0, 0)
cyan = (255, 255, 0)
yellow = (0, 255, 255)
magenta = (255, 0, 255)
white = (255, 255, 255)
black = (0, 0, 0)
def check_file_exist(file_name: str):
"""check_file_exist."""
if not os.path.isfile(file_name):
raise FileNotFoundError(f"File `{file_name}` does not exist.")
def color_val(color):
"""color_val."""
if isinstance(color, str):
return Color[color].value
if isinstance(color, Color):
return color.value
if isinstance(color, tuple):
assert len(color) == 3
for channel in color:
assert 0 <= channel <= 255
return color
if isinstance(color, int):
assert 0 <= color <= 255
return color, color, color
if isinstance(color, np.ndarray):
assert color.ndim == 1 and color.size == 3
assert np.all((color >= 0) & (color <= 255))
color = color.astype(np.uint8)
return tuple(color)
raise TypeError(f'Invalid type for color: {type(color)}')
def imread(image, mode=None):
"""imread."""
if isinstance(image, pathlib.Path):
image = str(image)
if isinstance(image, np.ndarray):
pass
elif isinstance(image, str):
check_file_exist(image)
image = Image.open(image)
if mode:
image = np.array(image.convert(mode))
else:
raise TypeError("Image must be a `ndarray`, `str` or Path object.")
return image
def imwrite(image, image_path, auto_mkdir=True):
"""imwrite."""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(image_path))
if dir_name != '':
dir_name = os.path.expanduser(dir_name)
os.makedirs(dir_name, mode=777, exist_ok=True)
image = Image.fromarray(image)
image.save(image_path)
def imshow(img, win_name='', wait_time=0):
"""imshow"""
cv2.imshow(win_name, imread(img))
if wait_time == 0: # prevent from hanging if windows was closed
while True:
ret = cv2.waitKey(1)
closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
# if user closed window or if some key pressed
if closed or ret != -1:
break
else:
ret = cv2.waitKey(wait_time)
def show_result(img: str,
result: Dict[int, float],
text_color: str = 'green',
font_scale: float = 0.5,
row_width: int = 20,
show: bool = False,
win_name: str = '',
wait_time: int = 0,
out_file: Optional[str] = None) -> None:
"""Mark the prediction results on the picture."""
img = imread(img, mode="RGB")
img = img.copy()
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width
if out_file:
show = False
imwrite(img, out_file)
if show:
imshow(img, win_name, wait_time)
def index2label():
"""Dictionary output for image numbers and categories of the ImageNet dataset."""
metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
meta = io.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]
_, wnids, classes = list(zip(*meta))[:3]
clssname = [tuple(clss.split(', ')) for clss in classes]
wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])
mapping = {}
for index, (_, class_name) in enumerate(wind2class_name):
mapping[index] = class_name[0]
return mapping
# Read data for inference
for i, image in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
image = image["image"]
image = ms.Tensor(image)
prob = model.predict(image)
label = np.argmax(prob.asnumpy(), axis=1)
mapping = index2label()
output = {int(label): mapping[int(label)]}
print(output)
show_result(img="./dataset/infer/n01440764/ILSVRC2012_test_00000279.JPEG",
result=output,
out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")
推理過(guò)程完成后,在推理文件夾下可以找到圖片的推理結(jié)果,可以看出預(yù)測(cè)結(jié)果是Doberman,與期望結(jié)果相同,驗(yàn)證了模型的準(zhǔn)確性。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-619968.html
文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-619968.html
到了這里,關(guān)于華為開(kāi)源自研AI框架昇思MindSpore應(yīng)用案例:Vision Transformer圖像分類(lèi)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!