国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

通俗易懂的知識(shí)蒸餾 Knowledge Distillation(下)——代碼實(shí)踐(附詳細(xì)注釋)

這篇具有很好參考價(jià)值的文章主要介紹了通俗易懂的知識(shí)蒸餾 Knowledge Distillation(下)——代碼實(shí)踐(附詳細(xì)注釋)。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問(wèn)。

第一步:導(dǎo)入所需要的包

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data

torch.manual_seed(0)		# 為CPU設(shè)置種子
torch.cuda.manual_seed(0)	# 為GPU設(shè)置種子

第二步:定義教師模型

教師模型網(wǎng)絡(luò)結(jié)構(gòu)(此處僅舉一個(gè)例子):卷積層-卷積層-dropout-dropout-全連接層-全連接層

class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)		# 卷積層
        self.conv2 = nn.Conv2d(32, 64, 3, 1)	# 卷積層
        self.dropout1 = nn.Dropout2d(0.3)		# dropout
        self.dropout2 = nn.Dropout2d(0.5)		# dropout
        self.fc1 = nn.Linear(9216, 128)			# 全連接層
        self.fc2 = nn.Linear(128, 10)			# 全連接層

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)			# 激活函數(shù)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        output = self.fc2(x)
        return output

第三步:定義訓(xùn)練教師模型方法

正常的定義一個(gè)神經(jīng)網(wǎng)絡(luò)模型

def train_teacher(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)	# 將數(shù)據(jù)轉(zhuǎn)移到CPU/GPU
        optimizer.zero_grad()	# 優(yōu)化器將梯度全部置為0
        output = model(data)	# 數(shù)據(jù)經(jīng)過(guò)模型向前傳播
        loss = F.cross_entropy(output, target)  # 計(jì)算損失函數(shù)
        loss.backward()			# 反向傳播
        optimizer.step()		# 更新梯度

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50) # 計(jì)算訓(xùn)練進(jìn)度
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

第四步:定義教師模型測(cè)試方法

正常的定義一個(gè)神經(jīng)網(wǎng)絡(luò)模型

def test_teacher(model, device, test_loader):
    model.eval()  # 設(shè)置為評(píng)估模式
    test_loss = 0
    correct = 0
    with torch.no_grad():  # 不計(jì)算梯度,減少計(jì)算量
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # 將數(shù)據(jù)轉(zhuǎn)移到CPU/GPU
            output = model(data)  # 經(jīng)過(guò)模型正向傳播得到結(jié)果
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # 計(jì)算總的損失函數(shù)
            pred = output.argmax(dim=1, keepdim=True)  # 獲取最大對(duì)數(shù)概率索引
            correct += pred.eq(target.view_as(pred)).sum().item()  # pred.eq(target.view_as(pred)) 會(huì)返回一個(gè)布爾張量,其中每個(gè)元素表示預(yù)測(cè)值是否等于目標(biāo)值。然后,.sum().item() 會(huì)將所有為 True 的元素相加,從而得到正確分類的數(shù)量。 

    test_loss /= len(test_loader.dataset)  # 計(jì)算損失函數(shù)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

第五步:定義教師模型主函數(shù)

整體也和正常模型一樣,但是這里使用了teacher_history去保留需要知識(shí)蒸餾的數(shù)據(jù)。

def teacher_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 使用的設(shè)備類型
    
	# 導(dǎo)入訓(xùn)練集
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))  # 數(shù)據(jù)正則化
                       ])),
        batch_size=batch_size, shuffle=True)
    
    # 導(dǎo)入測(cè)試集
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))	# 數(shù)據(jù)正則化
        ])),
        batch_size=1000, shuffle=True)

    model = TeacherNet().to(device)  # 傳輸經(jīng)過(guò)教師模型網(wǎng)絡(luò)
    optimizer = torch.optim.Adadelta(model.parameters())  # 使用Adadelta優(yōu)化器
    
    teacher_history = []  # 記錄教師得到結(jié)果的歷史
    for epoch in range(1, epochs + 1):
        train_teacher(model, device, train_loader, optimizer, epoch)  # 開(kāi)始訓(xùn)練模型
        loss, acc = test_teacher(model, device, test_loader)  # 計(jì)算損失函數(shù)和準(zhǔn)確率
        teacher_history.append((loss, acc))  # 記錄教師模型得到的歷史數(shù)據(jù)

    torch.save(model.state_dict(), "teacher.pt")  # 保存到權(quán)重文件
    return model, teacher_history

第六步:開(kāi)始訓(xùn)練教師模型

# 訓(xùn)練教師網(wǎng)絡(luò)
teacher_model, teacher_history = teacher_main()

第七步:定義學(xué)生模型網(wǎng)絡(luò)結(jié)構(gòu)

學(xué)生模型的網(wǎng)絡(luò)結(jié)構(gòu)定義時(shí)一般要比教師模型簡(jiǎn)單一些,這樣才能達(dá)到知識(shí)蒸餾輕量化的目的

class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)	# 全連接層
        self.fc2 = nn.Linear(128, 64)		# 全連接層
        self.fc3 = nn.Linear(64, 10)		# 全連接層

    def forward(self, x):
        x = torch.flatten(x, 1)		# 將輸入張量沿著第二維度平
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        output = F.relu(self.fc3(x))
        return output

第八步:定義知識(shí)蒸餾方法

這里定義知識(shí)蒸餾主要是實(shí)現(xiàn)其損失函數(shù)。

def distillation(y, labels, teacher_scores, temp, alpha):
    return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * (
            temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

我們這里寫(xiě)一下這個(gè)公式:
K L D i v L o s s ( l o g ( s o f t m a x ( y t e m p ) ) , s o f t m a x ( t e a c h e r ? s c o r e s t e m p ) ) ? 2 α t e m p 2 C r o s s E n t r o p y ( y , l a b e l s ) ( 1 ? α ) KLDivLoss(log(softmax(\frac{y}{temp})),softmax(\frac{teacher~scores}{temp}))*2\alpha temp^2\\CrossEntropy(y,labels)(1-\alpha) KLDivLoss(log(softmax(tempy?)),softmax(tempteacher?scores?))?2αtemp2CrossEntropy(y,labels)(1?α)
其中 α \alpha α 1 ? α 1-\alpha 1?α為系數(shù), t e m p 2 temp^2 temp2用于調(diào)節(jié)量綱。

第九步:定義學(xué)生模型訓(xùn)練和測(cè)試方法

學(xué)生模型訓(xùn)練部分和教師模型訓(xùn)練部分基本一樣,除了兩個(gè)部分。

第一個(gè)部分是,需要重點(diǎn)關(guān)注teacher_output = teacher_output.detach()切斷教師模型反向傳播這一行。

第二個(gè)部分是,這里訓(xùn)練使用的損失函數(shù)是上面定義的知識(shí)蒸餾的損失函數(shù)

def train_student_kd(model, device, train_loader, optimizer, epoch):
    model.train()
    trained_samples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)  # 學(xué)生模型前向傳播
        teacher_output = teacher_model(data)  # 教師模型前向傳播
        teacher_output = teacher_output.detach()  # 切斷老師網(wǎng)絡(luò)的反向傳播
        loss = distillation(output, target, teacher_output, temp=5.0, alpha=0.7)  # 計(jì)算總損失函數(shù),這里使用的是知識(shí)蒸餾的損失函數(shù)
        loss.backward()  # 反向傳播
        optimizer.step()  # 更新參數(shù)

        trained_samples += len(data)
        progress = math.ceil(batch_idx / len(train_loader) * 50)
        print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
              (epoch, trained_samples, len(train_loader.dataset),
               '-' * progress + '>', progress * 2), end='')

def test_student_kd(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # 計(jì)算總的損失函數(shù)
            pred = output.argmax(dim=1, keepdim=True)  # 獲取最大對(duì)數(shù)概率索引
            correct += pred.eq(target.view_as(pred)).sum().item()  # 計(jì)算準(zhǔn)確率
    test_loss /= len(test_loader.dataset)

    print('\nTest: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, correct / len(test_loader.dataset)

第十步:定義學(xué)生模型主函數(shù)

def student_kd_main():
    epochs = 10
    batch_size = 64
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	# 加載訓(xùn)練集
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_size, shuffle=True)
    # 加載測(cè)試集
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data/MNIST', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=1000, shuffle=True)
	# 加載學(xué)生模型
    model = StudentNet().to(device)
    optimizer = torch.optim.Adadelta(model.parameters())
    
    student_history = []  # 記錄學(xué)生訓(xùn)練的模型
    for epoch in range(1, epochs + 1):
        train_student_kd(model, device, train_loader, optimizer, epoch)
        loss, acc = test_student_kd(model, device, test_loader)
        student_history.append((loss, acc))

    torch.save(model.state_dict(), "student_kd.pt")
    return model, student_history
student_kd_model, student_kd_history = student_kd_main()

知識(shí)蒸餾步驟總結(jié)

(1)先訓(xùn)練教師模型,定義教師模型的訓(xùn)練方法和測(cè)試方法

(2)定義知識(shí)蒸餾損失函數(shù)

(3)再訓(xùn)練學(xué)生模型,定義學(xué)生模型的訓(xùn)練方法和測(cè)試方法

(4)訓(xùn)練學(xué)生模型的時(shí)候需要將教師模型得到數(shù)據(jù)輸出經(jīng)過(guò)知識(shí)蒸餾作為輸入,并且要阻斷教師模型的反向傳播,并利用知識(shí)蒸餾損失函數(shù)進(jìn)行反向傳播

(5)訓(xùn)練結(jié)束后得到學(xué)生模型文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-524288.html

到了這里,關(guān)于通俗易懂的知識(shí)蒸餾 Knowledge Distillation(下)——代碼實(shí)踐(附詳細(xì)注釋)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來(lái)自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 擴(kuò)散模型相關(guān)論文閱讀,擴(kuò)散模型和知識(shí)蒸餾的結(jié)合提升預(yù)測(cè)速度:Progressive Distillation for Fast Sampling of Diffusion Models

    擴(kuò)散模型相關(guān)論文閱讀,擴(kuò)散模型和知識(shí)蒸餾的結(jié)合提升預(yù)測(cè)速度:Progressive Distillation for Fast Sampling of Diffusion Models

    谷歌research的成果,ICLR 2022 https://arxiv.org/abs/2202.00512 tenserflow官方開(kāi)源代碼: https://github.com/google-research/google-research/tree/master/diffusion_distillation pytorch非官方代碼:https://github.com/lucidrains/imagen-pytorch 1.擴(kuò)散模型雖然取得了很好的效果,但是預(yù)測(cè)速度慢。 2.作者提出了一種逐步蒸餾

    2024年02月16日
    瀏覽(20)
  • [讀論文][backbone]Knowledge Diffusion for Distillation

    DiffKD 摘要 The representation gap between teacher and student is an emerging topic in knowledge distillation (KD). To reduce the gap and improve the performance, current methods often resort to complicated training schemes, loss functions, and feature alignments, which are task-specific and feature-specific. In this paper, we state that the essence of the

    2024年02月08日
    瀏覽(32)
  • 高頻知識(shí)匯總 |【計(jì)算機(jī)網(wǎng)絡(luò)】面試題匯總(萬(wàn)字長(zhǎng)文通俗易懂)

    高頻知識(shí)匯總 |【計(jì)算機(jī)網(wǎng)絡(luò)】面試題匯總(萬(wàn)字長(zhǎng)文通俗易懂)

    這篇【計(jì)算機(jī)網(wǎng)絡(luò)】是我在學(xué)習(xí)時(shí)自己整理的,大部分都是按我個(gè)人的理解來(lái)寫(xiě)的答案。廢話不多說(shuō)直接鋪干貨。 這就一道題,回答時(shí)遵循如下原則即可: 橫向看(為哪兩個(gè)平行的東西提供服務(wù)) 縱向看(為上層提供什么服務(wù)) 協(xié)議舉例 應(yīng)用層: 為應(yīng)用程序間提供通信和

    2024年02月09日
    瀏覽(22)
  • 建造者模式深入理解:演示建造單個(gè)和多個(gè)產(chǎn)品的實(shí)踐,結(jié)合模板模式,通俗易懂

    建造者模式深入理解:演示建造單個(gè)和多個(gè)產(chǎn)品的實(shí)踐,結(jié)合模板模式,通俗易懂

    首先呢看下建造者的定義是什么樣的,先讀一遍 建造者模式 (Builder Pattern)是一種創(chuàng)建型設(shè)計(jì)模式,它主要用于將一個(gè)復(fù)雜對(duì)象的構(gòu)建過(guò)程與它的表示分離,使得同樣的構(gòu)建過(guò)程可以創(chuàng)建不同的表現(xiàn)形式。這種模式通過(guò)一系列可重用的獨(dú)立的類(稱為建造者或構(gòu)建器)來(lái)一

    2024年01月22日
    瀏覽(28)
  • 深度學(xué)習(xí)概念(術(shù)語(yǔ)):Fine-tuning、Knowledge Distillation, etc

    這里的相關(guān)概念都是基于已有預(yù)訓(xùn)練模型,就是模型本身已經(jīng)訓(xùn)練好,有一定泛化能力。需要“再加工”滿足別的任務(wù)需求。 進(jìn)入后GPT時(shí)代,對(duì)模型的Fine-tuning也將成為趨勢(shì),借此機(jī)會(huì),我來(lái)科普下相關(guān)概念。 有些人認(rèn)為微調(diào)和訓(xùn)練沒(méi)有區(qū)別,都是訓(xùn)練模型,但是微調(diào)是在原

    2024年02月09日
    瀏覽(61)
  • 【論文閱讀】SKDBERT: Compressing BERT via Stochastic Knowledge Distillation

    【論文閱讀】SKDBERT: Compressing BERT via Stochastic Knowledge Distillation

    2022-2023年論文系列之模型輕量化和推理加速 通過(guò)Connected Papers搜索引用PaBEE/DeeBERT/FastBERT的最新工作,涵蓋: 模型推理加速 邊緣設(shè)備應(yīng)用 生成模型 BERT模型 知識(shí)蒸餾 SmartBERT: A Promotion of Dynamic Early Exiting Mechanism for Accelerating BERT Inference SKDBERT: Compressing BERT via Stochastic Knowledge Di

    2024年02月12日
    瀏覽(25)
  • 回溯法解01背包問(wèn)題(最通俗易懂,附C++代碼)

    回溯法解01背包問(wèn)題(最通俗易懂,附C++代碼)

    01背包問(wèn)題是算法中的經(jīng)典問(wèn)題,問(wèn)題描述如下: 對(duì)于給定的N個(gè)物品,第i個(gè)物品的重量為Wi,價(jià)值為Vi,對(duì)于一個(gè)最多能裝重量C的背包,應(yīng)該如何選擇放入包中的物品,使得包中物品的總價(jià)值最大? 回溯法的本質(zhì)其實(shí)就是一種蠻力法,只是通過(guò)一定的方法可以使得蠻力法中

    2023年04月08日
    瀏覽(26)
  • SVM(支持向量機(jī))進(jìn)行分類的原理和python代碼----通俗易懂

    SVM(支持向量機(jī))進(jìn)行分類的原理和python代碼----通俗易懂

    SVM(支持向量機(jī),Support Vector Machine)是一種非常流行的機(jī)器學(xué)習(xí)算法,可用于二分類和多分類問(wèn)題。其基本思想是通過(guò)在不同類別的樣本之間構(gòu)建最大化分類間隔的線性或非線性超平面來(lái)實(shí)現(xiàn)分類。 SVM分類的基本步驟如下: 根據(jù)訓(xùn)練集數(shù)據(jù),選取最優(yōu)的超平面(通常為線性

    2024年02月11日
    瀏覽(24)
  • 實(shí)際開(kāi)發(fā)中常用的設(shè)計(jì)模式--------策略模式(知識(shí)跟業(yè)務(wù)場(chǎng)景結(jié)合)-----小白也能看懂(通俗易懂版本)

    1.策略模式定義: 策略模式是一種行為型設(shè)計(jì)模式,它允許在運(yùn)行時(shí)動(dòng)態(tài)地改變對(duì)象的行為。策略模式將將每一個(gè)算法封裝到具有共同接口的獨(dú)立的類中,從而使得它們可以相互替換從而使得算法的變化不會(huì)影響到客戶端 2.簡(jiǎn)單的策略模式示例代碼: 在上述代碼中,SortStra

    2024年02月13日
    瀏覽(28)
  • 論文筆記|CVPR2023:Supervised Masked Knowledge Distillation for Few-Shot Transformers

    論文筆記|CVPR2023:Supervised Masked Knowledge Distillation for Few-Shot Transformers

    這篇論文的題目是 用于小樣本Transformers的監(jiān)督遮掩知識(shí)蒸餾 論文接收: CVPR 2023 論文地址: https://arxiv.org/pdf/2303.15466.pdf 代碼鏈接: https://github.com/HL-hanlin/SMKD 1.ViT在小樣本學(xué)習(xí)(只有少量標(biāo)記數(shù)據(jù)的小型數(shù)據(jù)集)中往往會(huì) 過(guò)擬合,并且由于缺乏 歸納偏置 而導(dǎo)致性能較差;

    2024年02月06日
    瀏覽(26)

覺(jué)得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包