???大家好,我是Sonhhxg_柒,希望你看完之后,能對你有所幫助,不足請指正!共同學(xué)習(xí)交流??
??個人主頁-Sonhhxg_柒的博客_CSDN博客???
??歡迎各位→點贊?? + 收藏?? + 留言???
??系列專欄 - 機(jī)器學(xué)習(xí)【ML】?自然語言處理【NLP】? 深度學(xué)習(xí)【DL】
?
???foreword
?說明?本人講解主要包括Python、機(jī)器學(xué)習(xí)(ML)、深度學(xué)習(xí)(DL)、自然語言處理(NLP)等內(nèi)容。
如果你對這個系列感興趣的話,可以關(guān)注訂閱喲??
生成對抗網(wǎng)絡(luò) GAN 的基本原理
說到GAN第一篇要看的paper當(dāng)然是Ian Goodfellow大牛的Generative Adversarial Networks(arxiv:https://arxiv.org/abs/1406.2661),這篇paper算是這個領(lǐng)域的開山之作。
GAN的基本原理其實非常簡單,這里以生成圖片為例進(jìn)行說明。假設(shè)我們有兩個網(wǎng)絡(luò),G(Generator)和D(Discriminator)。正如它的名字所暗示的那樣,它們的功能分別是:
- G是一個生成圖片的網(wǎng)絡(luò),它接收一個隨機(jī)的噪聲z,通過這個噪聲生成圖片,記做G(z)。
- D是一個判別網(wǎng)絡(luò),判別一張圖片是不是“真實的”。它的輸入?yún)?shù)是x,x代表一張圖片,輸出D(x)代表x為真實圖片的概率,如果為1,就代表100%是真實的圖片,而輸出為0,就代表不可能是真實的圖片。
在訓(xùn)練過程中,生成網(wǎng)絡(luò)G的目標(biāo)就是盡量生成真實的圖片去欺騙判別網(wǎng)絡(luò)D。而D的目標(biāo)就是盡量把G生成的圖片和真實的圖片分別開來。這樣,G和D構(gòu)成了一個動態(tài)的“博弈過程”。
最后博弈的結(jié)果是什么?在最理想的狀態(tài)下,G可以生成足以“以假亂真”的圖片G(z)。對于D來說,它難以判定G生成的圖片究竟是不是真實的,因此D(G(z)) = 0.5。
這樣我們的目的就達(dá)成了:我們得到了一個生成式的模型G,它可以用來生成圖片。
以上只是大致說了一下GAN的核心原理,如何用數(shù)學(xué)語言描述呢?這里直接摘錄論文里的公式:
簡單分析一下這個公式:
- 整個式子由兩項構(gòu)成。x表示真實圖片,z表示輸入G網(wǎng)絡(luò)的噪聲,而G(z)表示G網(wǎng)絡(luò)生成的圖片。
- D(x)表示D網(wǎng)絡(luò)判斷真實圖片是否真實的概率(因為x就是真實的,所以對于D來說,這個值越接近1越好)。而D(G(z))是D網(wǎng)絡(luò)判斷G生成的圖片的是否真實的概率。
- G的目的:上面提到過,D(G(z))是D網(wǎng)絡(luò)判斷G生成的圖片是否真實的概率,G應(yīng)該希望自己生成的圖片“越接近真實越好”。也就是說,G希望D(G(z))盡可能得大,這時V(D, G)會變小。因此我們看到式子的最前面的記號是min_G。
- D的目的:D的能力越強(qiáng),D(x)應(yīng)該越大,D(G(x))應(yīng)該越小。這時V(D,G)會變大。因此式子對于D來說是求最大(max_D)
下面這幅圖片很好地描述了這個過程:
那么如何用隨機(jī)梯度下降法訓(xùn)練D和G?論文中也給出了算法:
文章來源:http://www.zghlxwxcb.cn/news/detail-405239.html
這里紅框圈出的部分是我們要額外注意的。第一步我們訓(xùn)練D,D是希望V(G, D)越大越好,所以是加上梯度(ascending)。第二步訓(xùn)練G時,V(G, D)越小越好,所以是減去梯度(descending)。整個訓(xùn)練過程交替進(jìn)行。文章來源地址http://www.zghlxwxcb.cn/news/detail-405239.html
生成對抗網(wǎng)絡(luò)Pytorch的實現(xiàn)
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# 設(shè)備配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 超參數(shù)
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
# 如果不存在則創(chuàng)建目錄
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
# 圖像處理
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
# std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], # 1 for greyscale channels
std=[0.5])])
# MNIST 數(shù)據(jù)集
mnist = torchvision.datasets.MNIST(root='../../data/',
train=True,
transform=transform,
download=True)
# 數(shù)據(jù)加載器
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
# 鑒別器
D = nn.Sequential(
nn.Linear(image_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())
# 生成器
G = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh())
# 設(shè)備設(shè)置
D = D.to(device)
G = G.to(device)
# 二元交叉熵?fù)p失和優(yōu)化器
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()
# 開始訓(xùn)練
total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device)
# 創(chuàng)建稍后用作 BCE 損失輸入的標(biāo)簽
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# ================================================================== #
# 訓(xùn)練判別器 #
# ================================================================== #
# 使用真實圖像計算 BCE_Loss 其中 BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
# 損失的第二項總是為零,因為 real_labels == 1
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 使用假圖像計算 BCELoss
# 損失的第一項總是為零,因為 fake_labels == 0
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
# 反向傳播和優(yōu)化
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()
# ================================================================== #
# 訓(xùn)練生成器 #
# ================================================================== #
# 用假圖像計算損失
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
# 我們訓(xùn)練 G 最大化 log(D(G(z)) 而不是最小化 log(1-D(G(z)))
# 原因見第3節(jié)最后一段。 https://arxiv.org/pdf/1406.2661.pdf
g_loss = criterion(outputs, real_labels)
# 反向傳播和優(yōu)化
reset_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))
# 保存真實圖片
if (epoch+1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
# 保存采樣圖像
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# 保存模型checkpoints
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
到了這里,關(guān)于【Pytorch深度學(xué)習(xí)實戰(zhàn)】(10)生成對抗網(wǎng)絡(luò)(GAN)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!