首先我們要先了解深度學(xué)習(xí)的概念和AI計(jì)算框架的角色(https://zhuanlan.zhihu.com/p/463019160),本篇文章將演示怎么利用MindSpore來(lái)訓(xùn)練一個(gè)AI模型。和上一章的場(chǎng)景一致,我們要訓(xùn)練的模型是用來(lái)對(duì)手寫(xiě)數(shù)字圖片進(jìn)行分類的LeNet5模型
請(qǐng)參考(http://yann.lecun.com/exdb/lenet/)。
圖1 MindSpore使用流程
安裝MindSpore
MindSpore提供給用戶使用的是Python接口(什么是Python,請(qǐng)參考:
https://zhuanlan.zhihu.com/p/462756985),所以我們首先需要安裝MindSpore的whl包,安裝之后就可以導(dǎo)入(import)MindSpore提供的方法接口了。安裝whl包有兩種方式:
方式一:進(jìn)入MindSpore官網(wǎng),根據(jù)自己的設(shè)備和Python版本選擇安裝命令。比如我的Python版本是3.7.5,我的設(shè)備是筆記本(CPU),那么我就復(fù)制下圖紅框中的命令進(jìn)行安裝:
圖2 MindSpore安裝界面
安裝過(guò)程如下:
圖3 MindSpore安裝過(guò)程
注意:由于MindSpore還依賴于其他的Python三方庫(kù),所以在安裝過(guò)程中,系統(tǒng)還會(huì)自動(dòng)下載、安裝其他的Python三方庫(kù),如numpy、pillow、scipy等等,安裝結(jié)束后,如果能 import mindspore 成功,說(shuō)明MindSpore安裝成功了:
圖4 MindSpore安裝成功
方式二:可以在版本列表中找到對(duì)應(yīng)的whl包,點(diǎn)擊就能下載:
圖5 MindSpore版本下載列表
下載完成后,把whl包放到自己的目錄下,執(zhí)行 pip install xxx.whl:
圖6 MindSpore第二種安裝方式
定義模型
安裝好MindSpore之后,我們就可以導(dǎo)入MindSpore提供的算子(卷積、全連接、池化等函數(shù):https://zhuanlan.zhihu.com/p/463019160)來(lái)構(gòu)建我們的模型了??梢赃@么比喻:我們構(gòu)建一個(gè)AI模型就像建一個(gè)房子,而MindSpore提供給我們的算子就像是磚塊、窗戶、地板等基本組件。
圖7 定義LeNet5模型
如上圖所示,我們用到的“磚塊”都是mindspore.nn模塊提供的。注意:這里用到了Python的類(class),由②和③兩部分組成。我們這里定義的類是class LeNet5,它由初始化函數(shù) __init__(self) 和構(gòu)造函數(shù)construct(self, x)組成。初始化函數(shù)定義了我們構(gòu)造模型所需要用到的算子,比如conv算子、relu算子、flatten算子等等,這些算子都是從mindspore.nn獲取的;構(gòu)造函數(shù)就是把我們?cè)诔跏蓟瘮?shù)中導(dǎo)入的算子按順序排放,構(gòu)成我們最終的模型。construct()函數(shù)的輸入就是我們這個(gè)模型預(yù)測(cè)的對(duì)象,比如第一章講的黑白圖片像素矩陣;而“return y”中的就是預(yù)測(cè)的結(jié)果,對(duì)應(yīng)于第一章講到的10分類手寫(xiě)數(shù)字?jǐn)?shù)據(jù)集,就是一個(gè)行10列的數(shù)組(這里的是指輸入圖片的數(shù)量,AI模型支持多張圖片同時(shí)推理)。
導(dǎo)入訓(xùn)練數(shù)據(jù)集
什么是訓(xùn)練數(shù)據(jù)集?剛剛定義好的模型是不能對(duì)圖片進(jìn)行正確分類的,我們要通過(guò)“訓(xùn)練”過(guò)程來(lái)調(diào)整模型的參數(shù)矩陣的值。訓(xùn)練過(guò)程就需要用到訓(xùn)練樣本,也就是打上了正確標(biāo)簽的圖片。這就好比我們教小孩兒認(rèn)識(shí)動(dòng)物,需要拿幾張圖片給他們看,然后告訴他們這是什么、那是什么,教了幾遍之后,小孩兒就能認(rèn)識(shí)了。那么我們訓(xùn)練LeNet5模型就需要用到MNIST數(shù)據(jù)集,請(qǐng)參考(http://yann.lecun.com/exdb/mnist/)。這個(gè)數(shù)據(jù)集由兩部分組成:訓(xùn)練集(6萬(wàn)張圖片)和測(cè)試集(1萬(wàn)張圖片),都是0~9的黑白手寫(xiě)數(shù)字圖片。訓(xùn)練集是用來(lái)訓(xùn)練AI模型的,測(cè)試集是用來(lái)測(cè)試訓(xùn)練后的模型分類準(zhǔn)確率的。
下載得到的數(shù)據(jù)集最初是壓縮文件,還不能直接傳給MindSpore的訓(xùn)練接口使用,我們要先用MindSpore提供的數(shù)據(jù)處理接口把他們讀進(jìn)來(lái):
import mindspore.dataset as ds
mnist_ds = ds.MnistDataset(data_path) # 導(dǎo)入下載的MNIST數(shù)據(jù)集
然后進(jìn)行數(shù)據(jù)增強(qiáng)(比如把圖片大小轉(zhuǎn)化成相同的尺寸、像素值標(biāo)準(zhǔn)化、歸一化等操作),提升訓(xùn)練效率:
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype
# 定義數(shù)據(jù)增強(qiáng)函數(shù)
def create_dataset(data_path, batch_size=32): # batch_size是每一步訓(xùn)練使用的圖片數(shù)量,一般取32
"""
create dataset for train or test
Args:
data_path (str): Data path
batch_size (int): The number of data records in each group
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path) # 導(dǎo)入下載的MNIST數(shù)據(jù)集
# define some parameters needed for data enhancement and rough justification
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# according to the parameters, generate the corresponding data enhancement method
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# using map to apply operations to a dataset
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label")
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image")
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image")
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image")
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image")
# process the generated dataset
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds
?訓(xùn)練模型
訓(xùn)練數(shù)據(jù)集和模型定義完成之后呢,我們就可以開(kāi)始訓(xùn)練模型了。但是在訓(xùn)練之前,我們還需要從MindSpore導(dǎo)入兩個(gè)函數(shù):
-
損失函數(shù),也就是衡量預(yù)測(cè)結(jié)果和真實(shí)標(biāo)簽之間的差距的函數(shù)??催^(guò)上一章的同學(xué)可能會(huì)記得,我們之前用的損失函數(shù)是真實(shí)值與預(yù)測(cè)值之差的2-范數(shù):
圖8 2-范數(shù)損失
在這里,我們使用業(yè)界最常用的交叉熵?fù)p失函數(shù)SoftmaxCrossEntropyWithLogits,對(duì)于真實(shí)標(biāo)簽
和預(yù)測(cè)值,它們之間的交叉熵?fù)p失計(jì)算公式為:
其中J代表數(shù)組的下標(biāo),。從MindSpore導(dǎo)入損失函數(shù):
from mindspore.nn import SoftmaxCrossEntropyWithLogits
# define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
-
優(yōu)化器,優(yōu)化器就是用來(lái)求解損失函數(shù)關(guān)于模型參數(shù)的更新梯度的,它是整個(gè)訓(xùn)練過(guò)程中最重要的工具!我們這里用MindSpore提供的Momentum優(yōu)化器:
import mindspore.nn as nn
lr = 0.01 # 定義學(xué)習(xí)率
momentum = 0.9 # 定義Momentum優(yōu)化器的超參
# define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # 導(dǎo)入mindspore提供
?準(zhǔn)備好損失函數(shù)和優(yōu)化器之后我們就可以開(kāi)始訓(xùn)練模型了,也非常簡(jiǎn)單,我們先把前面定義好的模型、損失函數(shù)、優(yōu)化器封裝成一個(gè)Model:
from mindspore import Model
net = LeNet5()
model = Model(net, net_loss , net_opt , metrics={'acc', 'loss'})
然后使用model.train接口就可以訓(xùn)練我們定義的LeNet5模型了:
loss_cb = LossMonitor(per_print_times=ds_train.get_dataset_size()) # 用于監(jiān)控訓(xùn)練過(guò)程中損失函數(shù)值的變化
ds_train = create_dataset(train_data_dir) # 傳入下載的訓(xùn)練集的路徑
model.train(num_epochs, ds_train, callbacks=[loss_cb]) # num_epochs是訓(xùn)練的輪數(shù),往往訓(xùn)練多輪才能使模型收斂
測(cè)試訓(xùn)練后的模型準(zhǔn)確率
訓(xùn)練結(jié)束后,調(diào)用model.eval()計(jì)算訓(xùn)練后的模型在測(cè)試集上面的分類準(zhǔn)確率:
ds_eval = create_dataset(test_data_dir) # 傳入下載的訓(xùn)練集的路徑
metrics = model.eval(ds_eval)
小結(jié)
祝賀你耐心看完了MindSpore訓(xùn)練模型的完整過(guò)程,如果你想動(dòng)手操作一遍,但是又沒(méi)有現(xiàn)成的環(huán)境,那么你可以使用官網(wǎng)提供的“在線運(yùn)行”來(lái)體驗(yàn)一番:
圖9 MindSpore官網(wǎng)提供的免費(fèi)體驗(yàn)入口
這是體驗(yàn)過(guò)程的實(shí)操視頻:
https://zhuanlan.zhihu.com/p/463229660
歡迎投稿
歡迎大家踴躍投稿,有想投稿技術(shù)干貨、項(xiàng)目經(jīng)驗(yàn)等分享的同學(xué),可以添加MindSpore官方小助手:小貓子(mindspore0328)的微信,告訴貓哥哦!
昇思MindSpore官方交流QQ群?:?486831414(群里有很多技術(shù)大咖助力答疑!)
MindSpore官方資料
GitHub?:?https://github.com/mindspore-ai/mindspore
Gitee?:?https?:?//gitee.com/mindspore/mindspore文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-445378.html
官方QQ群?:?486831?文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-445378.html
到了這里,關(guān)于手把手教你用MindSpore訓(xùn)練一個(gè)AI模型!的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!