1 實驗介紹
1.1 簡介
Mnist手寫體圖像識別實驗是深度學習入門經(jīng)典實驗。Mnist數(shù)據(jù)集包含60,000個用于訓練的示例和10,000個用于測試的示例。這些數(shù)字已經(jīng)過尺寸標準化并位于圖像中心,圖像是固定大小(28x28像素),其值為0到255。為簡單起見,每個圖像都被平展并轉(zhuǎn)換為784(28*28)個特征的一維numpy數(shù)組。
1.2 實驗目的
- 學會如何搭建全連接神經(jīng)網(wǎng)絡。
- 掌握搭建網(wǎng)絡過程中的關(guān)鍵點。
- 掌握分類任務的整體流程。
2.2 實驗環(huán)境要求?
推薦在華為云ModelArts實驗平臺完成實驗,也可在本地搭建python3.7.5和MindSpore1.0.0環(huán)境完成實驗。
2.3 實驗總體設計
?文章來源地址http://www.zghlxwxcb.cn/news/detail-401450.html
創(chuàng)建實驗環(huán)境:在本地搭建MindSpore環(huán)境。
導入實驗所需模塊:該步驟通常都是程序編輯的第一步,將實驗代碼所需要用到的模塊包用import命令進行導入。
導入數(shù)據(jù)集并預處理:神經(jīng)網(wǎng)絡的訓練離不開數(shù)據(jù),這里對數(shù)據(jù)進行導入。同時,因為全連接網(wǎng)絡只能接收固定維度的輸入數(shù)據(jù),所以,要對數(shù)據(jù)集進行預處理,以符合網(wǎng)絡的輸入維度要求。同時,設定好每一次訓練的Batch的大小,以Batch Size為單位進行輸入。
模型搭建:利用mindspore.nn的cell模塊搭建全連接網(wǎng)絡,包含輸入層,隱藏層,輸出層。同時,配置好網(wǎng)絡需要的優(yōu)化器,損失函數(shù)和評價指標。傳入數(shù)據(jù),并開始訓練模型。
模型評估:利用測試集進行模型的評估。
2.4 實驗過程
2.4.1 搭建實驗環(huán)境
Windows下MindSpore實驗環(huán)境搭建并配置Pycharm請參考【機器學習】Windows下MindSpore實驗環(huán)境搭建并配置Pycharm_在pycharm上安裝mindspore_弓長纟隹為的博客-CSDN博客
官網(wǎng)下載MNIST數(shù)據(jù)集?MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
在MNIST文件夾下建立train和test兩個文件夾,train中存放train-labels-idx1-ubyte和train-images-idx3-ubyte文件,test中存放t10k-labels-idx1-ubyte和t10k-images-idx3-ubyte文件。
2.4.2? 模型訓練、測試及評估
#導入相關(guān)依賴庫
import os
import numpy as np
from matplotlib import pyplot as plt
import mindspore as ms
#context模塊用于設置實驗環(huán)境和實驗設備
import mindspore.context as context
#dataset模塊用于處理數(shù)據(jù)形成數(shù)據(jù)集
import mindspore.dataset as ds
#c_transforms模塊用于轉(zhuǎn)換數(shù)據(jù)類型
import mindspore.dataset.transforms as C
#vision.c_transforms模塊用于轉(zhuǎn)換圖像,這是一個基于opencv的高級API
import mindspore.dataset.vision as CV
#導入Accuracy作為評價指標
from mindspore.nn.metrics import Accuracy
#nn中有各種神經(jīng)網(wǎng)絡層如:Dense,ReLu
from mindspore import nn
#Model用于創(chuàng)建模型對象,完成網(wǎng)絡搭建和編譯,并用于訓練和評估
from mindspore.train import Model
#LossMonitor可以在訓練過程中返回LOSS值作為監(jiān)控指標
from mindspore.train.callback import LossMonitor
#設定運行模式為動態(tài)圖模式,并且運行設備為昇騰芯片
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
#MindSpore內(nèi)置方法讀取MNIST數(shù)據(jù)集
ds_train = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "train"))
ds_test = ds.MnistDataset(os.path.join(r'D:\Dataset\MNIST', "test"))
print('訓練數(shù)據(jù)集數(shù)量:',ds_train.get_dataset_size())
print('測試數(shù)據(jù)集數(shù)量:',ds_test.get_dataset_size())
#該數(shù)據(jù)集可以通過create_dict_iterator()轉(zhuǎn)換為迭代器形式,然后通過get_next()一個個輸出樣本
image=ds_train.create_dict_iterator().get_next()
#print(type(image))
print('圖像長/寬/通道數(shù):',image['image'].shape)
#一共10類,用0-9的數(shù)字表達類別。
print('一張圖像的標簽樣式:',image['label'])
DATA_DIR_TRAIN = "D:/Dataset/MNIST/train" # 訓練集信息
DATA_DIR_TEST = "D:/Dataset/MNIST/test" # 測試集信息
def create_dataset(training=True, batch_size=128, resize=(28, 28), rescale=1 / 255, shift=-0.5, buffer_size=64):
ds = ms.dataset.MnistDataset(DATA_DIR_TRAIN if training else DATA_DIR_TEST)
# 定義改變形狀、歸一化和更改圖片維度的操作。
# 改為(28,28)的形狀
resize_op = CV.Resize(resize)
# rescale方法可以對數(shù)據(jù)集進行歸一化和標準化操作,這里就是將像素值歸一到0和1之間,shift參數(shù)可以讓值域偏移至-0.5和0.5之間
rescale_op = CV.Rescale(rescale, shift)
# 由高度、寬度、深度改為深度、高度、寬度
hwc2chw_op = CV.HWC2CHW()
# 利用map操作對原數(shù)據(jù)集進行調(diào)整
ds = ds.map(input_columns="image", operations=[resize_op, rescale_op, hwc2chw_op])
ds = ds.map(input_columns="label", operations=C.TypeCast(ms.int32))
# 設定洗牌緩沖區(qū)的大小,從一定程度上控制打亂操作的混亂程度
ds = ds.shuffle(buffer_size=buffer_size)
# 設定數(shù)據(jù)集的batch_size大小,并丟棄剩余的樣本
ds = ds.batch(batch_size, drop_remainder=True)
return ds
#顯示前10張圖片以及對應標簽,檢查圖片是否是正確的數(shù)據(jù)集
dataset_show = create_dataset(training=False)
data = dataset_show.create_dict_iterator().get_next()
images = data['image'].asnumpy()
labels = data['label'].asnumpy()
for i in range(1,11):
plt.subplot(2, 5, i)
#利用squeeze方法去掉多余的一個維度
plt.imshow(np.squeeze(images[i]))
plt.title('Number: %s' % labels[i])
plt.xticks([])
plt.show()
# 利用定義類的方式生成網(wǎng)絡,Mindspore中定義網(wǎng)絡需要繼承nn.cell。在init方法中定義該網(wǎng)絡需要的神經(jīng)網(wǎng)絡層
# 在construct方法中梳理神經(jīng)網(wǎng)絡層與層之間的關(guān)系。
class ForwardNN(nn.Cell):
def __init__(self):
super(ForwardNN, self).__init__()
self.flatten = nn.Flatten()
self.relu = nn.ReLU()
self.fc1 = nn.Dense(784, 512, activation='relu')
self.fc2 = nn.Dense(512, 256, activation='relu')
self.fc3 = nn.Dense(256, 128, activation='relu')
self.fc4 = nn.Dense(128, 64, activation='relu')
self.fc5 = nn.Dense(64, 32, activation='relu')
self.fc6 = nn.Dense(32, 10, activation='softmax')
def construct(self, input_x):
output = self.flatten(input_x)
output = self.fc1(output)
output = self.fc2(output)
output = self.fc3(output)
output = self.fc4(output)
output = self.fc5(output)
output = self.fc6(output)
return output
lr = 0.001
num_epoch = 10
momentum = 0.9
net = ForwardNN()
#定義loss函數(shù),改函數(shù)不需要求導,可以給離散的標簽值,且loss值為均值
loss = nn.loss.SoftmaxCrossEntropyWithLogits( sparse=True, reduction='mean')
#定義準確率為評價指標,用于評價模型
metrics={"Accuracy": Accuracy()}
#定義優(yōu)化器為Adam優(yōu)化器,并設定學習率
opt = nn.Adam(net.trainable_params(), lr)
#生成驗證集,驗證機不需要訓練,所以不需要repeat
ds_eval = create_dataset(False, batch_size=32)
#模型編譯過程,將定義好的網(wǎng)絡、loss函數(shù)、評價指標、優(yōu)化器編譯
model = Model(net, loss, opt, metrics)
#生成訓練集
ds_train = create_dataset(True, batch_size=32)
print("============== Starting Training ==============")
#訓練模型,用loss作為監(jiān)控指標,并利用昇騰芯片的數(shù)據(jù)下沉特性進行訓練
model.train(num_epoch, ds_train,callbacks=[LossMonitor()],dataset_sink_mode=True)
#使用測試集評估模型,打印總體準確率
metrics_result=model.eval(ds_eval)
print(metrics_result)
備注:
若報錯 AttributeError: ‘DictIterator’ object has no attribute ‘get_next’?,這是說MindSpore數(shù)據(jù)類中缺少 “get_next”這個方法,但是在MNIST圖像識別的官方代碼中卻使用了這個方法,這就說明MindSpore官方把這個變成私密方法。
只需要在源碼iterators.py中找到DictIterator這個類,將私有方法變成公有方法就行了(即去掉最前面的下劃線)。
參考mindspore 報錯 AttributeError: ‘DictIterator‘ object has no attribute ‘get_next‘_create_dict_iterator_TNiuB的博客-CSDN博客
MindSpore:前饋神經(jīng)網(wǎng)絡時報錯‘DictIterator‘ has no attribute ‘get_next‘_skytier的博客-CSDN博客
更多問題請參考Window10 上MindSpore(CPU)用LeNet網(wǎng)絡訓練MNIST - 知乎?文章來源:http://www.zghlxwxcb.cn/news/detail-401450.html
?
到了這里,關(guān)于【深度學習】基于華為MindSpore的手寫體圖像識別實驗的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!