前言
我們之前講過的模型通常聚焦單個任務(wù),比如預(yù)測圖片的類別等,在訓(xùn)練的時候,我們會關(guān)注某一個特定指標(biāo)的優(yōu)化.
但是有時候,我們需要知道一個圖片,從它身上知道新聞的類型(政治/體育/娛樂)和是男性的新聞還是女性的.
我們關(guān)注某一個特定指標(biāo)的優(yōu)化,可能忽略了對有關(guān)注的指標(biāo)的有用信息.具體來說就是訓(xùn)練相關(guān)任務(wù)所帶來的額外信息,通過在多個相關(guān)任務(wù)中共享表示,我們可以使得模型在我們原本任務(wù)上獲得更好的泛化能力.這種方法就叫做多任務(wù)學(xué)習(xí).
1.多任務(wù)學(xué)習(xí)
1.1 定義
同時完成多個預(yù)測,共享表示,共享特征提取.使得模型關(guān)注到一些特有的特征.其實一套提取特征的網(wǎng)絡(luò),配合多個損失函數(shù),就是多任務(wù)損失.
圖像定位是單任務(wù),若還需要知道類別,就變成了多任務(wù)學(xué)習(xí).
1.2 原理
多任務(wù)學(xué)習(xí)的模型通常通過所有任務(wù)重共用隱藏層(特征提取層),而針對不同任務(wù)使用多個輸出層來實現(xiàn).自動學(xué)習(xí)到的任務(wù)越多,模型就能獲得捕捉所有任務(wù)的表示,而原本任務(wù)上過擬合的風(fēng)險更小.
多任務(wù)學(xué)習(xí)中,針對一個任務(wù)的特征提取,由于其它任務(wù)也能對提取的特征做出篩選,所以可以幫助模型將注意力集中到那些真正起作用的特征上.
模型會學(xué)習(xí)那些盡量表達(dá)多個任務(wù)的特征,而這些特征泛化能力會很好.
2. 多任務(wù)學(xué)習(xí)code
同時預(yù)測一個物品的顏色和類別.
2.1 數(shù)據(jù)集初探
一個分支用于分類給定輸入圖像的服裝種類(比如襯衫、裙子、牛仔褲、鞋子等);
另一個分支負(fù)責(zé)分類該服裝的顏色(黑色、紅色、藍(lán)色等)。
總體而言,我們的數(shù)據(jù)集由 2525 張圖像構(gòu)成,分為 7 種「顏色+類別」組合,包括:
黑色牛仔褲(344 張圖像)
黑色鞋子(358 張圖像)
藍(lán)色裙子(386 張圖像)
藍(lán)色牛仔褲(356 張圖像)
藍(lán)色襯衫(369 張圖像)
紅色裙子(380 張圖像)
紅色襯衫(332 張圖像)
數(shù)據(jù)集下載鏈接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取碼:2kbc
2.2 預(yù)處理
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
import glob
from torchvision import transforms
from torch.utils import data
from PIL import Image
img_paths = glob.glob(r"F:\multi-output-classification\dataset\*\*.jpg")
img_paths[:5]
路徑文件夾就表示了標(biāo)簽,所以要獲取其標(biāo)簽:
label_names = [img_path.split("\\")[-2] for img_path in img_paths]
label_names[:5]
label_array = np.array([la.split("_") for la in label_names])
label_array
label_color = label_array[:,0]
label_color
label_item = label_array[:,1]
label_item
吧他們轉(zhuǎn)成index,因為torch中只認(rèn)數(shù)字
unique_color = np.unique(label_color)
unique_color
unique_item = np.unique(label_item)
unique_item
item_to_idx = dict((v,k) for k, v in enumerate(unique_item))
item_to_idx
color_to_idx = dict((v,k) for k, v in enumerate(unique_color))
color_to_idx
label_item = [item_to_idx.get(k) for k in label_item]
label_color = [color_to_idx.get(k) for k in label_color ]
transform = transforms.Compose([
transforms.Resize((96,96)),
transforms.ToTensor(),
])
自定義數(shù)據(jù)集
class Multi_dataset(data.Dataset):
def __init__(self,imgs_path, label_color, label_item) -> None:
super().__init__()
self.imgs_path = imgs_path
self.label_color = label_color
self.label_item = label_item
def __getitem__(self, index):
img_path = self.imgs_path[index]
pil_img = Image.open(img_path)
# 防止有圖片有黑白圖
pil_img = pil_img.convert('RGB')
pil_img = transform(pil_img)
label_c = self.label_color[index]
label_i = self.label_item[index]
return pil_img, (label_c,label_i)
def __len__(self):
return len(self.imgs_path)
劃分訓(xùn)練集
count = len(multi_dataset)
count
# 劃分訓(xùn)練集 測試集
train_count = int(count*0.8)
test_count = count - train_count
train_ds, test_ds = data.random_split(multi_dataset,[train_count, test_count])
len(train_ds),len(test_ds)
BATCHSIZE = 32
train_dl = data.DataLoader(train_ds,batch_size=BATCHSIZE,shuffle=True)
test_dl = data.DataLoader(test_ds,batch_size=BATCHSIZE)
文章來源:http://www.zghlxwxcb.cn/news/detail-859928.html
2.3 網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計
## 定義網(wǎng)絡(luò)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3,16,3)
self.conv2 = nn.Conv2d(16,32,3)
self.conv3 = nn.Conv2d(32,64,3)
self.fc = nn.Linear(64*10*10, 1024)
self.fc1 = nn.Linear(1024,3)
self.fc2 = nn.Linear(1024,4)
def forward(self,x):
# 3X96X96-->3X48*48--->3X24X24--->3X12X12
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x,2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x,2)
x = x.view(-1,64*10*10)
c = F.relu(self.fc(x))
i = self.fc2(x)
return c,i
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Net().to(device)
model
Net(
(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
(conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(fc): Linear(in_features=6400, out_features=1024, bias=True)
(fc1): Linear(in_features=1024, out_features=3, bias=True)
(fc2): Linear(in_features=1024, out_features=4, bias=True)
)
2.4 訓(xùn)練
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
3. 總結(jié)
未完待續(xù)文章來源地址http://www.zghlxwxcb.cn/news/detail-859928.html
到了這里,關(guān)于【多任務(wù)學(xué)習(xí)】Multi-task Learning 手把手編碼帶數(shù)據(jù)集, 一文吃透多任務(wù)學(xué)習(xí)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!