變分自編碼器(Variational Autoencoder,VAE)是一種生成模型,通常用于學(xué)習(xí)數(shù)據(jù)的潛在表示,并用于生成新的數(shù)據(jù)樣本。它由兩部分組成:編碼器和解碼器。
-
編碼器(Encoder):接收輸入數(shù)據(jù),并將其映射到潛在空間中的分布。這意味著編碼器將數(shù)據(jù)轉(zhuǎn)換為均值和方差參數(shù)的分布,通常假設(shè)為高斯分布。
-
解碼器(Decoder):接收來自編碼器的潛在表示,并將其映射回原始數(shù)據(jù)空間。解碼器嘗試從潛在空間中的樣本中生成與輸入數(shù)據(jù)盡可能接近的重建數(shù)據(jù)。
VAE的目標(biāo)是學(xué)習(xí)一個(gè)能夠生成與訓(xùn)練數(shù)據(jù)類似的數(shù)據(jù)分布。為了實(shí)現(xiàn)這一點(diǎn),VAE采用了一種被稱為變分推斷的方法,其中引入了一個(gè)額外的損失項(xiàng),即KL散度,用于度量生成的潛在分布與預(yù)先設(shè)定的先驗(yàn)分布之間的差異。
VAE將經(jīng)過神經(jīng)網(wǎng)絡(luò)編碼后的隱藏層假設(shè)為一個(gè)標(biāo)準(zhǔn)的高斯分布,然后再?gòu)倪@個(gè)分布中采樣一個(gè)特征,再用這個(gè)特征進(jìn)行解碼,期望得到與原始輸入相同的結(jié)果,損失和AE幾乎一樣,只是增加編碼推斷分布與標(biāo)準(zhǔn)高斯分布的KL散度的正則項(xiàng),顯然增加這個(gè)正則項(xiàng)的目的就是防止模型退化成普通的AE,因?yàn)榫W(wǎng)絡(luò)訓(xùn)練時(shí)為了盡量減小重構(gòu)誤差,必然使得方差逐漸被降到0,這樣便不再會(huì)有隨機(jī)采樣噪聲,也就變成了普通的AE。
舉例來說,假設(shè)我們有一組手寫數(shù)字的圖像作為輸入數(shù)據(jù)。我們可以使用VAE來學(xué)習(xí)手寫數(shù)字的潛在表示,并用此表示來生成新的手寫數(shù)字圖像。編碼器將輸入圖像轉(zhuǎn)換為潛在空間中的分布,解碼器則將從該分布中采樣的樣本映射回原始圖像空間。通過訓(xùn)練編碼器和解碼器,VAE可以生成與訓(xùn)練數(shù)據(jù)類似的手寫數(shù)字圖像,同時(shí)學(xué)習(xí)數(shù)據(jù)的潛在結(jié)構(gòu)。
以下是使用 PyTorch 實(shí)現(xiàn)的簡(jiǎn)單示例代碼,演示了如何使用變分自編碼器(VAE)學(xué)習(xí)手寫數(shù)字的潛在表示,并用此表示來生成新的手寫數(shù)字圖像:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
# 定義變分自編碼器模型
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super(VAE, self).__init__()
# 編碼器部分
self.encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, latent_dim * 2) # 輸出均值和方差參數(shù)
)
# 解碼器部分
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, input_dim),
nn.Sigmoid() # 輸出范圍在 0 到 1 之間
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# 編碼
z_mu_logvar = self.encoder(x)
mu, logvar = torch.chunk(z_mu_logvar, 2, dim=1)
# 重參數(shù)化
z = self.reparameterize(mu, logvar)
# 解碼
x_recon = self.decoder(z)
return x_recon, mu, logvar
# 計(jì)算重構(gòu)損失和 KL 散度
def loss_function(x_recon, x, mu, logvar):
recon_loss = nn.BCELoss(reduction='sum')(x_recon, x) # 二進(jìn)制交叉熵?fù)p失
kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_divergence
# 數(shù)據(jù)預(yù)處理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1)) # 將圖像展平成向量
])
# 加載 MNIST 數(shù)據(jù)集
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 初始化模型和優(yōu)化器
latent_dim = 20
input_dim = 784 # 28x28
model = VAE(input_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 訓(xùn)練模型
num_epochs = 20
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, (x, _) in enumerate(train_loader):
optimizer.zero_grad()
x_recon, mu, logvar = model(x)
loss = loss_function(x_recon, x, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader.dataset)}")
# 使用訓(xùn)練好的模型生成新的手寫數(shù)字圖像
with torch.no_grad():
z = torch.randn(10, latent_dim) # 生成 10 個(gè)隨機(jī)潛在向量
generated_images = model.decoder(z)
generated_images = generated_images.view(-1, 1, 28, 28) # 將向量轉(zhuǎn)換成圖像形狀
# 可視化生成的圖像
fig, axes = plt.subplots(1, 10, figsize=(10, 1))
for i, ax in enumerate(axes):
ax.imshow(generated_images[i][0], cmap='gray')
ax.axis('off')
plt.show()
這段代碼首先定義了一個(gè)簡(jiǎn)單的變分自編碼器模型,然后使用 MNIST 數(shù)據(jù)集訓(xùn)練該模型,最后使用訓(xùn)練好的模型生成新的手寫數(shù)字圖像。文章來源:http://www.zghlxwxcb.cn/news/detail-847116.html
參考?【PyTorch】變分自編碼器/Variational Autoencoder(VAE)_variantautoencoder(vae)pytorch-CSDN博客文章來源地址http://www.zghlxwxcb.cn/news/detail-847116.html
到了這里,關(guān)于變分自編碼器生成新的手寫數(shù)字圖像的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!