国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二)

這篇具有很好參考價(jià)值的文章主要介紹了通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二)。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問。

生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network, GAN)的原理

學(xué)習(xí)李宏毅機(jī)器學(xué)習(xí)課程總結(jié)。
前面學(xué)習(xí)了GAN的直觀的介紹,現(xiàn)在學(xué)習(xí)GAN的基本理論。現(xiàn)在我們來學(xué)習(xí)GAN背后的理論。

引言

假設(shè)x是一張圖片(一個(gè)高維向量),如64 * 64 * 3的圖片,每個(gè)圖片都是高維空間中的一個(gè)點(diǎn)。為了畫圖方便,我們就畫成二維上的點(diǎn)。在高維空間中,只有一小部分采樣出來的點(diǎn)符合我們的數(shù)據(jù)分布(如:整個(gè)圖中只有藍(lán)色區(qū)域采樣的點(diǎn)的才是人臉,其他地方的就不是)。
我們想要產(chǎn)生的圖片,其數(shù)據(jù)分布為Pdata。
目的: 讓機(jī)器找出這個(gè)分布。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

原始做法

在有GAN之前,人們?cè)趺醋錾扇蝿?wù)呢?

最大似然估計(jì) (Maximum likelihood estimate)。

  • 假設(shè)數(shù)據(jù)集的數(shù)據(jù)分布為Pdata(x)
    比如數(shù)據(jù)集為二次元人物,我們也不知道Pdata長什么樣
  • 假設(shè)生成數(shù)據(jù)分布為PG(x; θ)
    希望找到θ,使得PG(x; θ)和原始未知分布Pdata(x)越接近越好
    如:服從高斯分布,θ就是均值和方差
  • 從Pdata(x)里采樣一組樣本{x1, x2, …, xm}
  • 對(duì)每個(gè)樣本,計(jì)算其似然:PG(xi; θ)
    通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
    找到一個(gè)θ*,使得該似然值最大

下面有個(gè)很重要的概念:
最大似然估計(jì) = 最小KL散度

下面證明:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

注:求最大值的θ,多個(gè)log不影響,為了乘積變加和

我們可以先回顧一下KL散度的定義:
設(shè)P(x)和Q(x) 是隨機(jī)變量X 上的兩個(gè)概率分布,則在離散隨機(jī)變量的情形下,KL散度的定義為:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
在連續(xù)隨機(jī)變量的情形下,KL散度的定義為:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
接著上面的,所以:
下面多加了一項(xiàng)(紅框),對(duì)結(jié)果不影響對(duì)吧,是為了和KL散度有關(guān)。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
所以,生成模型目的等價(jià)為:最小化分布PG和分布Pdata的散度。

如何定義一個(gè)廣義的PG?
如果分布為簡單的高斯分布,我們可以計(jì)算PG(x; θ),但實(shí)際數(shù)據(jù)都是更復(fù)雜的數(shù)據(jù),有更復(fù)雜的分布,所以無法計(jì)算出PG的似然。怎么辦?有人提出Generator。

GAN的做法

Generator

圖像生成任務(wù)在80年代就有人做,那個(gè)時(shí)候人們就是用高斯模型做,但生成的圖片非常非常模糊,不管怎么調(diào)整均值和方差,都出不來想要的結(jié)果。所以需要更廣義的方法做生成任務(wù),即生成對(duì)抗網(wǎng)絡(luò)。

G怎么做生成呢?
從高斯分布中采樣的數(shù)據(jù)z(也可以是其他分布,,如均勻分布等,那到底哪種分布輸入好呢?其實(shí)都可以,對(duì)輸出的影響不是很大,因?yàn)镚都能給它變成更復(fù)雜的分布),輸入網(wǎng)絡(luò)G,得到輸出x。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
我們希望概率分布PG和Pdata越接近越好,也就是最小化它們的某種散度Divergency(有很多散度,不一定是KL散度)。

那怎么計(jì)算這個(gè)散度呢?
Pdata和PG的概率分布公式我們不知道,所以不知道怎么算。所以人們想到了判別器Discriminator。

Discriminator

雖然我們不知道Pdata和PG的概率分布公式,但我們可以從這兩堆數(shù)據(jù)里分別采樣一些出來。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
GAN的神奇之處就在于,可以通過D來量這兩堆數(shù)據(jù)之間的散度。

把從Pdata和PG分布里取出的樣本數(shù)據(jù)輸入D,訓(xùn)練:

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

D相當(dāng)于二分類器,希望對(duì)真數(shù)據(jù)Pdata,輸出分?jǐn)?shù)越大越好;對(duì)生成數(shù)據(jù)PG,輸出分?jǐn)?shù)越小越好。訓(xùn)練的D的結(jié)果,就會(huì)告訴我們PG和Pdata他們之間的散度有多大。

訓(xùn)D的時(shí)候,G的參數(shù)是固定住的。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
如果你機(jī)器學(xué)習(xí)基礎(chǔ)很好的話,就可以看出這個(gè)D的優(yōu)化函數(shù)和二分類器的式子一模一樣。

神奇的地方是,當(dāng)你訓(xùn)完D,你可以得到一個(gè)最小的loss或最大的V(D, G ),而這個(gè)值和某個(gè)JS散度有一些關(guān)系,甚至可以說它就是JS散度。

如果D很難區(qū)別兩類數(shù)據(jù)的不同,loss就下不去,目標(biāo)函數(shù)就不會(huì)得到最大,意味著這兩堆數(shù)據(jù)很相似很接近,他們之間的散度就是很小的。反之亦然。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

GAN的數(shù)學(xué)原理

證明

為什么訓(xùn)練目標(biāo)函數(shù)和散度有關(guān)呢?
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

下面證明:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
假設(shè):D(x)可以是任何函數(shù)

上式相當(dāng)于,找到一個(gè)D,讓積分里面的部分最大:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

為了看起來方便,讓Pdata = a, PG = b, D(x) = D。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

可得到如下,求導(dǎo),讓導(dǎo)數(shù)為0。就可得到D*
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
此時(shí)得到局部最大。
接下來,把剛才求得的D*代入目標(biāo)函數(shù):
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

得到下式:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

為了把它整理成像JS散度,就作一些變換,分子分母同除以2:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

把分子的1/2都提出來,放到前面,就是2log(1/2),或 -2log2。

最后式子可以寫成如下:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
回顧一下JS散度的公式:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

所以可以看到,目標(biāo)函數(shù)和JS散度的關(guān)系。

那如果把目標(biāo)函數(shù)寫的和上面的不一樣,那就是在量不同的散度。

現(xiàn)在看生成器G的目標(biāo)函數(shù),那就是盡量生成最真的數(shù)據(jù),讓PG和Pdata越接近,即讓它們之間散度最小。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

但, Div(PG, Pdata)沒有辦法算,但上面證明了最小化散度就等于最大化V(D, G)。所以可以把Div(PG, Pdata)替換掉,變成如下:

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
問題變成min&max問題,看著比較復(fù)雜,那么下面舉個(gè)簡單的例子來說明。

  • 假設(shè):我們只有三個(gè)生成器G。現(xiàn)在要求解下式:
    通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
    也就是這三個(gè)G是已知的,定的。橫坐標(biāo)代表D,假設(shè)D可以用一個(gè)參數(shù)來操作,橫坐標(biāo)在改變的時(shí)候,代表你選擇了不同的D,如藍(lán)色曲線所示,實(shí)際的D由幾百萬個(gè)的神經(jīng)網(wǎng)絡(luò)參數(shù)控制的,非常復(fù)雜,這里為了解釋原理只是簡化成一條曲線。
    那minG 和maxD 在圖中表示什么呢?
    固定G時(shí),曲線最大值紅色點(diǎn)表示max V(G,D),接下來尋找minG,這幾個(gè)G哪個(gè)最好呢?也就是找哪個(gè)最min,顯然三號(hào)生成器G3。
    綠線的高度就代表PG和Pdata的距離,即它們之間的散度。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

算法

如何求解:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

之前,我們學(xué)到訓(xùn)練GAN的步驟,固定G,訓(xùn)D,固定D,訓(xùn)G,然后重復(fù)該過程,這個(gè)過程其實(shí)就是在解該式。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

那為什么這個(gè)過程就是在解這個(gè)式子呢?

把藍(lán)框的這部分先用L(G)表示,就是假設(shè)最大的這個(gè)值是L(G)。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

那問題就變成,你要找個(gè)最好的G,使得L(G)值最小。這個(gè)問題就和一般網(wǎng)絡(luò)一樣,用梯度下降法求解。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

但是,現(xiàn)在有個(gè)麻煩的事,就是L(G)式子里有max,那L(G)還可不可以作微分呢?
可以的,比如有個(gè)式子f(x)長這樣:
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

不同的x值,對(duì)應(yīng)的f(x)不同,看看現(xiàn)在的x能讓哪個(gè)f(x)最大,就對(duì)哪個(gè)f(x)微分。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
f(x)的最大值就是我畫出的桃紅色線。
再通俗一點(diǎn)說,就是拿到一個(gè)x,求出f1(x),f2(x), f3(x),看誰的值最大,就把誰拿出來做微分。
比如有個(gè)x,先算出來f1(x)最大,然后梯度下降,比如向右移動(dòng)一點(diǎn),可能移動(dòng)到了另外一個(gè)區(qū)域f2,那就此時(shí)f2(x)最大。以此類推。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

解釋了就算函數(shù)有max,也可以求微分。那就接著解這個(gè)式子。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

也就是交替的用梯度下降訓(xùn)練G和D。G0得到D0*,對(duì)G做微分,得到G1,G1得到D1*,對(duì)G做微分,得到G2…
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

可以看到,這整個(gè)過程和GAN是一模一樣的。每一步背后的含義是什么呢?就是最小化JS散度。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
但是上圖中的JS散度后面打了個(gè)問號(hào),是什么意思呢?
因?yàn)檫@件事情未必等同于在最小化JS散度。
因?yàn)镚在不斷的更新,比如在G0時(shí),D0*得到的maxV,更新到G1時(shí),不一定還是maxV。

那為什么我們又說這個(gè)過程是在最小化JS散度?
因?yàn)槊看胃露际呛苄〉囊徊?,所以我們假設(shè)更新后的式子和原來的式子還是非常像的。
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
Tip:
所以在訓(xùn)GAN時(shí),G每次更新的不能太多,理論上訓(xùn)D的時(shí)候要更多的迭代次數(shù)來訓(xùn)到底,找到最大V,才是在量散度,而訓(xùn)G不需要太多次的迭代,如果訓(xùn)太多次,D就無法量散度。

實(shí)際訓(xùn)練

目標(biāo)函數(shù):
通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
計(jì)算該式,要求期望,實(shí)際上沒有辦法真算期望,就用sample代替。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)
這個(gè)式子就等同于訓(xùn)一個(gè)二分類器,是一個(gè)logistic regression邏輯回歸,就是它的輸出接了一個(gè)sigmoid,是介于0到1之間。
就等同于max V,兩個(gè)框里的內(nèi)容等價(jià)。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

總結(jié):
訓(xùn)D:量散度
訓(xùn)G:最小化散度

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

對(duì)于G來說,第一項(xiàng)和G無關(guān),所以紅線劃掉,只剩后半部分,在真實(shí)操作中,后面的1也去掉了。這兩個(gè)函數(shù)的趨勢(shì)是一樣的 ,但斜率不同,后面發(fā)現(xiàn)都訓(xùn)的起來,差不多。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

直觀理解GAN

G和D之間的關(guān)系是什么樣子的呢?
假設(shè)綠色是真實(shí)數(shù)據(jù)的分布,藍(lán)色是G生成的數(shù)據(jù)的分布,現(xiàn)在要訓(xùn)一個(gè)D,給綠色較高的分?jǐn)?shù),給藍(lán)色較低的分?jǐn)?shù)。D的目標(biāo)函數(shù)的值就是這兩堆數(shù)據(jù)的某個(gè)散度值。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

G會(huì)希望D給它生成的數(shù)據(jù)打高分,所以藍(lán)色就往接近綠色的分布移動(dòng)一點(diǎn),但可能一下跑太多了,跑動(dòng)綠色右邊去了,但沒關(guān)系,再訓(xùn)一次D,D的loss會(huì)比較大,說明這兩堆數(shù)據(jù)的散度是比較小的。然后這些點(diǎn)又順著梯度給的方向往左移,最后藍(lán)色的分布就和綠色分布越來越近,讓D分辨不出,最后D會(huì)壞掉。

通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二),深度學(xué)習(xí),生成對(duì)抗網(wǎng)絡(luò),人工智能,神經(jīng)網(wǎng)絡(luò)

代碼

基于pytorch的文章來源地址http://www.zghlxwxcb.cn/news/detail-525457.html

import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
# 生成原始噪點(diǎn)數(shù)據(jù)大小--latent_dim
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=500, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size)
# print(img_shape) 1 ,28,28
# print(int(np.prod(img_shape))) 784
cuda = True if torch.cuda.is_available() else False
 
 
# 生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
 
        # 參數(shù) 進(jìn)入32 出來 64  歸一化
        def block(in_feat, out_feat, normalize=True):
            # 對(duì)傳入數(shù)據(jù)應(yīng)用線性轉(zhuǎn)換(輸入節(jié)點(diǎn)數(shù),輸出節(jié)點(diǎn)數(shù))
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                # 批規(guī)范化
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
                # 激活函數(shù)
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
 
        # 模型定義
        self.model = nn.Sequential(
 
            *block(opt.latent_dim, 128, normalize=False),
 
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            # np.prod 用來計(jì)算所有元素的乘積
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
 
    # 正向傳播
    def forward(self, z):
        img = self.model(z)  # shape 64 784
        img = img.view(img.size(0), *img_shape)  # 64 1 28 28
        return img
 
 
# 判別模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
 
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
 
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)  # 64 1 28 28 =>64 784
        validity = self.model(img_flat)  # 64 784 =>64 1
 
        return validity
 
 
# Loss function 類似 目標(biāo)值-得到值 的差值一種運(yùn)算
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
# 如果有g(shù)pu
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
 
# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
print(opt.img_size)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            # 其他地方也許是Resize((opt.img_size,opt.img_size)) 也就是((28,28))因?yàn)楹罄m(xù)重塑格式類似于(64,1,28,28)
            # 這里是(28)  后面重塑格式類似于(64,1,28*28)
            # transforms.Normalize([0.5], [0.5])  這是單通道數(shù)據(jù)集
            # transforms.Normalize((0.5,0.5,0.5), (0.5),(0.5),(0.5))  三通道數(shù)據(jù)集
            # 圖片三個(gè)通道
            # 前一個(gè)(0.5,0.5,0.5)是設(shè)置的mean值 后一個(gè)(0.5,0.5,0.5)是是設(shè)置各通道的標(biāo)準(zhǔn)差
            # 其作用就是先將輸入歸一化到(0,1),再使用公式”(x-mean)/std”,將每個(gè)元素分布到(-1,1)
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    # 一次多少個(gè)處理,小圖片一般64個(gè)
    batch_size=opt.batch_size,
    # 數(shù)據(jù)集打亂,洗牌
    shuffle=True,
)
 
# Optimizers 優(yōu)化器
# lr=opt.lr學(xué)習(xí)率
# betas (Tuple[float, float],可選):用于計(jì)算的系數(shù)
# 梯度及其平方的運(yùn)行平均值(默認(rèn)值:(0.9,0.999))
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
 
# 判斷是否有g(shù)pu
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
#  Training
# ----------
 
for epoch in range(opt.n_epochs):
    # dataloader中的數(shù)據(jù)是一張圖片對(duì)應(yīng)一個(gè)標(biāo)簽,所以imgs對(duì)應(yīng)的是圖片,_對(duì)應(yīng)的是標(biāo)簽,而i是enumerate輸出的功能
    for i, (imgs, _) in enumerate(dataloader):
 
        # Adversarial ground truths
        # 這部分定義的相當(dāng)于是一個(gè)標(biāo)準(zhǔn),vaild可以想象成是64行1列的向量,就是為了在后面計(jì)算損失時(shí),和1比較;fake也是一樣是全為0的向量,用法和1的用法相同。
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)
 
        # Configure input
        # 這句是將真實(shí)的圖片轉(zhuǎn)化為神經(jīng)網(wǎng)絡(luò)可以處理的變量。變?yōu)門ensor
        # print(type(imgs)) Tensor
        real_imgs = Variable(imgs.type(Tensor))
        # print(type(real_imgs)) Tensor
        # -----------------
        #  Train Generator
        # -----------------
 
        # optimizer.zero_grad()意思是把梯度置零
        # 每次的訓(xùn)練之前都將上一次的梯度置為零,以避免上一次的梯度的干擾
        optimizer_G.zero_grad()
 
        # Sample noise as generator input
        # 這部分就是在上面訓(xùn)練生成網(wǎng)絡(luò)的z的輸入值,np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)的意思就是
        # 64個(gè)噪音(基礎(chǔ)值為100大小的) 0,代表正態(tài)分布的均值,1,代表正態(tài)分布的方差
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        # Generate a batch of images 返回一個(gè)批次即64個(gè)
        gen_imgs = generator(z)
 
        # Loss measures generator's ability to fool the discriminator
        # 計(jì)算這64個(gè)圖片總損失  生成器損失
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        # 反向傳播
        g_loss.backward()
        optimizer_G.step()
        # ---------------------
        #  Train Discriminator
        # ---------------------
        # 梯度清零
        optimizer_D.zero_grad()
        # Measure discriminator's ability to classify real from generated samples
        # 判別器判別真實(shí)圖片是真的的損失
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        # 判別器判別假的圖片是假的的損失
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        # 判別器去判別真實(shí)圖片是真的的概率大,并且判別假圖片是真的的概率小,說明判別器越準(zhǔn)確所以說是maxD,
        # 生成器就是想生成真實(shí)的圖片來迷惑判別器,所以理論上想讓生成器生成真實(shí)的圖片概率大,
        # 由于公式第二部分表示生成器的損失,G(z)前有個(gè)負(fù)號(hào),所以如果結(jié)果小則證明G生成的越真實(shí),所以說minG
        d_loss = (real_loss + fake_loss) / 2
 
        # 反向傳播
        d_loss.backward()
        optimizer_D.step()
 
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )
 
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

到了這里,關(guān)于通俗易懂生成對(duì)抗網(wǎng)絡(luò)GAN原理(二)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 【Pytorch深度學(xué)習(xí)實(shí)戰(zhàn)】(10)生成對(duì)抗網(wǎng)絡(luò)(GAN)

    【Pytorch深度學(xué)習(xí)實(shí)戰(zhàn)】(10)生成對(duì)抗網(wǎng)絡(luò)(GAN)

    ???大家好,我是Sonhhxg_柒,希望你看完之后,能對(duì)你有所幫助,不足請(qǐng)指正!共同學(xué)習(xí)交流?? ??個(gè)人主頁-Sonhhxg_柒的博客_CSDN博客??? ??歡迎各位→點(diǎn)贊?? + 收藏?? + 留言??? ??系列專欄 - 機(jī)器學(xué)習(xí)【ML】?自然語言處理【NLP】? 深度學(xué)習(xí)【DL】 ? ???foreword ?說

    2023年04月08日
    瀏覽(32)
  • 深度學(xué)習(xí)7:生成對(duì)抗網(wǎng)絡(luò) – Generative Adversarial Networks | GAN

    深度學(xué)習(xí)7:生成對(duì)抗網(wǎng)絡(luò) – Generative Adversarial Networks | GAN

    生成對(duì)抗網(wǎng)絡(luò) – GAN 是最近2年很熱門的一種無監(jiān)督算法,他能生成出非常逼真的照片,圖像甚至視頻。我們手機(jī)里的照片處理軟件中就會(huì)使用到它。 目錄 生成對(duì)抗網(wǎng)絡(luò) GAN 的基本原理 大白話版本 非大白話版本 第一階段:固定「判別器D」,訓(xùn)練「生成器G」 第二階段:固定

    2024年02月11日
    瀏覽(21)
  • PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(31)——生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network, GAN)

    PyTorch深度學(xué)習(xí)實(shí)戰(zhàn)(31)——生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network, GAN)

    生成對(duì)抗網(wǎng)絡(luò) ( Generative Adversarial Networks , GAN ) 是一種由兩個(gè)相互競爭的神經(jīng)網(wǎng)絡(luò)組成的深度學(xué)習(xí)模型,它由一個(gè)生成網(wǎng)絡(luò)和一個(gè)判別網(wǎng)絡(luò)組成,通過彼此之間的博弈來提高生成網(wǎng)絡(luò)的性能。生成對(duì)抗網(wǎng)絡(luò)使用神經(jīng)網(wǎng)絡(luò)生成與原始圖像集非常相似的新圖像,它在圖像生成中應(yīng)用

    2024年01月22日
    瀏覽(24)
  • 深度學(xué)習(xí)8:詳解生成對(duì)抗網(wǎng)絡(luò)原理

    深度學(xué)習(xí)8:詳解生成對(duì)抗網(wǎng)絡(luò)原理

    目錄 大綱 生成隨機(jī)變量 可以偽隨機(jī)生成均勻隨機(jī)變量 隨機(jī)變量表示為操作或過程的結(jié)果 逆變換方法 生成模型 我們?cè)噲D生成非常復(fù)雜的隨機(jī)變量…… …所以讓我們使用神經(jīng)網(wǎng)絡(luò)的變換方法作為函數(shù)! 生成匹配網(wǎng)絡(luò) 培養(yǎng)生成模型 比較基于樣本的兩個(gè)概率分布 反向傳播分布

    2024年02月11日
    瀏覽(21)
  • 深度學(xué)習(xí)9:簡單理解生成對(duì)抗網(wǎng)絡(luò)原理

    深度學(xué)習(xí)9:簡單理解生成對(duì)抗網(wǎng)絡(luò)原理

    目錄 生成算法 生成對(duì)抗網(wǎng)絡(luò)(GAN) “生成”部分 “對(duì)抗性”部分 GAN如何運(yùn)作? 培訓(xùn)GAN的技巧? GAN代碼示例 如何改善GAN? 結(jié)論 您可以將生成算法分組到三個(gè)桶中的一個(gè): 鑒于標(biāo)簽,他們預(yù)測(cè)相關(guān)的功能(樸素貝葉斯) 給定隱藏的表示,他們預(yù)測(cè)相關(guān)的特征(變分自動(dòng)編

    2024年02月10日
    瀏覽(16)
  • 深入探究生成對(duì)抗網(wǎng)絡(luò)(GAN):原理與代碼分析

    生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network, GAN)是一種強(qiáng)大的深度學(xué)習(xí)模型,由生成器和判別器兩個(gè)神經(jīng)網(wǎng)絡(luò)組成。GAN的目標(biāo)是讓生成器網(wǎng)絡(luò)生成逼真的樣本,以盡可能欺騙判別器網(wǎng)絡(luò),同時(shí)判別器網(wǎng)絡(luò)要盡可能準(zhǔn)確地區(qū)分真實(shí)樣本和生成樣本。 GAN在圖像生成領(lǐng)域非常流行。通

    2024年02月11日
    瀏覽(22)
  • 產(chǎn)品經(jīng)理看AIGC--GAN(生成對(duì)抗網(wǎng)絡(luò))白話原理

    《AIGC:智能創(chuàng)作時(shí)代》的閱讀隨筆(推薦單獨(dú)閱讀第二章,其余章節(jié)快速略過),期待從業(yè)務(wù)角度而非推導(dǎo)角度更好的理解,為產(chǎn)品從業(yè)人員提供更好的了解溝通渠道。 如何從白話角度理解生成對(duì)抗網(wǎng)絡(luò),核心在于如何理解“對(duì)抗”,通俗的字面理解“對(duì)抗”,我們會(huì)聯(lián)想

    2024年02月15日
    瀏覽(13)
  • 3易懂AI深度學(xué)習(xí)算法:長短期記憶網(wǎng)絡(luò)(Long Short-Term Memory, LSTM)生成對(duì)抗網(wǎng)絡(luò) 優(yōu)化算法進(jìn)化算法

    繼續(xù)寫:https://blog.csdn.net/chenhao0568/article/details/134920391?spm=1001.2014.3001.5502 1.https://blog.csdn.net/chenhao0568/article/details/134931993?spm=1001.2014.3001.5502 2.https://blog.csdn.net/chenhao0568/article/details/134932800?spm=1001.2014.3001.5502 長短期記憶網(wǎng)絡(luò)(LSTM)是一種特殊的循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),主要用于處

    2024年02月04日
    瀏覽(22)
  • 深度學(xué)習(xí)進(jìn)階篇[9]:對(duì)抗生成網(wǎng)絡(luò)GANs綜述、代表變體模型、訓(xùn)練策略、GAN在計(jì)算機(jī)視覺應(yīng)用和常見數(shù)據(jù)集介紹,以及前沿問題解決

    深度學(xué)習(xí)進(jìn)階篇[9]:對(duì)抗生成網(wǎng)絡(luò)GANs綜述、代表變體模型、訓(xùn)練策略、GAN在計(jì)算機(jī)視覺應(yīng)用和常見數(shù)據(jù)集介紹,以及前沿問題解決

    【深度學(xué)習(xí)入門到進(jìn)階】必看系列,含激活函數(shù)、優(yōu)化策略、損失函數(shù)、模型調(diào)優(yōu)、歸一化算法、卷積模型、序列模型、預(yù)訓(xùn)練模型、對(duì)抗神經(jīng)網(wǎng)絡(luò)等 專欄詳細(xì)介紹:【深度學(xué)習(xí)入門到進(jìn)階】必看系列,含激活函數(shù)、優(yōu)化策略、損失函數(shù)、模型調(diào)優(yōu)、歸一化算法、卷積模型、

    2024年02月08日
    瀏覽(29)
  • 大數(shù)據(jù)機(jī)器學(xué)習(xí)GAN:生成對(duì)抗網(wǎng)絡(luò)GAN全維度介紹與實(shí)戰(zhàn)

    大數(shù)據(jù)機(jī)器學(xué)習(xí)GAN:生成對(duì)抗網(wǎng)絡(luò)GAN全維度介紹與實(shí)戰(zhàn)

    本文為生成對(duì)抗網(wǎng)絡(luò)GAN的研究者和實(shí)踐者提供全面、深入和實(shí)用的指導(dǎo)。通過本文的理論解釋和實(shí)際操作指南,讀者能夠掌握GAN的核心概念,理解其工作原理,學(xué)會(huì)設(shè)計(jì)和訓(xùn)練自己的GAN模型,并能夠?qū)Y(jié)果進(jìn)行有效的分析和評(píng)估。 生成對(duì)抗網(wǎng)絡(luò)(GAN)是深度學(xué)習(xí)的一種創(chuàng)新架

    2024年02月03日
    瀏覽(16)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包