?論文:[1503.02531] Distilling the Knowledge in a Neural Network (arxiv.org)
知識(shí)蒸餾是一種模型壓縮方法,是一種基于“教師-學(xué)生網(wǎng)絡(luò)思想”的訓(xùn)練方式,由于其簡單,有效,并且已經(jīng)在工業(yè)界被廣泛應(yīng)用。
知識(shí)蒸餾使用的是Teacher—Student模型,其中teacher是“知識(shí)”的輸出者,student是“知識(shí)”的接受者。知識(shí)蒸餾的過程分為2個(gè)階段:
????????①原始模型訓(xùn)練: 訓(xùn)練"Teacher模型", 簡稱為Net-T,它的特點(diǎn)是模型相對復(fù)雜,也可以由多個(gè)分別訓(xùn)練的模型集成而成。我們對"Teacher模型"不作任何關(guān)于模型架構(gòu)、參數(shù)量、是否集成方面的限制,唯一的要求就是,對于輸入X, 其都能輸出Y,其中Y經(jīng)過softmax的映射,輸出值對應(yīng)相應(yīng)類別的概率值。
????????②精簡模型訓(xùn)練: 訓(xùn)練"Student模型", 簡稱為Net-S,它是參數(shù)量較小、模型結(jié)構(gòu)相對簡單的單模型。同樣的,對于輸入X,其都能輸出Y,Y經(jīng)過softmax映射后同樣能輸出對應(yīng)相應(yīng)類別的概率值。在本論文中,作者將問題限定在分類問題下,或者其他本質(zhì)上屬于分類問題的問題,該類問題的共同點(diǎn)是模型最后會(huì)有一個(gè)softmax層,其輸出值對應(yīng)了相應(yīng)類別的概率值。
現(xiàn)實(shí)中,由于我們不可能收集到某問題的所有數(shù)據(jù)來作為訓(xùn)練數(shù)據(jù),并且新數(shù)據(jù)總是在源源不斷的產(chǎn)生,因此我們只能退而求其次,訓(xùn)練目標(biāo)變成在已有的訓(xùn)練數(shù)據(jù)集上建模輸入和輸出之間的關(guān)系。由于訓(xùn)練數(shù)據(jù)集是對真實(shí)數(shù)據(jù)分布情況的采樣,訓(xùn)練數(shù)據(jù)集上的最優(yōu)解往往會(huì)多少偏離真正的最優(yōu)解。
而在知識(shí)蒸餾時(shí),由于我們已經(jīng)有了一個(gè)泛化能力較強(qiáng)的Net-T,我們在利用Net-T來蒸餾訓(xùn)練Net-S時(shí),可以直接讓Net-S去學(xué)習(xí)Net-T的泛化能力。一個(gè)很直白且高效的遷移泛化能力的方法就是使用softmax層輸出的類別的概率來作為“soft target”。
? ? ? ? ①傳統(tǒng)training過程(hard targets): 對ground truth求極大似然
? ? ? ? ②KD的training過程(soft targets): 用large model的class probabilities作為soft targets
?例子:
在MNIST手寫數(shù)字識(shí)別任務(wù)中
假設(shè)某個(gè)輸入的“2”更加形似"3",softmax的輸出值中"3"對應(yīng)的概率為0.1,而其他負(fù)標(biāo)簽對應(yīng)的值都很小,而另一個(gè)"2"更加形似"7","7"對應(yīng)的概率為0.1。這兩個(gè)"2"對應(yīng)的hard target的值是相同的,但是它們的soft target卻是不同的,由此我們可見soft target蘊(yùn)含著比hard target多的信息。并且soft target分布的熵相對高時(shí),其soft target蘊(yùn)含的知識(shí)就更豐富。
?兩個(gè)”2“的hard target相同而soft target不同。
這就解釋了為什么通過蒸餾的方法訓(xùn)練出的Net-S相比使用完全相同的模型結(jié)構(gòu)和訓(xùn)練數(shù)據(jù)只使用hard target的訓(xùn)練方法得到的模型,擁有更好的泛化能力。
溫度T
把其他類別的可能性放大,把他們的相對大小充分暴露出來,讓學(xué)生網(wǎng)絡(luò)更加強(qiáng)烈地知道這些非類別的信息。當(dāng)T=1時(shí),與之前沒有變化;當(dāng)T越大,曲線的波峰就會(huì)越來越平滑。
?
?知識(shí)蒸餾的過程:
第一步:有一個(gè)已經(jīng)訓(xùn)練好的Teacher model,把很多數(shù)據(jù)喂給Teacher model,再把數(shù)據(jù)喂給(未訓(xùn)練/半成品)Student model,兩個(gè)都是在T=t時(shí)經(jīng)過Softmax,然后計(jì)算這兩個(gè)的損失函數(shù)值,讓它們兩個(gè)越接近越好,學(xué)生在模擬老師的預(yù)測結(jié)果。
第二步:Student model在T=1情況下經(jīng)過softmax操作,把預(yù)測結(jié)果hard prediction和真實(shí)數(shù)據(jù)的結(jié)果hard label進(jìn)行求損失值,希望它們兩個(gè)越接近越好。
總結(jié):Student model(T=t)與Teacher model(T=t)的預(yù)測結(jié)果越來越接近;Student model(T=1)的預(yù)測結(jié)果與數(shù)據(jù)結(jié)果(標(biāo)準(zhǔn)答案)越來越接近。
Loss = k1*distillation Loss+k2*student Loss。(加權(quán)求和)
?????
在使用Student model時(shí)只需要輸入數(shù)據(jù)就行,不需要T,因?yàn)槟P偷膮?shù)已經(jīng)訓(xùn)練完成了,最后只需要經(jīng)過基礎(chǔ)softmax操作得到最終結(jié)果。
?實(shí)驗(yàn)結(jié)果:
使用MNIST數(shù)據(jù)集訓(xùn)練Teacher model,把MNIST數(shù)據(jù)集中去除”3“相關(guān)的所有數(shù)據(jù)集來訓(xùn)練Student model,實(shí)驗(yàn)結(jié)果證明,經(jīng)過知識(shí)蒸餾后,沒有學(xué)習(xí)過”3“的Student model可以識(shí)別出”3“。
Soft targets可以僅僅使用3%的訓(xùn)練集來訓(xùn)練并達(dá)到近似Teacher model的效果。
知識(shí)蒸餾的應(yīng)用場景:
①模型壓縮
②優(yōu)化訓(xùn)練,防止過擬合
③無限大、無監(jiān)督數(shù)據(jù)集的數(shù)據(jù)挖掘文章來源:http://www.zghlxwxcb.cn/news/detail-740747.html
④少樣本、零樣本學(xué)習(xí)文章來源地址http://www.zghlxwxcb.cn/news/detail-740747.html
到了這里,關(guān)于知識(shí)蒸餾(Knowledge Distillation)的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!