第一步:導(dǎo)入所需要的包
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
torch.manual_seed(0) # 為CPU設(shè)置種子
torch.cuda.manual_seed(0) # 為GPU設(shè)置種子
第二步:定義教師模型
教師模型網(wǎng)絡(luò)結(jié)構(gòu)(此處僅舉一個(gè)例子):卷積層-卷積層-dropout-dropout-全連接層-全連接層
class TeacherNet(nn.Module):
def __init__(self):
super(TeacherNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) # 卷積層
self.conv2 = nn.Conv2d(32, 64, 3, 1) # 卷積層
self.dropout1 = nn.Dropout2d(0.3) # dropout
self.dropout2 = nn.Dropout2d(0.5) # dropout
self.fc1 = nn.Linear(9216, 128) # 全連接層
self.fc2 = nn.Linear(128, 10) # 全連接層
def forward(self, x):
x = self.conv1(x)
x = F.relu(x) # 激活函數(shù)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
output = self.fc2(x)
return output
第三步:定義訓(xùn)練教師模型方法
正常的定義一個(gè)神經(jīng)網(wǎng)絡(luò)模型
def train_teacher(model, device, train_loader, optimizer, epoch):
model.train()
trained_samples = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) # 將數(shù)據(jù)轉(zhuǎn)移到CPU/GPU
optimizer.zero_grad() # 優(yōu)化器將梯度全部置為0
output = model(data) # 數(shù)據(jù)經(jīng)過(guò)模型向前傳播
loss = F.cross_entropy(output, target) # 計(jì)算損失函數(shù)
loss.backward() # 反向傳播
optimizer.step() # 更新梯度
trained_samples += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50) # 計(jì)算訓(xùn)練進(jìn)度
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, trained_samples, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')
第四步:定義教師模型測(cè)試方法
正常的定義一個(gè)神經(jīng)網(wǎng)絡(luò)模型
def test_teacher(model, device, test_loader):
model.eval() # 設(shè)置為評(píng)估模式
test_loss = 0
correct = 0
with torch.no_grad(): # 不計(jì)算梯度,減少計(jì)算量
for data, target in test_loader:
data, target = data.to(device), target.to(device) # 將數(shù)據(jù)轉(zhuǎn)移到CPU/GPU
output = model(data) # 經(jīng)過(guò)模型正向傳播得到結(jié)果
test_loss += F.cross_entropy(output, target, reduction='sum').item() # 計(jì)算總的損失函數(shù)
pred = output.argmax(dim=1, keepdim=True) # 獲取最大對(duì)數(shù)概率索引
correct += pred.eq(target.view_as(pred)).sum().item() # pred.eq(target.view_as(pred)) 會(huì)返回一個(gè)布爾張量,其中每個(gè)元素表示預(yù)測(cè)值是否等于目標(biāo)值。然后,.sum().item() 會(huì)將所有為 True 的元素相加,從而得到正確分類的數(shù)量。
test_loss /= len(test_loader.dataset) # 計(jì)算損失函數(shù)
print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct / len(test_loader.dataset)
第五步:定義教師模型主函數(shù)
整體也和正常模型一樣,但是這里使用了teacher_history
去保留需要知識(shí)蒸餾的數(shù)據(jù)。
def teacher_main():
epochs = 10
batch_size = 64
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 使用的設(shè)備類型
# 導(dǎo)入訓(xùn)練集
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 數(shù)據(jù)正則化
])),
batch_size=batch_size, shuffle=True)
# 導(dǎo)入測(cè)試集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 數(shù)據(jù)正則化
])),
batch_size=1000, shuffle=True)
model = TeacherNet().to(device) # 傳輸經(jīng)過(guò)教師模型網(wǎng)絡(luò)
optimizer = torch.optim.Adadelta(model.parameters()) # 使用Adadelta優(yōu)化器
teacher_history = [] # 記錄教師得到結(jié)果的歷史
for epoch in range(1, epochs + 1):
train_teacher(model, device, train_loader, optimizer, epoch) # 開(kāi)始訓(xùn)練模型
loss, acc = test_teacher(model, device, test_loader) # 計(jì)算損失函數(shù)和準(zhǔn)確率
teacher_history.append((loss, acc)) # 記錄教師模型得到的歷史數(shù)據(jù)
torch.save(model.state_dict(), "teacher.pt") # 保存到權(quán)重文件
return model, teacher_history
第六步:開(kāi)始訓(xùn)練教師模型
# 訓(xùn)練教師網(wǎng)絡(luò)
teacher_model, teacher_history = teacher_main()
第七步:定義學(xué)生模型網(wǎng)絡(luò)結(jié)構(gòu)
學(xué)生模型的網(wǎng)絡(luò)結(jié)構(gòu)定義時(shí)一般要比教師模型簡(jiǎn)單一些,這樣才能達(dá)到知識(shí)蒸餾輕量化的目的
class StudentNet(nn.Module):
def __init__(self):
super(StudentNet, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128) # 全連接層
self.fc2 = nn.Linear(128, 64) # 全連接層
self.fc3 = nn.Linear(64, 10) # 全連接層
def forward(self, x):
x = torch.flatten(x, 1) # 將輸入張量沿著第二維度平
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
output = F.relu(self.fc3(x))
return output
第八步:定義知識(shí)蒸餾方法
這里定義知識(shí)蒸餾主要是實(shí)現(xiàn)其損失函數(shù)。
def distillation(y, labels, teacher_scores, temp, alpha):
return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
我們這里寫(xiě)一下這個(gè)公式:
K
L
D
i
v
L
o
s
s
(
l
o
g
(
s
o
f
t
m
a
x
(
y
t
e
m
p
)
)
,
s
o
f
t
m
a
x
(
t
e
a
c
h
e
r
?
s
c
o
r
e
s
t
e
m
p
)
)
?
2
α
t
e
m
p
2
C
r
o
s
s
E
n
t
r
o
p
y
(
y
,
l
a
b
e
l
s
)
(
1
?
α
)
KLDivLoss(log(softmax(\frac{y}{temp})),softmax(\frac{teacher~scores}{temp}))*2\alpha temp^2\\CrossEntropy(y,labels)(1-\alpha)
KLDivLoss(log(softmax(tempy?)),softmax(tempteacher?scores?))?2αtemp2CrossEntropy(y,labels)(1?α)
其中
α
\alpha
α和
1
?
α
1-\alpha
1?α為系數(shù),
t
e
m
p
2
temp^2
temp2用于調(diào)節(jié)量綱。
第九步:定義學(xué)生模型訓(xùn)練和測(cè)試方法
學(xué)生模型訓(xùn)練部分和教師模型訓(xùn)練部分基本一樣,除了兩個(gè)部分。
第一個(gè)部分是,需要重點(diǎn)關(guān)注teacher_output = teacher_output.detach()
切斷教師模型反向傳播這一行。
第二個(gè)部分是,這里訓(xùn)練使用的損失函數(shù)是上面定義的知識(shí)蒸餾的損失函數(shù)
def train_student_kd(model, device, train_loader, optimizer, epoch):
model.train()
trained_samples = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data) # 學(xué)生模型前向傳播
teacher_output = teacher_model(data) # 教師模型前向傳播
teacher_output = teacher_output.detach() # 切斷老師網(wǎng)絡(luò)的反向傳播
loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7) # 計(jì)算總損失函數(shù),這里使用的是知識(shí)蒸餾的損失函數(shù)
loss.backward() # 反向傳播
optimizer.step() # 更新參數(shù)
trained_samples += len(data)
progress = math.ceil(batch_idx / len(train_loader) * 50)
print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
(epoch, trained_samples, len(train_loader.dataset),
'-' * progress + '>', progress * 2), end='')
def test_student_kd(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # 計(jì)算總的損失函數(shù)
pred = output.argmax(dim=1, keepdim=True) # 獲取最大對(duì)數(shù)概率索引
correct += pred.eq(target.view_as(pred)).sum().item() # 計(jì)算準(zhǔn)確率
test_loss /= len(test_loader.dataset)
print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
return test_loss, correct / len(test_loader.dataset)
第十步:定義學(xué)生模型主函數(shù)
def student_kd_main():
epochs = 10
batch_size = 64
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加載訓(xùn)練集
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
# 加載測(cè)試集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True)
# 加載學(xué)生模型
model = StudentNet().to(device)
optimizer = torch.optim.Adadelta(model.parameters())
student_history = [] # 記錄學(xué)生訓(xùn)練的模型
for epoch in range(1, epochs + 1):
train_student_kd(model, device, train_loader, optimizer, epoch)
loss, acc = test_student_kd(model, device, test_loader)
student_history.append((loss, acc))
torch.save(model.state_dict(), "student_kd.pt")
return model, student_history
student_kd_model, student_kd_history = student_kd_main()
知識(shí)蒸餾步驟總結(jié)
(1)先訓(xùn)練教師模型,定義教師模型的訓(xùn)練方法和測(cè)試方法
(2)定義知識(shí)蒸餾損失函數(shù)
(3)再訓(xùn)練學(xué)生模型,定義學(xué)生模型的訓(xùn)練方法和測(cè)試方法
(4)訓(xùn)練學(xué)生模型的時(shí)候需要將教師模型得到數(shù)據(jù)輸出經(jīng)過(guò)知識(shí)蒸餾作為輸入,并且要阻斷教師模型的反向傳播,并利用知識(shí)蒸餾損失函數(shù)進(jìn)行反向傳播文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-524288.html
(5)訓(xùn)練結(jié)束后得到學(xué)生模型文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-524288.html
到了這里,關(guān)于通俗易懂的知識(shí)蒸餾 Knowledge Distillation(下)——代碼實(shí)踐(附詳細(xì)注釋)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!