1.代碼運行
- 輸入 1 測試一張圖片并預測結果
-
輸入 2 對測試集整體進行測試,得出準確率(10秒左右)
-
輸入其他數(shù)字自動退出程序
2.注意事項
-
本程序包含python庫較多,請自行配置(pip),如有需求,請評論或私信!
-
回復其他數(shù)字會自動退出程序
-
輸入圖片要求是28*28像素
-
模型訓練大概需要2分鐘,請耐心等候!
-
本代碼使用在線MNIST數(shù)據(jù)庫,無需本地MNIST數(shù)據(jù)庫!
-
文件會自動在同目錄下面生成
Model
文件夾,里面包含兩個文件model.pdopt
、model.pdparams
-
如果需要可視化,可以將
callbacks
行注釋去除 -
如果需要下圖格式,請將上面代碼中
verbose=0
修改為verbose=1
3.代碼分析
- 數(shù)據(jù)預處理,從paddle庫得到mnist數(shù)據(jù)
def data_process():
transform = T.Normalize(mean=[127.5], std=[127.5])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('訓練樣本量:{},測試樣本量:{}'.format(len(train_dataset), len(test_dataset)))
return train_dataset, test_dataset
2.訓練模型
def create_model(train_dataset, test_dataset):
print('查找是否存在模型.')
# 網(wǎng)絡結構代碼實現(xiàn),調(diào)用paddle的網(wǎng)絡
network = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(network)
model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()), # 優(yōu)化器
paddle.nn.CrossEntropyLoss(), # 損失函數(shù)
paddle.metric.Accuracy()) # 評估指標
if not os.path.exists('Model/model.pdopt') or not os.path.exists('Model/model.pdparams'):
print('不存在模型,開始訓練模型.')
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir_LeNet學習率0.001')
# 啟動全流程訓練
model.fit(train_dataset, # 訓練數(shù)據(jù)集
test_dataset, # 評估數(shù)據(jù)集
epochs=5, # 訓練輪次
batch_size=64, # 單次計算數(shù)據(jù)樣本量
verbose=0,
# callbacks=callback
) # 日志展示形式
print("模型訓練結束")
# 進行預測操作
# result = model_1.predict(test_dataset)
model.save('Model/model')
else:
model.load('Model/model')
print("已經(jīng)存在訓練好的模型?。?!")
return model
3.測試單張和多張文章來源:http://www.zghlxwxcb.cn/news/detail-740087.html
def Test_one(imgPath, modelPath):
model = modelPath
# image = preprocessing.StandardScaler().fit_transform(np.array(Image.open(imgPath).convert('L'), dtype='float32'))
im = Image.open(imgPath).convert('L')
# 為灰度圖像,每個像素用8個bit表示,0表示黑,255表示白,其他數(shù)字表示不同的灰度。 轉(zhuǎn)換公式:L = R * 299/1000 + G * 587/1000+ B * 114/1000。
# im = im.resize((28, 28), Image.ANTIALIAS)
im = numpy.array(im).reshape(-1, 1, 28, 28).astype('float32')
im = im / 255.0 - 1.0
result = model.predict([im], verbose=True)
print('預測結果是:', result[0][0].argmax(), '\n') # argmax()得到最大值下標
# 測試準確率
def Test_all(modelPath):
model = modelPath
result = model.evaluate(test_dataset, verbose=1)
print('準確率為:', result['acc'])
4.主程序文章來源地址http://www.zghlxwxcb.cn/news/detail-740087.html
if __name__ == '__main__':
print('訓練數(shù)據(jù)自動使用paddle自帶的MINST數(shù)據(jù)庫')
print('回復數(shù)字1為測試一張圖片,回復數(shù)字2為測試測試集準確率,回復其他數(shù)字自動退出程序?。?!\n')
train_dataset, test_dataset = data_process()
model = create_model(train_dataset, test_dataset)
while 1:
ans = input('測試一張(1)還是測試集準確率(2):')
match (ans):
case '1':
modelPath = model # 模型文件會生成在本文件同目錄下不需要選擇,這里Path默認直接加載模型
root = tk.Tk()
root.withdraw()
print('請選擇測試圖片')
imgPath = filedialog.askopenfilename()
Test_one(imgPath, modelPath)
continue
case '2':
modelPath = model
Test_all(modelPath)
continue
case _:
print('測試結束?。。?)
exit()
4.源代碼
如果是完成大作業(yè)需求,請各位自己適配?。。?! |
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName :CNN3.py
# @Time :2023/4/24 14:28
# @Author :YKW
import os
import paddle
import numpy
import tkinter as tk
from PIL import Image
from tkinter import filedialog
import paddle.vision.transforms as T
'protobuf版本建議使用3.20.0,否則會不兼容'
'訓練數(shù)據(jù)自動使用paddle自帶的MINST數(shù)據(jù)庫'
# 數(shù)據(jù)預處理,從paddle庫得到mnist數(shù)據(jù)
def data_process():
transform = T.Normalize(mean=[127.5], std=[127.5])
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('訓練樣本量:{},測試樣本量:{}'.format(len(train_dataset), len(test_dataset)))
return train_dataset, test_dataset
# 訓練模型
def create_model(train_dataset, test_dataset):
print('查找是否存在模型.')
# 網(wǎng)絡結構代碼實現(xiàn),調(diào)用paddle的網(wǎng)絡
network = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(network)
model.prepare(paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()), # 優(yōu)化器
paddle.nn.CrossEntropyLoss(), # 損失函數(shù)
paddle.metric.Accuracy()) # 評估指標
if not os.path.exists('Model/model.pdopt') or not os.path.exists('Model/model.pdparams'):
print('不存在模型,開始訓練模型.')
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir_LeNet學習率0.001')
# 啟動全流程訓練
model.fit(train_dataset, # 訓練數(shù)據(jù)集
test_dataset, # 評估數(shù)據(jù)集
epochs=5, # 訓練輪次
batch_size=64, # 單次計算數(shù)據(jù)樣本量
verbose=0,
# callbacks=callback
) # 日志展示形式
print("模型訓練結束")
# 進行預測操作
# result = model_1.predict(test_dataset)
'''
indexs = [5, 20, 48, 210]
for idx in indexs:
show_img(test_dataset[idx][0], np.argmax(result[0][idx]))'''
model.save('Model/model')
else:
model.load('Model/model')
print("已經(jīng)存在訓練好的模型?。。?)
return model
# 測試單張
def Test_one(imgPath, modelPath):
model = modelPath
# image = preprocessing.StandardScaler().fit_transform(np.array(Image.open(imgPath).convert('L'), dtype='float32'))
im = Image.open(imgPath).convert('L')
# 為灰度圖像,每個像素用8個bit表示,0表示黑,255表示白,其他數(shù)字表示不同的灰度。 轉(zhuǎn)換公式:L = R * 299/1000 + G * 587/1000+ B * 114/1000。
# im = im.resize((28, 28), Image.ANTIALIAS)
im = numpy.array(im).reshape(-1, 1, 28, 28).astype('float32')
im = im / 255.0 - 1.0
result = model.predict([im], verbose=True)
print('預測結果是:', result[0][0].argmax(), '\n') # argmax()得到最大值下標
# 測試準確率
def Test_all(modelPath):
model = modelPath
result = model.evaluate(test_dataset, verbose=1)
print('準確率為:', result['acc'])
if __name__ == '__main__':
print('訓練數(shù)據(jù)自動使用paddle自帶的MINST數(shù)據(jù)庫')
print('回復數(shù)字1為測試一張圖片,回復數(shù)字2為測試測試集準確率,回復其他數(shù)字自動退出程序?。?!\n')
train_dataset, test_dataset = data_process()
model = create_model(train_dataset, test_dataset)
while 1:
ans = input('測試一張(1)還是測試集準確率(2):')
match (ans):
case '1':
modelPath = model # 模型文件會生成在本文件同目錄下不需要選擇,這里Path默認直接加載模型
root = tk.Tk()
root.withdraw()
print('請選擇測試圖片')
imgPath = filedialog.askopenfilename()
Test_one(imgPath, modelPath)
continue
case '2':
modelPath = model
Test_all(modelPath)
continue
case _:
print('測試結束?。。?)
exit()
到了這里,關于卷積神經(jīng)網(wǎng)絡(CNN)實現(xiàn)圖像分類——Python的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關文章,希望大家以后多多支持TOY模板網(wǎng)!