筆記為自我總結(jié)整理的學(xué)習(xí)筆記,若有錯(cuò)誤歡迎指出喲~
往期文章:
【深度學(xué)習(xí)】pytorch——快速入門(mén)
CIFAR-10簡(jiǎn)介
CIFAR-10是一個(gè)常用的圖像分類(lèi)數(shù)據(jù)集,每張圖片都是 3×32×32,3通道彩色圖片,分辨率為 32×32。
它包含了10個(gè)不同類(lèi)別,每個(gè)類(lèi)別有6000張圖像,其中5000張用于訓(xùn)練,1000張用于測(cè)試。這10個(gè)類(lèi)別分別為:飛機(jī)、汽車(chē)、鳥(niǎo)類(lèi)、貓、鹿、狗、青蛙、馬、船和卡車(chē)。
CIFAR-10分類(lèi)任務(wù)是將這些圖像正確地分類(lèi)到它們所屬的類(lèi)別中。對(duì)于這個(gè)任務(wù),可以使用深度學(xué)習(xí)模型,如卷積神經(jīng)網(wǎng)絡(luò)(CNN)來(lái)實(shí)現(xiàn)高效的分類(lèi)。
CIFAR-10分類(lèi)任務(wù)是一個(gè)比較典型的圖像分類(lèi)問(wèn)題,在計(jì)算機(jī)視覺(jué)領(lǐng)域中被廣泛使用,是檢驗(yàn)深度學(xué)習(xí)模型表現(xiàn)的一個(gè)重要基準(zhǔn)。
CIFAR-10數(shù)據(jù)集分類(lèi)實(shí)現(xiàn)步驟
- 使用torchvision加載并預(yù)處理CIFAR-10數(shù)據(jù)集
- 定義網(wǎng)絡(luò)
- 定義損失函數(shù)和優(yōu)化器
- 訓(xùn)練網(wǎng)絡(luò)并更新網(wǎng)絡(luò)參數(shù)
- 測(cè)試網(wǎng)絡(luò)
一、數(shù)據(jù)加載及預(yù)處理
實(shí)現(xiàn)數(shù)據(jù)加載及預(yù)處理
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor轉(zhuǎn)成Image,方便可視化
# 第一次運(yùn)行程序torchvision會(huì)自動(dòng)下載CIFAR-10數(shù)據(jù)集,大約100M。
# 如果已經(jīng)下載有CIFAR-10,可通過(guò)root參數(shù)指定
# 定義對(duì)數(shù)據(jù)的預(yù)處理
transform = transforms.Compose([
transforms.ToTensor(), # 轉(zhuǎn)為T(mén)ensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 歸一化
])
# 訓(xùn)練集
trainset = tv.datasets.CIFAR10( # PyTorch提供的CIFAR-10數(shù)據(jù)集的類(lèi),用于加載CIFAR-10數(shù)據(jù)集。
root='D:/深度學(xué)習(xí)基礎(chǔ)/pytorch/data/', # 設(shè)置數(shù)據(jù)集存儲(chǔ)的根目錄。
train=True, # 指定加載的是CIFAR-10的訓(xùn)練集。
download=True, # 如果數(shù)據(jù)集尚未下載,設(shè)置為T(mén)rue會(huì)自動(dòng)下載CIFAR-10數(shù)據(jù)集。
transform=transform) # 設(shè)置數(shù)據(jù)集的預(yù)處理方式。
# 數(shù)據(jù)加載器
trainloader = t.utils.data.DataLoader(
trainset, # 指定了要加載的訓(xùn)練集數(shù)據(jù),即CIFAR-10數(shù)據(jù)集。
batch_size=4, # 每個(gè)小批量(batch)的大小是4,即每次會(huì)加載4張圖片進(jìn)行訓(xùn)練。
shuffle=True, # 在每個(gè)epoch訓(xùn)練開(kāi)始前,會(huì)打亂訓(xùn)練集中數(shù)據(jù)的順序,以增加訓(xùn)練效果。
num_workers=2) # 使用2個(gè)進(jìn)程來(lái)加載數(shù)據(jù),以提高數(shù)據(jù)的加載速度。
# 測(cè)試集
testset = tv.datasets.CIFAR10(
'D:/深度學(xué)習(xí)基礎(chǔ)/pytorch/data/',
train=False,
download=True,
transform=transform)
testloader = t.utils.data.DataLoader(
testset,
batch_size=4,
shuffle=False,
num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
這段代碼主要是使用PyTorch和torchvision庫(kù)來(lái)加載并處理CIFAR-10數(shù)據(jù)集,其中包括訓(xùn)練集和測(cè)試集。
-
import torch as t
和import torchvision as tv
導(dǎo)入了PyTorch和torchvision庫(kù)。 -
import torchvision.transforms as transforms
導(dǎo)入了torchvision.transforms模塊,用于進(jìn)行數(shù)據(jù)轉(zhuǎn)換和增強(qiáng)操作。 -
from torchvision.transforms import ToPILImage
導(dǎo)入了ToPILImage類(lèi),它可以將Tensor對(duì)象轉(zhuǎn)換為PIL Image對(duì)象,以方便后續(xù)的可視化操作。 -
show = ToPILImage()
創(chuàng)建一個(gè)ToPILImage對(duì)象,用于將張量(Tensor)對(duì)象轉(zhuǎn)換為PIL Image對(duì)象,以便于后續(xù)的可視化操作。 -
transform = transforms.Compose([...])
定義對(duì)數(shù)據(jù)的預(yù)處理操作,將多個(gè)預(yù)處理操作組合在一起,形成一個(gè)數(shù)據(jù)預(yù)處理的管道。該管道首先使用transforms.ToTensor()
函數(shù)將圖像轉(zhuǎn)換為張量(Tensor)對(duì)象,然后使用transforms.Normalize()
函數(shù)對(duì)圖像進(jìn)行歸一化操作,以便于后續(xù)的訓(xùn)練。 -
trainset = tv.datasets.CIFAR10([...])
使用tv.datasets.CIFAR10()
函數(shù)加載CIFAR-10數(shù)據(jù)集,并指定數(shù)據(jù)集的存儲(chǔ)位置、是否為訓(xùn)練集、是否需要下載等參數(shù)。還可以通過(guò)transform
參數(shù)來(lái)指定對(duì)數(shù)據(jù)進(jìn)行的預(yù)處理操作。 -
trainloader = t.utils.data.DataLoader([...])
使用PyTorch的DataLoader
類(lèi)來(lái)創(chuàng)建一個(gè)數(shù)據(jù)加載器,該加載器可以按照指定的批量大小將數(shù)據(jù)集分成小批量進(jìn)行加載??梢灾付虞d器的參數(shù),如批量大小、是否隨機(jī)洗牌、使用的進(jìn)程數(shù)等。 -
testset = tv.datasets.CIFAR10([...])
和testloader = t.utils.data.DataLoader([...])
與訓(xùn)練集的加載方式類(lèi)似,只是將參數(shù)中的train
改為False
,表示這是測(cè)試集。 -
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
定義了CIFAR-10數(shù)據(jù)集中包含的10個(gè)類(lèi)別。
注:tv.datasets.CIFAR10()
函數(shù)會(huì)自動(dòng)下載CIFAR-10數(shù)據(jù)集并存儲(chǔ)到指定位置,如果已經(jīng)下載過(guò)該數(shù)據(jù)集,可以通過(guò)root
參數(shù)來(lái)指定數(shù)據(jù)集的存儲(chǔ)位置,避免重復(fù)下載浪費(fèi)時(shí)間和帶寬。
歸一化的理解
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 歸一化
transforms.Normalize()
函數(shù)實(shí)現(xiàn)了對(duì)圖像數(shù)據(jù)進(jìn)行歸一化操作。該函數(shù)的參數(shù)是均值和標(biāo)準(zhǔn)差,在CIFAR-10數(shù)據(jù)集中,每個(gè)像素有3個(gè)通道(R,G,B),因此傳入的均值和標(biāo)準(zhǔn)差是一個(gè)長(zhǎng)度為3的元組。這里(0.5, 0.5, 0.5)
表示每個(gè)通道的均值為0.5,(0.5, 0.5, 0.5)
表示每個(gè)通道的標(biāo)準(zhǔn)差也為0.5。具體地,對(duì)于每個(gè)像素的每個(gè)通道,該函數(shù)執(zhí)行以下計(jì)算:
input[channel] = (input[channel] - mean[channel]) / std[channel]
其中,input[channel]
表示一個(gè)像素的某個(gè)通道的像素值,mean[channel]
和std[channel]
分別表示該通道的均值和標(biāo)準(zhǔn)差。通過(guò)這樣的歸一化操作,每個(gè)通道的像素值都將落在-1到1之間,從而便于模型的訓(xùn)練。
因此,這行代碼的作用是對(duì)CIFAR-10數(shù)據(jù)集中的圖像進(jìn)行歸一化,將每個(gè)通道的像素值映射到-1到1之間。
訪問(wèn)數(shù)據(jù)集
Dataset對(duì)象
Dataset對(duì)象是一個(gè)數(shù)據(jù)集,可以按下標(biāo)訪問(wèn),返回形如(data, label)的數(shù)據(jù)。
(data, label) = trainset[100] # 從訓(xùn)練集中獲取第100個(gè)樣本的數(shù)據(jù)(圖像)和標(biāo)簽。
print(classes[label])
# (data + 1) / 2是為了還原被歸一化的數(shù)據(jù),將之前歸一化的數(shù)據(jù)重新映射到0到1的范圍內(nèi)。
show((data + 1) / 2).resize((200, 200))
輸出為:
ship
Dataloader對(duì)象
Dataloader是一個(gè)可迭代的對(duì)象,它將dataset返回的每一條數(shù)據(jù)拼接成一個(gè)batch,并提供多線程加速優(yōu)化和數(shù)據(jù)打亂等操作。當(dāng)程序?qū)ataset的所有數(shù)據(jù)遍歷完一遍之后,相應(yīng)的對(duì)Dataloader也完成了一次迭代
dataiter = iter(trainloader)
images, labels = next(dataiter) # 返回4張圖片及標(biāo)簽
print(','.join('%11s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid((images+1)/2)).resize((400,100))
-
使用
iter(trainloader)
將訓(xùn)練數(shù)據(jù)加載器轉(zhuǎn)換成一個(gè)迭代器對(duì)象dataiter
。 -
使用
next(dataiter)
從迭代器中獲取下一個(gè)批次的數(shù)據(jù)。這里假設(shè)每個(gè)批次的大小為4,所以images
和labels
分別是一個(gè)包含4張圖片和對(duì)應(yīng)標(biāo)簽的張量。 -
通過(guò)一個(gè)循環(huán)遍歷了這4張圖片的標(biāo)簽,并使用
classes[labels[j]]
將每個(gè)標(biāo)簽轉(zhuǎn)換為對(duì)應(yīng)的類(lèi)別名稱(chēng)。classes
是一個(gè)包含CIFAR-10數(shù)據(jù)集各個(gè)類(lèi)別名稱(chēng)的列表。 -
使用
tv.utils.make_grid()
函數(shù)將這4張圖片拼接成一張網(wǎng)格圖,并通過(guò)(images+1)/2
將像素值從[-1, 1]的范圍映射到[0, 1]的范圍。使用show()
函數(shù)顯示圖像,并調(diào)用resize()
對(duì)圖像進(jìn)行調(diào)整大小,再使用print()
輸出調(diào)整大小后的圖像。
輸出為:
cat, truck, plane, deer
二、定義網(wǎng)絡(luò)
LeNet網(wǎng)絡(luò),self.conv1第一個(gè)參數(shù)為3通道,因?yàn)镃IFAR-10是3通道彩圖
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(x.size()[0], -1) # -1表示會(huì)自適應(yīng)的調(diào)整剩余的維度
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
輸出為:
Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
模型包含以下層:
-
self.conv1
: 輸入通道數(shù)為3,輸出通道數(shù)為6,卷積核大小為5x5的卷積層。 -
self.conv2
: 輸入通道數(shù)為6,輸出通道數(shù)為16,卷積核大小為5x5的卷積層。 -
self.fc1
: 輸入大小為16x5x5,輸出大小為120的全連接層。 -
self.fc2
: 輸入大小為120,輸出大小為84的全連接層。 -
self.fc3
: 輸入大小為84,輸出大小為10的全連接層。
模型的前向傳播函數(shù)(forward
):
- 先經(jīng)過(guò)第一個(gè)卷積層,然后應(yīng)用ReLU激活函數(shù)和2x2的最大池化操作。
- 再經(jīng)過(guò)第二個(gè)卷積層,同樣應(yīng)用ReLU激活函數(shù)和2x2的最大池化操作。
- 通過(guò)
x.view(x.size()[0], -1)
將特征張量x展平為一維向量,以便輸入全連接層。 - 依次經(jīng)過(guò)兩個(gè)全連接層,并使用ReLU激活函數(shù)進(jìn)行非線性變換。
- 最后一層是一個(gè)全連接層,輸出大小為10,對(duì)應(yīng)CIFAR-10數(shù)據(jù)集的10個(gè)類(lèi)別。這里沒(méi)有使用激活函數(shù),因?yàn)樵撃P蛯⑵漭敵鲋苯幼鳛榉诸?lèi)的得分。
總體而言,該模型由兩個(gè)卷積層和三個(gè)全連接層組成,用于對(duì)CIFAR-10數(shù)據(jù)集進(jìn)行圖像分類(lèi)。
三、定義損失函數(shù)和優(yōu)化器(loss和optimizer)
from torch import optim
criterion = nn.CrossEntropyLoss() # 交叉熵?fù)p失函數(shù)
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
-
nn.CrossEntropyLoss()
創(chuàng)建了一個(gè)交叉熵?fù)p失函數(shù)的實(shí)例,用于計(jì)算分類(lèi)任務(wù)中的損失。交叉熵?fù)p失函數(shù)通常用于多類(lèi)別分類(lèi)問(wèn)題,它將模型的輸出與真實(shí)標(biāo)簽進(jìn)行比較,并計(jì)算出一個(gè)數(shù)值作為損失值,用來(lái)衡量模型預(yù)測(cè)與真實(shí)標(biāo)簽之間的差異。 -
optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
創(chuàng)建了一個(gè)隨機(jī)梯度下降(SGD)優(yōu)化器的實(shí)例。net.parameters()
表示要優(yōu)化的模型參數(shù),即神經(jīng)網(wǎng)絡(luò)中的權(quán)重和偏置。lr=0.001
是學(xué)習(xí)率(learning rate),控制每次參數(shù)更新的步長(zhǎng)大小。momentum=0.9
表示動(dòng)量(momentum)參數(shù),用于加速優(yōu)化過(guò)程并避免陷入局部最優(yōu)解。
四、訓(xùn)練網(wǎng)絡(luò)并更新網(wǎng)絡(luò)參數(shù)
t.set_num_threads(8) # 設(shè)置線程數(shù)為 8,以加速訓(xùn)練過(guò)程。
for epoch in range(2): # 指定訓(xùn)練的輪數(shù)為 2 輪(epoch),即遍歷整個(gè)數(shù)據(jù)集兩次。
running_loss = 0.0 # 記錄當(dāng)前訓(xùn)練階段的損失值
for i, data in enumerate(trainloader, 0):
# 輸入數(shù)據(jù)
inputs, labels = data
# 梯度清零
optimizer.zero_grad() # 每個(gè) batch 開(kāi)始時(shí),將優(yōu)化器的梯度緩存清零,以避免梯度累積
# forward + backward
outputs = net(inputs)
loss = criterion(outputs, labels) # 進(jìn)行前向傳播,然后計(jì)算損失函數(shù) loss
loss.backward() # 自動(dòng)計(jì)算損失函數(shù)相對(duì)于模型參數(shù)的梯度
# 更新參數(shù)
optimizer.step() # 使用優(yōu)化器 optimizer 來(lái)更新模型的權(quán)重和偏置,以最小化損失函數(shù)
# 打印log信息
# loss 是一個(gè)scalar,需要使用loss.item()來(lái)獲取數(shù)值,不能使用loss[0]
running_loss += loss.item()
if i % 2000 == 1999: # 每2000個(gè)batch打印一下訓(xùn)練狀態(tài)
print('[%d, %5d] loss: %.3f' \
% (epoch+1, i+1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
輸出結(jié)果:
[1, 2000] loss: 2.247
[1, 4000] loss: 1.974
[1, 6000] loss: 1.753
[1, 8000] loss: 1.605
[1, 10000] loss: 1.527
[1, 12000] loss: 1.472
[2, 2000] loss: 1.424
[2, 4000] loss: 1.386
[2, 6000] loss: 1.331
[2, 8000] loss: 1.303
[2, 10000] loss: 1.300
[2, 12000] loss: 1.275
Finished Training
enumerate函數(shù)
enumerate
是Python內(nèi)置函數(shù)之一,用于將一個(gè)可迭代的對(duì)象(如列表、元組、字符串等)組合為一個(gè)索引序列。它返回一個(gè)枚舉對(duì)象,包含了原始對(duì)象中的元素以及對(duì)應(yīng)的索引值。
enumerate
函數(shù)的一般語(yǔ)法如下:
enumerate(iterable, start=0)
其中,iterable
是要進(jìn)行枚舉的可迭代對(duì)象,start
是可選參數(shù),表示起始的索引值,默認(rèn)為0。
下面是一個(gè)簡(jiǎn)單的例子來(lái)說(shuō)明enumerate
函數(shù)的用法:
fruits = ['apple', 'banana', 'cherry']
for index, fruit in enumerate(fruits):
print(index, fruit)
輸出結(jié)果:
0 apple
1 banana
2 cherry
在上述示例中,enumerate
函數(shù)將列表fruits
中的元素與對(duì)應(yīng)的索引值配對(duì),然后通過(guò)for
循環(huán)依次取出每個(gè)元素和索引值進(jìn)行打印。
在機(jī)器學(xué)習(xí)或深度學(xué)習(xí)中,enumerate
函數(shù)常常與循環(huán)結(jié)合使用,用于遍歷數(shù)據(jù)集或批次數(shù)據(jù),并同時(shí)獲取數(shù)據(jù)的索引值。這在模型訓(xùn)練過(guò)程中很有用,可以方便地記錄當(dāng)前處理的數(shù)據(jù)的位置信息。
五、測(cè)試網(wǎng)絡(luò)
部分?jǐn)?shù)據(jù)集(實(shí)際的label)
dataiter = iter(testloader)
images, labels = next(dataiter) # 一個(gè)batch返回4張圖片
print('實(shí)際的label: ', ' '.join(\
'%08s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid(images+1)/2).resize((400,100))
輸出結(jié)果:
實(shí)際的label: cat ship ship plane
部分?jǐn)?shù)據(jù)集(預(yù)測(cè)的label)
# 計(jì)算圖片在每個(gè)類(lèi)別上的分?jǐn)?shù)
outputs = net(images)
# 得分最高的那個(gè)類(lèi)
_, predicted = t.max(outputs.data, 1)
print('預(yù)測(cè)結(jié)果: ', ' '.join('%5s'\
% classes[predicted[j]] for j in range(4)))
輸出結(jié)果:
預(yù)測(cè)結(jié)果: cat car ship plane
整個(gè)測(cè)試集
correct = 0 # 預(yù)測(cè)正確的圖片數(shù)
total = 0 # 總共的圖片數(shù)
# 使用 torch.no_grad() 上下文管理器,表示在測(cè)試過(guò)程中不需要計(jì)算梯度,以提高速度和節(jié)約內(nèi)存
with t.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = t.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('10000張測(cè)試集中的準(zhǔn)確率為: %d %%' % (100 * correct / total))
輸出結(jié)果:文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-740641.html
10000張測(cè)試集中的準(zhǔn)確率為: 54 %
訓(xùn)練的準(zhǔn)確率遠(yuǎn)比隨機(jī)猜測(cè)(準(zhǔn)確率10%)好。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-740641.html
到了這里,關(guān)于【深度學(xué)習(xí)】pytorch——實(shí)現(xiàn)CIFAR-10數(shù)據(jù)集的分類(lèi)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!