簡介:GAN生成對抗網(wǎng)絡(luò)本質(zhì)上是一種思想,其依靠神經(jīng)網(wǎng)絡(luò)能夠擬合任意函數(shù)的能力,設(shè)計了一種架構(gòu)來實現(xiàn)數(shù)據(jù)的生成。
原理:GAN的原理就是最小化生成器Generator的損失,但是在最小化損失的過程中加入了一個約束,這個約束就是使Generator生成的數(shù)據(jù)滿足我們指定數(shù)據(jù)的分布,GAN的巧妙之處在于使用一個神經(jīng)網(wǎng)絡(luò)(鑒別器Discriminator)來自動判斷生成的數(shù)據(jù)是否符合我們所需要的分布。
實現(xiàn)細(xì)節(jié):
一:
????????準(zhǔn)備好我們想要讓生成器生成的數(shù)據(jù)類型,比如MINIST手寫數(shù)字集,包含1-10十個數(shù)字,一共60000張圖片。生成器的目的就是學(xué)習(xí)這個數(shù)據(jù)集的分布。
二,
????????定義一個生成器,用于判別一張圖片是實際的還是生成器生成的,當(dāng)生成器完美學(xué)習(xí)得到數(shù)據(jù)分布之后,鑒別器可能就分不清圖片是生成器的還是實際的,這樣的話生成器就能生成我們想要的圖片了。
????????生成器的訓(xùn)練過程為:實際數(shù)據(jù)輸出結(jié)果1,生成數(shù)據(jù)輸出結(jié)果為0,目的是學(xué)會區(qū)分真假數(shù)據(jù),相當(dāng)于提供一個約束,使生成數(shù)據(jù)符合指定分布。當(dāng)鑒別生成器的數(shù)據(jù)分布時,只需要更新鑒別器的參數(shù)權(quán)重,不能夠通過計算圖將生成器的參數(shù)進(jìn)行更新。
三,
????????定義一個生成器,給定一個輸入,他就能生成1-10里面的一個數(shù)字的圖片。生成器的反向更新是根據(jù)鑒別器的損失來確定(被約束進(jìn)行反向更新)。生成器的網(wǎng)絡(luò)權(quán)重參數(shù)是單獨(dú)的,反向更新時,只需要更新計算圖當(dāng)中屬于生成器部分的參數(shù)。
下面給出生成1-0-1-0數(shù)據(jù)格式的代碼:文章來源:http://www.zghlxwxcb.cn/news/detail-667857.html
# %% import torch import numpy import torch.nn as nn import matplotlib.pyplot as plt # %% def gennerate1010(): return torch.FloatTensor([numpy.random.uniform(0.9,1.1), numpy.random.uniform(0.,.1), numpy.random.uniform(0.9,1.1), numpy.random.uniform(0.0,.1)]) # %% def genneratexxxx(): return torch.rand(4) # %% class Discrimer(nn.Module): def __init__(self) -> None: father_obj = super(Discrimer,self) father_obj.__init__() self.create_model() self.counter = 0 self.progress = [] def create_model(self): self.model = nn.Sequential( nn.Linear(4,3), nn.Sigmoid(), nn.Linear(3,1), nn.Sigmoid(), ) self.loss_functon = nn.MSELoss() self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01) def forward(self,x): return self.model(x) def train(self,x,targets): outputs = self.forward(x) loss = self.loss_functon(outputs,targets) self.counter += 1 if self.counter%10 == 0: self.progress.append(loss.item()) if self.counter%10000 == 0: print(self.counter) self.optimiser.zero_grad() loss.backward() self.optimiser.step() def plotprogress(self): plt.plot(self.progress,marker='*') plt.show() # %% class Gennerater(nn.Module): def __init__(self) -> None: father_obj = super(Gennerater,self) father_obj.__init__() self.create_model() self.counter = 0 self.progress = [] def create_model(self): self.model = nn.Sequential( nn.Linear(1,3), nn.Sigmoid(), nn.Linear(3,4), nn.Sigmoid(), ) # 這個優(yōu)化器只能優(yōu)化生成器部分的參數(shù) self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01) def forward(self,x): return self.model(x) def train(self,D,x,targets): g_outputs = self.forward(x) d_outputs = D.forward(g_outputs) # 使用鑒別器的loss函數(shù),但是只更新生成器的參數(shù),生成器的參數(shù)需要根據(jù)鑒別器的約束進(jìn)行更新 loss = D.loss_functon(d_outputs,targets) self.counter += 1 if self.counter%10 == 0: self.progress.append(loss.item()) if self.counter%10000 == 0: print(self.counter) self.optimiser.zero_grad() loss.backward() self.optimiser.step() def plotprogress(self): plt.plot(self.progress,marker='*') plt.show() # %% D = Discrimer() # %% G = Gennerater() # %% for id in range(15000): # 喂入實際數(shù)據(jù)給鑒別器 D.train(gennerate1010(),torch.FloatTensor([1.])) # 喂入生成的數(shù)據(jù),使用detach從計算圖脫離,用于更新鑒別器,而生成器得不到更新 D.train(G.forward(torch.FloatTensor([0.5]).detach()),torch.FloatTensor([0.0])) G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.])) # %% D.plotprogress() # %% G.plotprogress() # %% G.forward(torch.FloatTensor([0.5]))
參考:PyTorch生成對抗網(wǎng)絡(luò)編程文章來源地址http://www.zghlxwxcb.cn/news/detail-667857.html
到了這里,關(guān)于GAN(生成對抗網(wǎng)絡(luò))的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!