
?作者簡(jiǎn)介:人工智能專業(yè)本科在讀,喜歡計(jì)算機(jī)與編程,寫(xiě)博客記錄自己的學(xué)習(xí)歷程。
??個(gè)人主頁(yè):小嗷犬的個(gè)人主頁(yè)
??個(gè)人網(wǎng)站:小嗷犬的技術(shù)小站
??個(gè)人信條:為天地立心,為生民立命,為往圣繼絕學(xué),為萬(wàn)世開(kāi)太平。
VAE 簡(jiǎn)介
變分自編碼器(Variational Autoencoder,VAE)是一種深度學(xué)習(xí)中的生成模型,它結(jié)合了自編碼器(Autoencoder, AE)和概率建模的思想,在無(wú)監(jiān)督學(xué)習(xí)環(huán)境中表現(xiàn)出了強(qiáng)大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成為生成模型領(lǐng)域的重要組成部分。
基本原理
自編碼器(AE)基礎(chǔ):
自編碼器是一種神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),通常由兩部分組成:編碼器(Encoder)和解碼器(Decoder)。原始數(shù)據(jù)通過(guò)編碼器映射到一個(gè)低維的潛在空間(或稱為隱空間),這個(gè)低維向量被稱為潛變量(latent variable)。然后,潛變量再通過(guò)解碼器重構(gòu)回原始數(shù)據(jù)的近似版本。在訓(xùn)練過(guò)程中,自編碼器的目標(biāo)是使得輸入數(shù)據(jù)經(jīng)過(guò)編碼-解碼過(guò)程后能夠盡可能地恢復(fù)原貌,從而學(xué)習(xí)到數(shù)據(jù)的有效表示。
VAE的引入與擴(kuò)展:
VAE 將自編碼器的概念推廣到了概率框架下。在 VAE 中,潛變量不再是確定性的,而是被賦予了概率分布。具體來(lái)說(shuō),對(duì)于給定的輸入數(shù)據(jù),編碼器不直接輸出一個(gè)點(diǎn)估計(jì)值,而是輸出潛變量的均值和方差(假設(shè)潛變量服從高斯分布)。這樣,每個(gè)輸入數(shù)據(jù)可以被視為是從某個(gè)潛在的概率分布中采樣得到的。
變分推斷(Variational Inference):
訓(xùn)練 VA E時(shí),由于真實(shí)的后驗(yàn)概率分布難以直接計(jì)算,因此采用變分推斷來(lái)近似后驗(yàn)分布。編碼器實(shí)際上輸出的是一個(gè)參數(shù)化的概率分布
q
(
z
∣
x
)
q(z|x)
q(z∣x),即給定輸入
x
x
x 時(shí)潛變量
z
z
z 的概率分布。然后通過(guò)最小化 KL 散度(Kullback-Leibler divergence)來(lái)優(yōu)化這個(gè)近似分布,使其盡可能接近真實(shí)的后驗(yàn)分布
p
(
z
∣
x
)
p(z|x)
p(z∣x)。
目標(biāo)函數(shù) - Evidence Lower Bound (ELBO):
VAE 的目標(biāo)函數(shù)是證據(jù)下界(ELBO),它是原始數(shù)據(jù) log-likelihood 的下界。優(yōu)化該目標(biāo)函數(shù)既鼓勵(lì)編碼器找到數(shù)據(jù)的高效潛在表示,又促使解碼器基于這些表示重建出類似原始數(shù)據(jù)的新樣本。
數(shù)學(xué)表達(dá)上,ELBO 通常分解為兩個(gè)部分:
- 重構(gòu)損失(Reconstruction Loss):衡量從潛變量重構(gòu)出來(lái)的數(shù)據(jù)與原始數(shù)據(jù)之間的差異。
- KL散度損失(KL Divergence Loss):衡量編碼器產(chǎn)生的潛變量分布與預(yù)設(shè)的標(biāo)準(zhǔn)正態(tài)分布(或其他先驗(yàn)分布)之間的距離。
應(yīng)用與優(yōu)點(diǎn)
- VAE 可以用于生成新數(shù)據(jù),例如圖像、文本、音頻等。
- 由于其對(duì)潛變量進(jìn)行概率建模,所以它可以提供連續(xù)的數(shù)據(jù)生成,并且能夠探索數(shù)據(jù)的不同模式。
- 在處理連續(xù)和離散數(shù)據(jù)時(shí)具有一定的靈活性。
- 可以用于特征學(xué)習(xí),提取數(shù)據(jù)的有效低維表示。
缺點(diǎn)與挑戰(zhàn)
- 訓(xùn)練 VAE 可能需要大量的計(jì)算資源和時(shí)間。
- 生成的樣本有時(shí)可能不夠清晰或細(xì)節(jié)模糊,尤其是在復(fù)雜數(shù)據(jù)集上。
- 對(duì)于某些復(fù)雜的分布形式,VAE 可能無(wú)法完美捕獲所有細(xì)節(jié)。
使用 VAE 生成 MNIST 手寫(xiě)數(shù)字
下面我們將使用 PyTorch Lightning 來(lái)實(shí)現(xiàn)一個(gè)簡(jiǎn)單的 VAE 模型,并使用 MNIST 數(shù)據(jù)集來(lái)進(jìn)行訓(xùn)練和生成。
在線 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist
忽略警告
import warnings
warnings.filterwarnings("ignore")
導(dǎo)入必要的庫(kù)
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})
import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
設(shè)置隨機(jī)種子
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cuDNN 設(shè)置
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
超參數(shù)設(shè)置
batch_size = 64
epochs = 10
KLD_weight = 1
lr = 0.001
input_dim = 784 # 28 * 28
h_dim = 256 # 隱藏層維度
z_dim = 2 # 潛變量維度
數(shù)據(jù)加載
train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
定義 VAE 模型
class VAE(nn.Module):
def __init__(self, input_dim=784, h_dim=400, z_dim=20):
super(VAE, self).__init__()
self.input_dim = input_dim
self.h_dim = h_dim
self.z_dim = z_dim
# Encoder
self.fc1 = nn.Linear(input_dim, h_dim)
self.fc21 = nn.Linear(h_dim, z_dim) # mu
self.fc22 = nn.Linear(h_dim, z_dim) # log_var
# Decoder
self.fc3 = nn.Linear(z_dim, h_dim)
self.fc4 = nn.Linear(h_dim, input_dim)
def encode(self, x):
h = torch.relu(self.fc1(x))
mean = self.fc21(h)
log_var = self.fc22(h)
return mean, log_var
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = torch.relu(self.fc3(z))
out = torch.sigmoid(self.fc4(h))
return out
def forward(self, x):
mean, log_var = self.encode(x)
z = self.reparameterize(mean, log_var)
reconstructed_x = self.decode(z)
return reconstructed_x, mean, log_var
vae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])
定義損失函數(shù)
def loss_function(x_hat, x, mu, log_var, KLD_weight=1):
BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重構(gòu)損失
KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度損失
loss = BCE_loss + KLD_loss * KLD_weight
return loss, BCE_loss, KLD_loss
定義 Lightning 模型
class LitModel(pl.LightningModule):
def __init__(self, input_dim=784, h_dim=400, z_dim=20):
super().__init__()
self.model = VAE(input_dim, h_dim, z_dim)
def forward(self, x):
x = self.model(x)
return x
def configure_optimizers(self):
optimizer = optim.Adam(
self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
reconstructed_x, mean, log_var = self(x)
loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)
self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log_dict(
{
"BCE_loss": BCE_loss,
"KLD_loss": KLD_loss,
},
on_step=False,
on_epoch=True,
logger=True,
)
return loss
def decode(self, z):
out = self.model.decode(z)
return out
訓(xùn)練模型
model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(
max_epochs=epochs,
enable_progress_bar=True,
logger=logger,
callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)
繪制訓(xùn)練過(guò)程
log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"
plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()
隨機(jī)生成新樣本
row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()
plt.figure(figsize=(col, row))
for i in range(row * col):
plt.subplot(row, col, i + 1)
plt.imshow(random_res[i].squeeze(), cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()
文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-831469.html
根據(jù)潛變量插值生成新樣本
from scipy.stats import norm
n = 15
digit_size = 28
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
t = [xi, yi]
z_sampled = torch.FloatTensor(t)
with torch.no_grad():
decode = model.decode(z_sampled)
digit = decode.view((digit_size, digit_size))
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()
文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-831469.html
到了這里,關(guān)于變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!