說(shuō)明
使用pytorch框架,實(shí)現(xiàn)對(duì)MNIST手寫數(shù)字?jǐn)?shù)據(jù)集的訓(xùn)練和識(shí)別。重點(diǎn)是,自己手寫數(shù)字,手機(jī)拍照后傳入電腦,使用你自己訓(xùn)練的權(quán)重和偏置能夠識(shí)別。數(shù)據(jù)預(yù)處理過(guò)程的代碼是重點(diǎn)。
分析
要識(shí)別自己用手在紙上寫的數(shù)字,從特征上來(lái)看,手寫數(shù)字相比于普通的電腦上的數(shù)字最大的 不同就是數(shù)字的邊緣會(huì)發(fā)生不同幅度的抖動(dòng)。而且,在MNIST數(shù)據(jù)集中的數(shù)字是邊緣為黑色的,然后數(shù)字是不同灰度的白色的,如下所示:
在數(shù)據(jù)集中,每個(gè)數(shù)據(jù)都是
28
?
28
28*28
28?28的灰度圖,并且黑色部分都是零,其余白色的灰度值并不統(tǒng)一。因?yàn)槿绻?xùn)練時(shí)背景都是統(tǒng)一的時(shí)候我們測(cè)試用的圖片背景也必須是統(tǒng)一的,否則基本無(wú)法識(shí)別出來(lái)。除非訓(xùn)練的時(shí)候換各種不同的背景大數(shù)據(jù)進(jìn)行訓(xùn)練,這樣特征就不會(huì)依托著背景而存在,剩下的就是要識(shí)別的物體自己所擁有的特征了。所以在這里我要做的就是在圖片預(yù)處理的時(shí)候盡量讓圖片處理成接近測(cè)試圖片的樣子。
訓(xùn)練網(wǎng)絡(luò)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
# 下載訓(xùn)練集
train_dataset = datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=False)
# 下載測(cè)試集
test_dataset = datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor(),
download=False)
# 設(shè)置批次數(shù)
batch_size = 100
# 裝載訓(xùn)練集
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = batch_size,
shuffle=True)
# 裝載測(cè)試集
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = batch_size,
shuffle = True)
# 自定義手寫數(shù)字識(shí)別網(wǎng)絡(luò)
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.Conn_layers = nn.Sequential(
nn.Linear(784, 100),
nn.Sigmoid(),
nn.Linear(100, 10),
nn.Sigmoid()
)
def forward(self, input):
output = self.Conn_layers(input)
return output
# 定義學(xué)習(xí)率
LR = 0.1
# 定義一個(gè)網(wǎng)絡(luò)對(duì)象
net = net()
# 損失函數(shù)使用交叉熵
loss_function = nn.CrossEntropyLoss()
# 優(yōu)化函數(shù)使用 SGD
optimizer = optim.SGD(
net.parameters(),
lr = LR,
momentum = 0.9,
weight_decay = 0.0005
)
# 定義迭代次數(shù)
epoch = 20
# 進(jìn)行迭代訓(xùn)練
for epoch in range(epoch):
for i, data in enumerate(train_loader):
inputs, labels = data
# 轉(zhuǎn)換下輸入形狀
inputs = inputs.reshape(batch_size, 784)
inputs, labels = Variable(inputs), Variable(labels)
outputs = net(inputs)
loss = loss_function(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 初始化正確結(jié)果數(shù)為0
test_result = 0
# 用測(cè)試數(shù)據(jù)進(jìn)行測(cè)試
for data_test in test_loader:
images, labels = data_test
# 轉(zhuǎn)換下輸入形狀
images = images.reshape(batch_size, 784)
images, labels = Variable(images), Variable(labels)
output_test = net(images)
# 對(duì)一個(gè)批次的數(shù)據(jù)的準(zhǔn)確性進(jìn)行判斷
for i in range(len(labels)):
# 如果輸出結(jié)果的最大值的索引與標(biāo)簽內(nèi)正確數(shù)據(jù)相等,準(zhǔn)確個(gè)數(shù)累加
if torch.argmax(output_test[i]) == labels[i]:
test_result += 1
# 打印每次迭代后正確的結(jié)果數(shù)
print("Epoch {} : {} / {}".format(epoch, test_result, len(test_dataset)))
# 保存權(quán)重模型
torch.save(net, 'weight/test.pkl')
至此,對(duì)手寫數(shù)字網(wǎng)絡(luò)的訓(xùn)練已經(jīng)結(jié)束,且訓(xùn)練的準(zhǔn)確性為:
這個(gè)網(wǎng)絡(luò)比較粗糙,所以準(zhǔn)確性也只是一般,但如果要精確起來(lái)后面有很多文章可做。
圖像預(yù)處理
因?yàn)槲覀兪謾C(jī)拍的照片和訓(xùn)練集的圖片有很大的區(qū)別,所以無(wú)法將手機(jī)上拍的照片直接丟到訓(xùn)練好的網(wǎng)絡(luò)模型中進(jìn)行識(shí)別,需要先對(duì)圖片進(jìn)行預(yù)處理。有幾點(diǎn)需要對(duì)原圖進(jìn)行改變:
- 圖片的大?。嚎隙ǖ脤⑴臄z到的圖片轉(zhuǎn)換成 28 ? 28 28*28 28?28尺寸大小的圖片。
- 圖片的通道數(shù):由于MNIST是灰度圖,所以原圖的channel也得轉(zhuǎn)換成1。
- 圖片的背景:圖片的背景得轉(zhuǎn)換成MNIST相同的黑色,這樣識(shí)別結(jié)果準(zhǔn)確性更高。
- 數(shù)字的顏色:毋庸置疑,數(shù)字的顏色得變成MNIST相同的白色。
- 數(shù)字顏色中間深邊緣前:觀察MNIST的白色部分并不都是255全白,而是有漸變色的,這個(gè)漸變色模擬起來(lái)比較困難,算是難度最大的一點(diǎn)了。
接下來(lái)直接上代碼了:
import cv2
import numpy as np
def image_preprocessing():
# 讀取圖片
img = cv2.imread("picture/test8.jpeg")
# =====================圖像處理======================== #
# 轉(zhuǎn)換成灰度圖像
gray_img = cv2.cvtColor(img , cv2.COLOR_BGR2GRAY)
# 進(jìn)行高斯濾波
gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT)
# 邊緣檢測(cè)
img_edge1 = cv2.Canny(gauss_img, 100, 200)
# ==================================================== #
# =====================圖像分割======================== #
# 獲取原始圖像的寬和高
high = img.shape[0]
width = img.shape[1]
# 分別初始化高和寬的和
add_width = np.zeros(high, dtype = int)
add_high = np.zeros(width, dtype = int)
# 計(jì)算每一行的灰度圖的值的和
for h in range(high):
for w in range(width):
add_width[h] = add_width[h] + img_edge1[h][w]
# 計(jì)算每一列的值的和
for w in range(width):
for h in range(high):
add_high[w] = add_high[w] + img_edge1[h][w]
# 初始化上下邊界為寬度總值最大的值的索引
acount_high_up = np.argmax(add_width)
acount_high_down = np.argmax(add_width)
# 將上邊界坐標(biāo)值上移,直到?jīng)]有遇到白色點(diǎn)停止,此為數(shù)字的上邊界
while add_width[acount_high_up] != 0:
acount_high_up = acount_high_up + 1
# 將下邊界坐標(biāo)值下移,直到?jīng)]有遇到白色點(diǎn)停止,此為數(shù)字的下邊界
while add_width[acount_high_down] != 0:
acount_high_down = acount_high_down - 1
# 初始化左右邊界為寬度總值最大的值的索引
acount_width_left = np.argmax(add_high)
acount_width_right = np.argmax(add_high)
# 將左邊界坐標(biāo)值左移,直到?jīng)]有遇到白色點(diǎn)停止,此為數(shù)字的左邊界
while add_high[acount_width_left] != 0:
acount_width_left = acount_width_left - 1
# 將右邊界坐標(biāo)值右移,直到?jīng)]有遇到白色點(diǎn)停止,此為數(shù)字的右邊界
while add_high[acount_width_right] != 0:
acount_width_right = acount_width_right + 1
# 求出寬和高的間距
width_spacing = acount_width_right - acount_width_left
high_spacing = acount_high_up - acount_high_down
# 求出寬和高的間距差
poor = width_spacing - high_spacing
# 將數(shù)字進(jìn)行正方形分割,目的是方便之后進(jìn)行圖像壓縮
if poor > 0:
tailor_image = img[acount_high_down - poor // 2 - 5:acount_high_up + poor - poor // 2 + 5, acount_width_left - 5:acount_width_right + 5]
else:
tailor_image = img[acount_high_down - 5:acount_high_up + 5, acount_width_left + poor // 2 - 5:acount_width_right - poor + poor // 2 + 5]
# ==================================================== #
# ======================小圖處理======================= #
# 將裁剪后的圖片進(jìn)行灰度化
gray_img = cv2.cvtColor(tailor_image , cv2.COLOR_BGR2GRAY)
# 高斯去噪
gauss_img = cv2.GaussianBlur(gray_img, (5,5), 0, 0, cv2.BORDER_DEFAULT)
# 將圖像形狀調(diào)整到28*28大小
zoom_image = cv2.resize(gauss_img, (28, 28))
# 獲取圖像的高和寬
high = zoom_image.shape[0]
wide = zoom_image.shape[1]
# 將圖像每個(gè)點(diǎn)的灰度值進(jìn)行閾值比較
for h in range(high):
for w in range(wide):
# 若灰度值大于100,則判斷為背景并賦值0,否則將深灰度值變白處理
if zoom_image[h][w] > 100:
zoom_image[h][w] = 0
else:
zoom_image[h][w] = 255 - zoom_image[h][w]
# ==================================================== #
return zoom_image
在此,我在紙上寫了個(gè)6,如下圖所示:
然后是對(duì)圖像進(jìn)行分割,首先要介紹下我分割圖像的方法。下面是一張進(jìn)行canny邊緣檢測(cè)后的6:
在這里這個(gè)6有個(gè)特點(diǎn),就是被白邊給包圍著了,因?yàn)榘咨幕叶戎禐?55,黑色的灰度值為0,所以我就假設(shè)以高為很坐標(biāo),然后每個(gè)高對(duì)應(yīng)著的寬的灰度值進(jìn)行相加。所以會(huì)很明顯發(fā)現(xiàn)就6這個(gè)字的整體的值比較聚集,當(dāng)然有可能有零星的散點(diǎn),但并不影響對(duì)6所在位置的判斷。最后以高為例,得到的值的坐標(biāo)圖如下:
因?yàn)樽畲笾当容^容易找到,所以就找到最大值然后向兩邊延伸,當(dāng)發(fā)現(xiàn)值為零時(shí)就可以把邊界給標(biāo)定出來(lái)了。
最后進(jìn)行分割分割注意的是后面對(duì)圖像進(jìn)行裁剪的時(shí)候是將寬和高較長(zhǎng)的一邊減去較短的一邊然后除以2平分給較短的一邊的兩側(cè),為了防止邊緣檢測(cè)沒(méi)有包裹著數(shù)字,于是在數(shù)字四周都加了五個(gè)像素點(diǎn)進(jìn)行裁剪,最后裁剪出來(lái)的效果如下:
這個(gè)圖片就是上述代碼中的tailor_image所顯示出來(lái)的圖片,因?yàn)轱@示圖片的代碼只作為測(cè)試使用,而且又很簡(jiǎn)單,這里就沒(méi)有展示出來(lái)。
好了,接下來(lái)就是要對(duì)辛辛苦苦裁剪出來(lái)的小圖進(jìn)行圖像進(jìn)行處理了,首先還是最基本的灰度化和高斯濾波處理,然后就是對(duì)圖像進(jìn)行大小轉(zhuǎn)換,因?yàn)镸NIST數(shù)據(jù)形狀就是
28
?
28
28*28
28?28所以也要將輸入圖片轉(zhuǎn)換成
28
?
28
28*28
28?28的大小。大小轉(zhuǎn)換完成后,就是要完成把灰度圖轉(zhuǎn)換成背景為0,然后數(shù)字變成白色的圖片,因?yàn)檫@樣和MNIST數(shù)據(jù)集里的數(shù)字圖片特別的像。在這里我用了閾值控制的方法將背景變成黑色的。至于這100當(dāng)然是將圖片的灰度值打出來(lái)后觀察得出來(lái)的。但是這種方法是比較危險(xiǎn)的,因?yàn)檫@樣的魯棒性并不強(qiáng),但后面如果要加強(qiáng)魯棒性則同樣可以用邊緣檢測(cè)把數(shù)字包裹住,然后數(shù)字之外的背景清零,這確實(shí)是一個(gè)很好的思路,但在這里就建議的用閾值控制的方法來(lái)實(shí)現(xiàn)背景黑化了。黑化背景后當(dāng)然就是將數(shù)字白化了,之前有將數(shù)字部分都是255值,但發(fā)現(xiàn)識(shí)別的效果并不理想,所以這里我采用了用255-原先數(shù)字的值,這樣如果原先的數(shù)字黑度深的部分就會(huì)變成白色程度深,就簡(jiǎn)單的實(shí)現(xiàn)了數(shù)字邊緣淺,中間深的變換。最后處理得到的圖像如下:
雖說(shuō)看起來(lái)沒(méi)有第一張圖那么完美,但大概還是能達(dá)到驗(yàn)證數(shù)據(jù)所需的要求了。至此,數(shù)據(jù)預(yù)處理已經(jīng)完成了,接下來(lái)就是激動(dòng)的預(yù)測(cè)了。
預(yù)測(cè)
預(yù)測(cè)代碼如下:文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-436087.html
import torch
# pretreatment.py為上面圖片預(yù)處理的文件名,導(dǎo)入圖片預(yù)處理文件
import pretreatment as PRE
# 加載網(wǎng)絡(luò)模型
net = torch.load('weight/test.pkl')
# 得到返回的待預(yù)測(cè)圖片值,就是pretreatment.py中的zoom_image
img = PRE.image_preprocessing()
# 將待預(yù)測(cè)圖片轉(zhuǎn)換形狀
inputs = img.reshape(-1, 784)
# 輸入數(shù)據(jù)轉(zhuǎn)換成tensor張量類型,并轉(zhuǎn)換成浮點(diǎn)類型
inputs = torch.from_numpy(inputs)
inputs = inputs.float()
# 丟入網(wǎng)絡(luò)進(jìn)行預(yù)測(cè),得到預(yù)測(cè)數(shù)據(jù)
predict = net(inputs)
# 打印對(duì)應(yīng)的最后的預(yù)測(cè)結(jié)果
print("The number in this picture is {}".format(torch.argmax(predict).detach().numpy()))
最后得到結(jié)果如圖所示:
這樣,整個(gè)手寫數(shù)字識(shí)別基本已經(jīng)完成了。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-436087.html
到了這里,關(guān)于手寫數(shù)字識(shí)別(識(shí)別紙上手寫的數(shù)字)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!