在用深度學習做分類的時候,常常需要進行交叉驗證,目前pytorch沒有通用的一套代碼來實現(xiàn)這個功能??梢越柚?sklearn中的 StratifiedKFold,KFold來實現(xiàn),其中StratifiedKFold可以根據(jù)類別的樣本量,進行數(shù)據(jù)劃分。以5折為例,它可以實現(xiàn)每個類別的樣本都是4:1劃分。
代碼簡單的示例如下:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5)
for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
以上示例是將所有imgs列表與對應的labels列表進行split,得到train_idx代表訓練集的下標,val_idx代表驗證集的下標。后續(xù)代碼只需要將split完成的trainset與valset輸入dataset即可。
接下來用我自己數(shù)據(jù)集的實例來完整地實現(xiàn)整個過程,即從讀取數(shù)據(jù),到開始訓練。如果你的數(shù)據(jù)集存儲方式和我不同,改一下數(shù)據(jù)讀取代碼即可。關鍵是如何獲取到imgs和對應的labels。
我的數(shù)據(jù)存儲方式是這樣的(類別為文件夾名,屬于該類別的圖像在該文件夾下):
"""A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
?以下代碼是獲取imgs與labels的過程:
import os
import numpy as np
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png')
def is_image_file(filename):
return filename.lower().endswith(IMG_EXTENSIONS)
def find_classes(dir):
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
if __name__ == "__main__":
dir = 'your root path'
classes, class_to_idx = find_classes(dir)
imgs = []
labels = []
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(dir, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_image_file(path):
imgs.append(path)
labels.append(class_index)
上述代碼只需要把dir改為自己的root路徑即可。接下來對所有數(shù)據(jù)進行5折split。其中我自己寫了MyDataset類,可以直接照搬用。
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5) #5折
for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
train_dataset = MyDataset(trainset, traintag, data_transforms['train'] )
val_dataset = MyDataset(valset, valtag, data_transforms['val'])
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, imgs, labels, transform=None,target_transform=None):
self.imgs = imgs
self.labels = labels
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
path = self.imgs[idx]
target = self.labels[idx]
with open(path, 'rb') as f:
img = Image.open(f)
img = img.convert('RGB')
if self.transform:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
有了數(shù)據(jù)集之后,就可以創(chuàng)建dataloader了,后面就是正常的訓練代碼:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5) #5折
for i, (train_idx, val_idx) in enumerate(skf.split(imgs, labels)):
trainset, valset = np.array(imgs)[[train_idx]],np.array(imgs)[[val_idx]]
traintag, valtag = np.array(labels)[[train_idx]],np.array(labels)[[val_idx]]
train_dataset = MyDataset(trainset, traintag, data_transforms['train'] )
val_dataset = MyDataset(valset, valtag, data_transforms['val'])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
# define model
model = resnet18().cuda()
# define criterion
criterion = torch.nn.CrossEntropyLoss()
# Observe that all parameters are being optimized.
optimizer = optim.SGD(model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
for epoch in range(args.epoch):
train_acc, train_loss = train(train_dataloader, model, criterion, args)
test_acc, tect_acc_top5, test_loss = validate(test_dataloader, model, criterion, args)
為了保證每次跑的時候分的數(shù)據(jù)都是一致的,注意shuffle=False(默認)
StratifiedKFold(n_splits=5,shuffle=False)文章來源:http://www.zghlxwxcb.cn/news/detail-447854.html
以上就是實現(xiàn)的基本代碼,之所以在代碼層面實現(xiàn)k折而不是在數(shù)據(jù)層面做,比如預先把數(shù)據(jù)等分為5份。是因為這個代碼可以支持數(shù)據(jù)樣本的隨意增減,不需要人為地再去分數(shù)據(jù),十分方便。?文章來源地址http://www.zghlxwxcb.cn/news/detail-447854.html
到了這里,關于手把手教你用pytorch實現(xiàn)k折交叉驗證,解決類別不平衡的文章就介紹完了。如果您還想了解更多內容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關文章,希望大家以后多多支持TOY模板網(wǎng)!