目錄
一,Pytorch簡介;
二,環(huán)境配置;
三,自定義數(shù)據(jù)集;
四,模型訓(xùn)練;
五,模型驗證;
一,Pytorch簡介;
????????PyTorch是一個開源的Python機(jī)器學(xué)習(xí)庫,基于Torch,用于自然語言處理等應(yīng)用程序。PyTorch 基于 Python:?PyTorch 以 Python 為中心或“pythonic”,旨在深度集成 Python 代碼,而不是作為其他語言編寫的庫的接口。Python 是數(shù)據(jù)科學(xué)家使用的最流行的語言之一,也是用于構(gòu)建機(jī)器學(xué)習(xí)模型和 ML 研究的最流行的語言之一。由于其語法類似于 Python 等傳統(tǒng)編程語言,PyTorch 比其他深度學(xué)習(xí)框架更容易學(xué)習(xí)。
二,環(huán)境配置;
? ? ? ?版本:
????????系統(tǒng):window10;
? ? ? ? Python:3.11.5;
? ? ? ? pytorch:2.0.1;
? ? ? ?Python安裝:
? ? ? ? Python官網(wǎng):python.org;
? ? ? ? 下載3.11.5版本Python安裝版進(jìn)行安裝;
? ? ? ? 配置Python環(huán)境變量;
? ? ? ? 在系統(tǒng)變量path中添加Python的bin路徑和Script路徑;
? ? ? ? 查看Python是否安裝成功;
????????
? ? ? ? 正常如上顯示表示安裝成功。
????????同時查看Python對應(yīng)的Pip版本;
? ? ? ? Pytorch安裝:
? ? ? ? pytorch官網(wǎng):PyTorch;
????????
????????進(jìn)入Pytorch官網(wǎng)后點擊左上角Get Started查看Pytorch對于的Python版本,GPU版本。默認(rèn)安裝的是CPU版本,本文使用Pip安裝Pytorch方式,直接運(yùn)行Run this Command會報錯,安裝了幾次都不行,所以自己找對應(yīng)的安裝文件進(jìn)行安裝更方便。
? ? ? ? 根據(jù)Pytorch官網(wǎng)介紹的對應(yīng)版本找到我們需要的依賴文件。
? ? ? ? 網(wǎng)址:download.pytorch.org/whl/torch_stable.html
? ? ? ??
? ? ? ? 找到對應(yīng)安裝的版本,cu開頭表示是GPU版本和版本號,torch后面對應(yīng)的是Pytorch版本號,cp對應(yīng)Python版本;點擊下載安裝文件;
? ? ? ? 下載好以后打開文件所在位置,進(jìn)入window命令界面,執(zhí)行命令;
pip install?torch-2.0.1+cu117-cp311-cp311-win_amd64.whl
????????英偉達(dá)GPU安裝:
????????選擇對應(yīng)的GPU版本安裝,安裝完成后驗證下是否安裝成功,正常顯示版本表示安裝成功。
三,自定義數(shù)據(jù)集;
? ? ? ? 從網(wǎng)上下載數(shù)據(jù)集,按照文件夾分類,首先將數(shù)據(jù)集制作成包含圖片路徑,和對應(yīng)索引的csv文件。
import torch import os, glob import random, csv # 所有自定義數(shù)據(jù)集的一個母類 from torch.utils.data import Dataset, DataLoader # 常用的圖片變換器 from torchvision import transforms # 從圖片讀取出數(shù)據(jù) from PIL import Image # 自定義數(shù)據(jù)集的類,繼承自Dataset class Pokemon(Dataset): # 一、初始化函數(shù)init # 第一個參數(shù)root:總的圖片所在的位置,可以是任意的位置,我們的圖片可以放在任意的位置,我們這里就存儲在當(dāng)前目錄文件夾下。 # 第二個參數(shù)resize:圖片輸出的size,是由這個參數(shù)所進(jìn)行設(shè)定。 # 第三個參數(shù)mode:這里我們需要做train、validation以及test,對應(yīng)這三種數(shù)據(jù)結(jié)構(gòu),因此我們用一個list[0,1,2]來代表是哪個模式。 def __init__(self, root, resize, mode): # 先調(diào)用母類的初始化函數(shù): super(Pokemon, self).__init__() # 1、首先我們將這個參數(shù)保存下來 self.root = root self.resize = resize # 2、給每一個分類做一個映射,即當(dāng)前的皮卡丘、妙蛙種子等這個string類型所對應(yīng)的label是多少,這個是需要我們?nèi)藶檫M(jìn)行編碼的。 self.name2label = {} # 用字典來表示映射關(guān)系 # 通過循環(huán)方式,將root路徑下的文件夾名進(jìn)行編碼 for name in sorted(os.listdir(os.path.join(root))): # 過濾掉非文件夾:如果不是dir,就過濾掉,此外我們還通過sorted排序的方法,將鍵值對關(guān)系固定下來 if not os.path.isdir(os.path.join(root, name)): continue # 文件名做key,當(dāng)前name2label的長度做value self.name2label[name] = len(self.name2label.keys()) print(self.name2label) # image, label self.load_csv('images.csv') # 二、創(chuàng)建一個csv,用于保存圖片全路徑和對應(yīng)的標(biāo)簽label # 這個函數(shù)接受一個參數(shù)filename # 這個函數(shù)中需要將所有圖片都load進(jìn)來 def load_csv(self, filename): images = [] for name in self.name2label.keys(): # 類別信息我們可以使用路徑來判斷 # 上面路徑的mewtwo就是類別 images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) print(len(images), images) # 將images順序打亂 random.shuffle(images) # 打開這個文件 with open(os.path.join(self.root, filename), mode='w', newline='') as f: # 新建writer,獲得csv這個文件對象 writer = csv.writer(f) for img in images: # 獲得每行信息 # 通過分割符,將每行信息的內(nèi)容分割開,取導(dǎo)數(shù)第二個,類型 name = img.split(os.sep)[-2] # 通過獲取的類型名來獲取label label = self.name2label[name] # 將這個label信息寫到csv中 # csv是以逗號作為分割的 writer.writerow([img, label]) print('writen into csv file:', filename) # 三、完成兩個自定義的邏輯 # 1、樣本的總體數(shù)量(圖片總體數(shù)量),返回的是一個數(shù)字,總體圖片大概有1168張,60%用于training,因此返回6-7百張圖片 def __len__(self): pass # 2、用于返回當(dāng)前index上面元素的值,這里是返回兩個數(shù)據(jù): # 需要返回當(dāng)前image的data,以及image所對應(yīng)的label[0,1,2,3,4] def __getitem__(self, idx): pass # 創(chuàng)建一個調(diào)試函數(shù): def main(): db = Pokemon('F:\\train', 224, 'train') if __name__ == '__main__': main()
????????將圖片路徑改成自己數(shù)據(jù)的文件夾路徑,運(yùn)行代碼在對應(yīng)路徑下生成.csv格式文件
? ? ? ? 類別索引根據(jù)文件夾種類順序生成,要和csv文件中索引對應(yīng)。數(shù)據(jù)集制作完成后就可以開始訓(xùn)練了。
? ? ? ? 首先定義加載數(shù)據(jù)集類;
import torch import os, glob import random, csv # 所有自定義數(shù)據(jù)集的一個母類 from torch.utils.data import Dataset, DataLoader # 常用的圖片變換器 from torchvision import transforms # 從圖片讀取出數(shù)據(jù) from PIL import Image # 自定義數(shù)據(jù)集的類,繼承自Dataset class Pokemon(Dataset): # 一、初始化函數(shù)init # 第一個參數(shù)root:總的圖片所在的位置,可以是任意的位置,我們的圖片可以放在任意的位置,我們這里就存儲在當(dāng)前目錄文件夾下。 # 第二個參數(shù)resize:圖片輸出的size,是由這個參數(shù)所進(jìn)行設(shè)定。 # 第三個參數(shù)mode:這里我們需要做train、validation以及test,對應(yīng)這三種數(shù)據(jù)結(jié)構(gòu),因此我們用一個list[0,1,2]來代表是哪個模式。 def __init__(self, root, resize, mode): # 先調(diào)用母類的初始化函數(shù): super(Pokemon, self).__init__() # 1、首先我們將這個參數(shù)保存下來 self.root = root self.resize = resize # 2、給每一個分類做一個映射,這個string類型所對應(yīng)的label是多少,這個是需要我們?nèi)藶檫M(jìn)行編碼的。 self.name2label = {} # 用字典來表示映射關(guān)系 # 通過循環(huán)方式,將root路徑下的文件夾名進(jìn)行編碼 for name in sorted(os.listdir(os.path.join(root))): # 過濾掉非文件夾:如果不是dir,就過濾掉,此外我們還通過sorted排序的方法,將鍵值對關(guān)系固定下來 if not os.path.isdir(os.path.join(root, name)): continue # 文件名做key,當(dāng)前name2label的長度做value self.name2label[name] = len(self.name2label.keys()) # print(self.name2label) # 將self.load_csv的返回值images, labels賦予self.images, self.labels self.images, self.labels = self.load_csv('images.csv') # 四、不同比例模式下對圖片數(shù)量進(jìn)行劃分 if mode == 'train': # 取60%做training # len(self.images)的長度是1167,取60%做為train模式的圖片 self.images = self.images[:int(0.6 * len(self.images))] self.labels = self.labels[:int(0.6 * len(self.labels))] elif mode == 'val': # 取20%做validation, 60%-80% self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))] self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))] else: # mode為test,取80%到最末尾 self.images = self.images[int(0.8 * len(self.images)):] self.labels = self.labels[int(0.8 * len(self.labels)):] # 二、創(chuàng)建一個csv,用于保存圖片全路徑和對應(yīng)的標(biāo)簽label # 這個函數(shù)接受一個參數(shù)filename # 這個函數(shù)中需要將所有圖片都load進(jìn)來 def load_csv(self, filename): # 需要一個判斷,如果文件不存在,就需要創(chuàng)建csv,直接讀取創(chuàng)建好的csv文件內(nèi)容即可: # 如果不存在,就需要創(chuàng)建csv if not os.path.exists(os.path.join(self.root, filename)): images = [] for name in self.name2label.keys(): # 類別信息我們可以使用路徑來判斷 # 上面路徑的mewtwo就是類別 images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) print(len(images), images) # 將images順序打亂 random.shuffle(images) # 打開這個文件 with open(os.path.join(self.root, filename), mode='w', newline='') as f: # 新建writer,寫入csv這個文件對象 writer = csv.writer(f) for img in images: # 通過分割符,將每行信息的內(nèi)容分割開,取導(dǎo)數(shù)第二個,類型 name = img.split(os.sep)[-2] # 通過獲取的類型名來獲取label label = self.name2label[name] # 將這個label信息寫到csv中 # csv是以逗號作為分割的 writer.writerow([img, label]) print('writen into csv file:', filename) # 三、讀取csv文件過程: # 這里需要在開頭有一個判斷,如果csv存在,就不用寫入csv了,直接進(jìn)行讀取 # 下次運(yùn)行的時候只需加載進(jìn)來即可 images, labels = [], [] with open(os.path.join(self.root, filename)) as f: # 新建reader,讀取csv這個文件對象 reader = csv.reader(f) for row in reader: img, label = row label = int(label) # 將這個label轉(zhuǎn)碼為int類型 # 將img每個圖片路徑,以及l(fā)abel保存在建立好的列表對象中。 images.append(img) labels.append(label) assert len(images) == len(labels) return images, labels # 完成兩個自定義的邏輯: # 1、樣本的總體數(shù)量(圖片總體數(shù)量),返回的是一個數(shù)字,總體圖片大概有1168張,60%用于training,因此返回6-7百張圖片 # 五、完成總體樣本數(shù)量函數(shù)的內(nèi)容 def __len__(self): # 這里的樣本長度是跟模型類別來決定的,上面已經(jīng)根據(jù)不同模型類型劃分了樣本數(shù)量了。 # 不同模式下,樣本長度是不同的。 # 因此這里的總體樣本長度,就是不同模式下的樣本數(shù)量。 return len(self.images) # 九、解決normalize處理后,visdom無法正常顯示的問題 # 這里傳入的參數(shù)x是normalize過后的 def denormalize(self, x_hat): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1) std = torch.tensor(std).unsqueeze(1).unsqueeze(1) print('mean.shape,std.shape:', mean.shape, std.shape) x = x_hat * std + mean return x # 2、用于返回當(dāng)前index上面元素的值,這里是返回兩個數(shù)據(jù): # 需要返回當(dāng)前image的data,以及image所對應(yīng)的label[0,1,2,3,4] # 六、完成index與樣本的一一對應(yīng) def __getitem__(self, idx): # idx數(shù)值范圍是[0-len(images)] # self.images保存了所有的數(shù)據(jù);self.labels保存了所有數(shù)據(jù)對應(yīng)的label信息; # img是一個string類型(還不是具體的圖片,只是路徑) # label是一個整數(shù)類型 img, label = self.images[idx], self.labels[idx] # 這里就需要將img所對應(yīng)的路徑讀取出圖片,并轉(zhuǎn)為tensor類型 # 這里我們可以Compose組合操作步驟 # 八、增加數(shù)據(jù)預(yù)處理的工作,在Compose中增加這些內(nèi)容,data augmentation數(shù)據(jù)增強(qiáng) # 這里我們做放大、旋轉(zhuǎn)、裁切這三個數(shù)據(jù)增強(qiáng)的操作 tf = transforms.Compose([ # 這里需要將路徑變成具體的圖片數(shù)據(jù)類型 # 即:string path => image data lambda x: Image.open(x).convert('RGB'), # Resize工作,這里的size是我們實例化時的self.resize的值 # 1、data augmentation放大:在Resize設(shè)置的基礎(chǔ)上,稍微調(diào)大一些size, 調(diào)整為1.25倍 transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))), # 2、data augmentation旋轉(zhuǎn):增加隨機(jī)旋轉(zhuǎn),注意:這里旋轉(zhuǎn)角度不能太大,會增加學(xué)習(xí)的難度。 transforms.RandomRotation(15), # 3、data augmentation中心裁切:裁切為我們所需要的大小 transforms.CenterCrop(self.resize), # 將數(shù)據(jù)變?yōu)閠ensor類型 transforms.ToTensor(), # 4、normalize處理,希望圖片數(shù)值范圍在0左右分布,而不希望數(shù)值只分布在0的右側(cè)或只在左側(cè) # 其中參數(shù)統(tǒng)計的所有image net數(shù)據(jù)集幾百萬張圖片的mean=[R的mean,G的mean,B的mean]和std=[R的方差,G的方差,B的方差] # 基本上這個數(shù)值是通用的 # 數(shù)據(jù)通過Normalize處理后,就是在-1到1之間分布了。 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = tf(img) label = torch.tensor(label) return img, label # 創(chuàng)建一個調(diào)試函數(shù): def main(): # 七、驗證自定義數(shù)據(jù)集 # 驗證需要一些輔助函數(shù),用visdom做一些可視化。 import visdom import time import torchvision # 通過API較為簡便的加載自定義數(shù)據(jù)集,需要引入torchvision # 創(chuàng)建一個visdom這個對象 viz = visdom.Visdom() # 十一、通過API較為簡便的加載自定義數(shù)據(jù)集(前提是數(shù)據(jù)集按照不同類型存儲在對應(yīng)類型命名的文件夾下面,并且這些不同類別的文件夾都存儲在統(tǒng)一的一個文件夾下,只有這種固定的二級目錄存儲形式才能用這個API進(jìn)行加載。) tf = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) # 參數(shù)1:傳入路徑 # 參數(shù)2:變換器,這個變換器就是進(jìn)行resize操作 db = torchvision.datasets.ImageFolder(root='F:\\train', transform=tf) loader = DataLoader(db, batch_size=32, shuffle=True) print(db.class_to_idx) # 通過這個就能知道不同類別是如何編碼的了。 if __name__ == '__main__': main()
????????將上面代碼修改即可;
四,模型訓(xùn)練;
? ? ? ? 這里我們需要用到可視化工具來查看我們訓(xùn)練效果。
? ? ? ? 安裝visdom:
pip install visdom
? ? ? ? 在pycharm命令界面啟動visdom:
python -m visdom.server
? ? ? ? 正常啟動在瀏覽器輸入localhost:8097打開可視化界面;
? ? ? ? 準(zhǔn)備工作完成,編寫模型訓(xùn)練代碼,這么我們直接使用Pytorch自帶的神經(jīng)網(wǎng)絡(luò)resnet18模型;
import torch from torch import optim, nn import visdom import torchvision from torch.utils.data import DataLoader from pokemon import Pokemon from torchvision.models import resnet18 # 這個resnet18是已經(jīng)training好的狀態(tài) from utils import Flatten # 用于打平,這個是自己來實現(xiàn)的打平層 batchsz = 32 lr = 1e-3 epochs = 40 device = torch.device('cuda') torch.manual_seed(1234) # 這個是隨機(jī)數(shù)種子,保證每次都能復(fù)現(xiàn)出來。 # 這里是需要實例化Pokemon類 # 這里之所以使用224,是因為是ResNet最適合的大小。 train_db = Pokemon('F:\\train', 224, 'train') val_db = Pokemon('F:\\train', 224, 'val') test_db = Pokemon('F:\\train', 224, 'test') # 批量加載數(shù)據(jù) # 參數(shù)num_workers表示工作線程數(shù): train_loader = DataLoader(train_db , batch_size=batchsz , shuffle=True , num_workers=4) val_loader = DataLoader(val_db , batch_size=batchsz , num_workers=2) test_loader = DataLoader(test_db , batch_size=batchsz , num_workers=2) # 需要把train的進(jìn)度保存下來,需要用到visdom viz = visdom.Visdom() # 建立一個測試函數(shù):測試函數(shù)針對validation和test功能是一樣的 def evalute(model, loader): # 用于統(tǒng)計總的預(yù)測正確的數(shù)量 correct = 0 # 總的測試數(shù)量 total = len(loader.dataset) for x, y in loader: x, y = x.to(device), y.to(device) with torch.no_grad(): # test和validation是不需要梯度信息的 logits = model(x) pred = logits.argmax(dim=1) # 最大的值所在的位置 # 總的預(yù)測正確的數(shù)量,累加操作 correct += torch.eq(pred, y).sum().float().item() accuracy = correct / total return accuracy def main(): # 實例化模型 # 使用已經(jīng)訓(xùn)練好的resnet18模型,一定要設(shè)置這個參數(shù)pretrained=True trained_model = resnet18(pretrained=True) # 我們要使用訓(xùn)練好的resnet18模型的A部分,即取出前17層: # Sequential結(jié)束的是一個打散的數(shù)據(jù),所有我們在list前加一個*,*args:接收若干個位置參數(shù),轉(zhuǎn)換成元組tuple形式。 model = nn.Sequential(*list(trained_model.children())[:-1] # model的前17層(即A部分)返回的結(jié)果是:[b,512,1,1] , Flatten() # 打平操作從[b,512,1,1]=>[b,512] , nn.Linear(512, 14) # 這層是最后那層,用于從新學(xué)習(xí)分成14類。(第二個參數(shù)為自定義數(shù)據(jù)集實際訓(xùn)練種類數(shù)量,根據(jù)自己數(shù)據(jù)集的種類數(shù)據(jù)傳遞實際值) ).to(device) # 我們從已經(jīng)訓(xùn)練好的resnet18開始訓(xùn)練效果會好很多 # # 這里我們測試一下 # x = torch.randn(2,3,224,224) # print(model(x).shape)#打印結(jié)果為:torch.Size([2, 5]) # #這樣就實現(xiàn)了transfer learning # ====================================================== # 創(chuàng)建一個優(yōu)化器Adam,這個優(yōu)化器比較好 optimizer = optim.Adam(model.parameters(), lr=lr) # Loss的計算方法:CrossEntropyLoss; # 這個Loss所接受的參數(shù)是logits,logits是不需要經(jīng)過一個softmax的,只需要得到logits即可。 criteon = nn.CrossEntropyLoss() # 用于保存模型的訓(xùn)練狀態(tài) best_acc, best_epoch = 0, 0 # step每次都是從0開始的,因此這里我們創(chuàng)建一個全局step global_step = 0 # 用visdom工具保存下accuracy和loss # training和loss的曲線 # x=0,y=-1是初始狀態(tài) viz.line([0], [-1], win='loss', opts=dict(title='loss(損失值)')) # training和validation accuracy的曲線 viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc(準(zhǔn)確率)')) # training邏輯 for epoch in range(epochs): for step, (x, y) in enumerate(train_loader): # x:[b,3,224,224]; y:[b] x, y = x.to(device), y.to(device) # x和y都轉(zhuǎn)移到cuda上面 # 執(zhí)行forward函數(shù) logits = model(x) # 學(xué)出的預(yù)測結(jié)果 # 在pytorch中crossEntropyLoss中,傳入的真實值y不需要進(jìn)行one-hot操作,不需要做one-hot編碼,會在內(nèi)部做one-hot。 # 所以我們直接傳入y就可以了。 loss = criteon(logits, y) # 預(yù)測結(jié)果與真實值進(jìn)行交叉熵計算 # 前向傳播和迭代過程 # 優(yōu)化器 optimizer.zero_grad() loss.backward() optimizer.step() # 用visdom工具保存下accuracy和loss # 每一個step我都要記錄下來 # validation和loss的曲線 # x=loss.item()loss是一個tensor,因此需要通過item轉(zhuǎn)為具體數(shù)值,y=-1是初始狀態(tài) # 參數(shù)update為append,表示添加到曲線的末尾。 viz.line([loss.item()], [global_step], win='loss', update='append') global_step += 1 # 這里我們每完成兩個epoch就做一組validation if epoch % 1 == 0: # 我們根據(jù)validation accuracy來選擇要不要保存這個模型的訓(xùn)練狀態(tài)。 val_acc = evalute(model, val_loader) # 如果當(dāng)前accuracy大于best_acc,就保存當(dāng)前的狀態(tài): if val_acc > best_acc: best_epoch = epoch best_acc = val_acc # 保存當(dāng)前模型的狀態(tài): # 參數(shù)一:模型狀態(tài)值 # 參數(shù)二:模型狀態(tài)保存的文件名,文件名后綴隨意 torch.save(model, 'best-pro.pth') # validation和 accuracy的曲線 # 這里val_acc是數(shù)值型,所以不需要轉(zhuǎn)換。 viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc:', best_acc, 'best epoch:', best_epoch) # 從最好的狀態(tài)加載模型: # model.load_state_dict(torch.load('best-pro.ptl')) # print('loaded from check point!') # # # 上面加載了最好的模型狀態(tài),這里使用的模型也是最好的狀態(tài)時的模型 # test_acc = evalute(model, test_loader) # print('test_acc:', test_acc) if __name__ == '__main__': main()
這里我們用到了一個util:
from matplotlib import pyplot as plt import torch from torch import nn # 該函數(shù)是一個標(biāo)準(zhǔn)的打平層 class Flatten(nn.Module): # 該文件utils包含一些輔助函數(shù)。 def __init__(self): super(Flatten, self).__init__() def forward(self, x): shape = torch.prod(torch.tensor(x.shape[1:])).item() return x.view(-1, shape) # 該函數(shù)是將img打印到matplotlib上 def plot_image(img, label, name): fig = plt.figure() for i in range(6): plt.subplot(2, 3, i + 1) plt.tight_layout() plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none') plt.title("{}: {}".format(name, label[i].item())) plt.xticks([]) plt.yticks([]) plt.show()
運(yùn)行函數(shù)打開可視化界面,查看訓(xùn)練情況;
? ? ? ? 剛開始訓(xùn)練的情況,使用數(shù)據(jù)量大概1.6w張最終結(jié)果大概是準(zhǔn)確率96%。已經(jīng)非常好了。
五,模型驗證;
import numpy as np import torch import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image device = torch.device('cuda') def main(): labels = ['兔子', '吊蘭', '文竹', '月季', '枸骨', '狗', '獅子', '貓', '綠蘿', '老虎', '菊花', '蛇', '迎春花', '龜背竹'] image_path = "C:/Users/LENOVO/Desktop/dog.png" image = Image.open(image_path) image = image.resize((256, 256), Image.BILINEAR).convert("RGB") image = np.array(image) to_tensor = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) image = to_tensor(image) image = torch.unsqueeze(image, 0) image = image.cuda() model = torch.load("剛才訓(xùn)練好的模型") model.eval() model.to(device) output = model(image) output1 = F.softmax(output, dim=1) predicted = torch.max(output1, dim=1)[1].cpu().item() outputs2 = output1.squeeze(0) confidence = outputs2[predicted].item() confidence = round(confidence, 3) print("識別結(jié)果: ", labels[predicted], " 準(zhǔn)確率為: ", confidence * 100, "%") if __name__ == '__main__': main()
????????測試圖片:
????????labels為我們訓(xùn)練的類別數(shù)組,和cvs的索引對應(yīng)。
文章來源:http://www.zghlxwxcb.cn/news/detail-728881.html
多次測試結(jié)果全對,準(zhǔn)確率不低于95%。文章來源地址http://www.zghlxwxcb.cn/news/detail-728881.html
到了這里,關(guān)于Pytorch目標(biāo)分類深度學(xué)習(xí)自定義數(shù)據(jù)集訓(xùn)練的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!