目錄
GAN的訓(xùn)練過程:
L1和L2損失函數(shù)的區(qū)別
基礎(chǔ)概念
相同點
差異
GAN的訓(xùn)練過程:
1、先定義一個標(biāo)簽:real = 1,fake = 0。當(dāng)然這兩個值的維度是按照數(shù)據(jù)的輸出來看的。再定義了兩個優(yōu)化器。用于生成器和判別器。
2、隨機生成一個噪聲z。將z作為生成器的輸入,輸出gen_imgs(假樣本)。
3、計算生成器的損失
定義:生成器的損失為g_loss。損失函數(shù)為adverisal_loss()。判別器為discriminator()。
g_loss = adverisal_loss(discriminator(gen_imgs), real)
g_loss.backward()
optimizer_G.step()
可以看出來,g_loss是根據(jù)一個輸出(將生成的樣本作為輸入的判別器的輸出)與real的一個損失。
1)discriminator(gen_imgs) 的輸出是個什么?
既然是判別器,意思就是判別gen_imgs是不是真樣本。如果是用softmax輸出,是一個概率,為真樣本的概率。
2)g_loss = adverisal_loss(discriminator(gen_imgs), real)
計算g_loss就是判別器的輸出與real的差距,讓g_loss越來越小,就是讓gen_imgs作為判別器的輸出的概率更接近valid。就是讓gen_imgs更像真樣本。
3)要注意的是,這個g_loss用于去更新了生成器的權(quán)重。這個時候,判別器的權(quán)重并沒有被更新。
4、分別把假樣本和真樣本都送入到判別器。
real_loss = adverisal_loss(discriminator(real_imgs), real)
fake_loss = adverisal_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
real_loss是判別器去判別真樣本的輸出,讓這個輸出更接近與real。
fake_loss是判別器去判別假樣本的輸出,讓這個輸出更接近與fake。
d_loss是前兩者的平均。
損失函數(shù)向后傳播,就是為了讓d_loss ---> 0。也就是讓:
real_loss ---> 0 ===> 讓判別器的輸出(真樣本概率)接近 real
fake_loss ---> 0 ===> 讓判別器的輸出(假樣本概率)接近 fake
也就是說,讓判別器按照真假樣本的類別,分別按照不同的要求去更新參數(shù)。
5、損失函數(shù)的走向?
g_loss 越小,說明生成器生產(chǎn)的假樣本作為判別器的輸入的輸出(概率)越接近real,就是生成的假樣本越像真樣本。
d_loss越小,說明判別器越能夠?qū)⒆R別出真樣本和假樣本。
所以,最后是要讓g_loss更小,d_loss更接近0.5。以至于d_loss最后為0.5的時候,達(dá)到最好的效果。這個0.5的意思就是:判別器將真樣本全部識別正確,所以real_loss=0。把所有的生成的假樣本識別錯誤(生成的樣本很真),此時fake_loss = 1。最后的d_loss = 1/2。
補充:
L1和L2損失函數(shù)的區(qū)別
基礎(chǔ)概念
??? L1損失函數(shù)又稱為MAE(mean abs error),即平均絕對誤差,也就是預(yù)測值和真實值之間差值的絕對值。
??? L2損失函數(shù)又稱為MSE(mean square error),即平均平方誤差,也就是預(yù)測值和真實值之間差值的平方。
相同點
??? 因為計算的方式類似,只有一個平方的差異,因此使用的場合都很相近,通常用于回歸任務(wù)中。文章來源:http://www.zghlxwxcb.cn/news/detail-844533.html
差異
??? 1)L2沒有L1魯棒,直觀來說,L2會將誤差平方,如果誤差大于1,則誤差會被放大很多,因此模型會對異常樣本更敏感,這樣會犧牲許多正常的樣本。當(dāng)訓(xùn)練集中含有更多異常值的時候,L1會更有效。
??? 2)如果是圖像重建任務(wù),如超分辨率、深度估計、視頻插幀等,L2會更加有效,這是由任務(wù)特性決定了,圖像重建任務(wù)中通常預(yù)測值和真實值之間的差異不大,因此需要用L2損失來放大差異,進(jìn)而指導(dǎo)模型的優(yōu)化。
??? 3)L1的問題在于它的梯度在極值點會發(fā)生躍變,并且很小的差異也會帶來很大的梯度,不利于學(xué)習(xí),因此在使用時通常會設(shè)定學(xué)習(xí)率衰減策略。而L2作為損失函數(shù)的時候本身由于其函數(shù)的特性,自身就會對梯度進(jìn)行縮放,因此有的任務(wù)在使用L2時甚至不會調(diào)整學(xué)習(xí)率,不過隨著現(xiàn)在的行業(yè)認(rèn)知,學(xué)習(xí)率衰減策略在很多場景中依然是獲得更優(yōu)模型的手段。
?文章來源地址http://www.zghlxwxcb.cn/news/detail-844533.html
到了這里,關(guān)于對抗生成網(wǎng)絡(luò)(GAN)中的損失函數(shù)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!