国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn)

這篇具有很好參考價(jià)值的文章主要介紹了變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn)。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問(wèn)。

變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn),Python,深度學(xué)習(xí),pytorch,人工智能,python

?作者簡(jiǎn)介:人工智能專業(yè)本科在讀,喜歡計(jì)算機(jī)與編程,寫(xiě)博客記錄自己的學(xué)習(xí)歷程。
??個(gè)人主頁(yè):小嗷犬的個(gè)人主頁(yè)
??個(gè)人網(wǎng)站:小嗷犬的技術(shù)小站
??個(gè)人信條:為天地立心,為生民立命,為往圣繼絕學(xué),為萬(wàn)世開(kāi)太平。



VAE 簡(jiǎn)介

變分自編碼器(Variational Autoencoder,VAE)是一種深度學(xué)習(xí)中的生成模型,它結(jié)合了自編碼器(Autoencoder, AE)和概率建模的思想,在無(wú)監(jiān)督學(xué)習(xí)環(huán)境中表現(xiàn)出了強(qiáng)大的能力。VAE 在 2013 年由 Diederik P. Kingma 和 Max Welling 首次提出,并迅速成為生成模型領(lǐng)域的重要組成部分。

基本原理

自編碼器(AE)基礎(chǔ):
自編碼器是一種神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),通常由兩部分組成:編碼器(Encoder)和解碼器(Decoder)。原始數(shù)據(jù)通過(guò)編碼器映射到一個(gè)低維的潛在空間(或稱為隱空間),這個(gè)低維向量被稱為潛變量(latent variable)。然后,潛變量再通過(guò)解碼器重構(gòu)回原始數(shù)據(jù)的近似版本。在訓(xùn)練過(guò)程中,自編碼器的目標(biāo)是使得輸入數(shù)據(jù)經(jīng)過(guò)編碼-解碼過(guò)程后能夠盡可能地恢復(fù)原貌,從而學(xué)習(xí)到數(shù)據(jù)的有效表示。

VAE的引入與擴(kuò)展:
VAE 將自編碼器的概念推廣到了概率框架下。在 VAE 中,潛變量不再是確定性的,而是被賦予了概率分布。具體來(lái)說(shuō),對(duì)于給定的輸入數(shù)據(jù),編碼器不直接輸出一個(gè)點(diǎn)估計(jì)值,而是輸出潛變量的均值和方差(假設(shè)潛變量服從高斯分布)。這樣,每個(gè)輸入數(shù)據(jù)可以被視為是從某個(gè)潛在的概率分布中采樣得到的。

變分推斷(Variational Inference):
訓(xùn)練 VA E時(shí),由于真實(shí)的后驗(yàn)概率分布難以直接計(jì)算,因此采用變分推斷來(lái)近似后驗(yàn)分布。編碼器實(shí)際上輸出的是一個(gè)參數(shù)化的概率分布 q ( z ∣ x ) q(z|x) q(zx),即給定輸入 x x x 時(shí)潛變量 z z z 的概率分布。然后通過(guò)最小化 KL 散度(Kullback-Leibler divergence)來(lái)優(yōu)化這個(gè)近似分布,使其盡可能接近真實(shí)的后驗(yàn)分布 p ( z ∣ x ) p(z|x) p(zx)。

目標(biāo)函數(shù) - Evidence Lower Bound (ELBO):
VAE 的目標(biāo)函數(shù)是證據(jù)下界(ELBO),它是原始數(shù)據(jù) log-likelihood 的下界。優(yōu)化該目標(biāo)函數(shù)既鼓勵(lì)編碼器找到數(shù)據(jù)的高效潛在表示,又促使解碼器基于這些表示重建出類似原始數(shù)據(jù)的新樣本。

數(shù)學(xué)表達(dá)上,ELBO 通常分解為兩個(gè)部分:

  1. 重構(gòu)損失(Reconstruction Loss):衡量從潛變量重構(gòu)出來(lái)的數(shù)據(jù)與原始數(shù)據(jù)之間的差異。
  2. KL散度損失(KL Divergence Loss):衡量編碼器產(chǎn)生的潛變量分布與預(yù)設(shè)的標(biāo)準(zhǔn)正態(tài)分布(或其他先驗(yàn)分布)之間的距離。

應(yīng)用與優(yōu)點(diǎn)

  • VAE 可以用于生成新數(shù)據(jù),例如圖像、文本、音頻等。
  • 由于其對(duì)潛變量進(jìn)行概率建模,所以它可以提供連續(xù)的數(shù)據(jù)生成,并且能夠探索數(shù)據(jù)的不同模式。
  • 在處理連續(xù)和離散數(shù)據(jù)時(shí)具有一定的靈活性。
  • 可以用于特征學(xué)習(xí),提取數(shù)據(jù)的有效低維表示。

缺點(diǎn)與挑戰(zhàn)

  • 訓(xùn)練 VAE 可能需要大量的計(jì)算資源和時(shí)間。
  • 生成的樣本有時(shí)可能不夠清晰或細(xì)節(jié)模糊,尤其是在復(fù)雜數(shù)據(jù)集上。
  • 對(duì)于某些復(fù)雜的分布形式,VAE 可能無(wú)法完美捕獲所有細(xì)節(jié)。

使用 VAE 生成 MNIST 手寫(xiě)數(shù)字

下面我們將使用 PyTorch Lightning 來(lái)實(shí)現(xiàn)一個(gè)簡(jiǎn)單的 VAE 模型,并使用 MNIST 數(shù)據(jù)集來(lái)進(jìn)行訓(xùn)練和生成。

在線 Notebook:https://www.kaggle.com/code/marquis03/vae-mnist

忽略警告

import warnings
warnings.filterwarnings("ignore")

導(dǎo)入必要的庫(kù)

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

設(shè)置隨機(jī)種子

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

cuDNN 設(shè)置

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

超參數(shù)設(shè)置

batch_size = 64

epochs = 10
KLD_weight = 1
lr = 0.001

input_dim = 784  # 28 * 28
h_dim = 256  # 隱藏層維度  
z_dim = 2  # 潛變量維度

數(shù)據(jù)加載

train_dataset = datasets.MNIST(root="data", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

定義 VAE 模型

class VAE(nn.Module):
    def __init__(self, input_dim=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()

        self.input_dim = input_dim
        self.h_dim = h_dim
        self.z_dim = z_dim

        # Encoder
        self.fc1 = nn.Linear(input_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)  # mu
        self.fc22 = nn.Linear(h_dim, z_dim)  # log_var

        # Decoder
        self.fc3 = nn.Linear(z_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        h = torch.relu(self.fc1(x))
        mean = self.fc21(h)
        log_var = self.fc22(h)
        return mean, log_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = torch.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out

    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterize(mean, log_var)
        reconstructed_x = self.decode(z)
        return reconstructed_x, mean, log_var

vae = VAE(input_dim, h_dim, z_dim)
x = torch.randn((10, input_dim))
reconstructed_x, mean, log_var = vae(x)
print(reconstructed_x.shape, mean.shape, log_var.shape)
# torch.Size([10, 784]) torch.Size([10, 2]) torch.Size([10, 2])

定義損失函數(shù)

def loss_function(x_hat, x, mu, log_var, KLD_weight=1):
    BCE_loss = F.binary_cross_entropy(x_hat, x, reduction="sum") # 重構(gòu)損失
    KLD_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL 散度損失
    loss = BCE_loss + KLD_loss * KLD_weight
    return loss, BCE_loss, KLD_loss

定義 Lightning 模型

class LitModel(pl.LightningModule):
    def __init__(self, input_dim=784, h_dim=400, z_dim=20):
        super().__init__()
        self.model = VAE(input_dim, h_dim, z_dim)

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=lr, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        reconstructed_x, mean, log_var = self(x)
        loss, BCE_loss, KLD_loss = loss_function(reconstructed_x, x, mean, log_var, KLD_weight=KLD_weight)
        self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log_dict(
            {
                "BCE_loss": BCE_loss,
                "KLD_loss": KLD_loss,
            },
            on_step=False,
            on_epoch=True,
            logger=True,
        )
        return loss
    
    def decode(self, z):
        out = self.model.decode(z)
        return out

訓(xùn)練模型

model = LitModel(input_dim, h_dim, z_dim)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(monitor="loss", min_delta=0.00, patience=5, verbose=False, mode="min")
trainer = pl.Trainer(
    max_epochs=epochs,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[early_stop_callback],
)
trainer.fit(model, train_loader)

繪制訓(xùn)練過(guò)程

log_path = logger.log_dir + "/metrics.csv"
metrics = pd.read_csv(log_path)
x_name = "epoch"

plt.figure(figsize=(8, 6), dpi=100)
sns.lineplot(x=x_name, y="loss", data=metrics, label="Loss", linewidth=2, marker="o", markersize=10)
sns.lineplot(x=x_name, y="BCE_loss", data=metrics, label="BCE Loss", linewidth=2, marker="^", markersize=12)
sns.lineplot(x=x_name, y="KLD_loss", data=metrics, label="KLD Loss", linewidth=2, marker="s", markersize=10)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.tight_layout()
plt.show()

變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn),Python,深度學(xué)習(xí),pytorch,人工智能,python

隨機(jī)生成新樣本

row, col = 4, 18
z = torch.randn(row * col, z_dim)
random_res = model.model.decode(z).view(-1, 1, 28, 28).detach().numpy()

plt.figure(figsize=(col, row))
for i in range(row * col):
    plt.subplot(row, col, i + 1)
    plt.imshow(random_res[i].squeeze(), cmap="gray")
    plt.xticks([])
    plt.yticks([])
    plt.axis("off")
plt.show()

變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn),Python,深度學(xué)習(xí),pytorch,人工智能,python

根據(jù)潛變量插值生成新樣本

from scipy.stats import norm

n = 15
digit_size = 28

grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

figure = np.zeros((digit_size * n, digit_size * n))
for i, yi in enumerate(grid_y):
    for j, xi in enumerate(grid_x):
        t = [xi, yi]
        z_sampled = torch.FloatTensor(t)
        with torch.no_grad():
            decode = model.decode(z_sampled)
            digit = decode.view((digit_size, digit_size))
            figure[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = digit

plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap="gray")
plt.xticks([])
plt.yticks([])
plt.axis("off")
plt.show()

變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn),Python,深度學(xué)習(xí),pytorch,人工智能,python文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-831469.html

到了這里,關(guān)于變分自編碼器(VAE)PyTorch Lightning 實(shí)現(xiàn)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來(lái)自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 變分自編碼器(Variational AutoEncoder,VAE)

    變分自編碼器(Variational AutoEncoder,VAE)

    說(shuō)到編碼器這塊,不可避免地要講起 AE (AutoEncoder)自編碼器。它的結(jié)構(gòu)下圖所示: 據(jù)圖可知,AE通過(guò)自監(jiān)督的訓(xùn)練方式,能夠?qū)⑤斎氲脑继卣魍ㄟ^(guò)編碼encoder后得到潛在的特征編碼,實(shí)現(xiàn)了自動(dòng)化的特征工程,并且達(dá)到了降維和泛化的目的。而后通過(guò)對(duì)進(jìn)行decoder后,我們

    2024年01月18日
    瀏覽(25)
  • AIGC實(shí)戰(zhàn)——變分自編碼器(Variational Autoencoder, VAE)

    我們已經(jīng)學(xué)習(xí)了如何實(shí)現(xiàn)自編碼器,并了解了自編碼器無(wú)法在潛空間中的空白位置處生成逼真的圖像,且空間分布并不均勻,為了解決這些問(wèn)題#

    2024年02月05日
    瀏覽(22)
  • 理解 Stable Diffusion、模型檢查點(diǎn)(ckpt)和變分自編碼器(VAE)

    ? ? ? ? 在探索深度學(xué)習(xí)和人工智能領(lǐng)域的旅途中,理解Stable Diffusion、模型檢查點(diǎn)(ckpt)以及變分自編碼器(VAE)之間的關(guān)系至關(guān)重要。這些組件共同構(gòu)成了當(dāng)下一些最先進(jìn)圖像生成系統(tǒng)的基礎(chǔ)。本文將為初學(xué)者提供一個(gè)詳細(xì)的概述,幫助您理解這些概念以及它們是如何協(xié)同工作

    2024年01月21日
    瀏覽(27)
  • 簡(jiǎn)要介紹 | 生成模型的演進(jìn):從自編碼器(AE)到變分自編碼器(VAE)和生成對(duì)抗網(wǎng)絡(luò)(GAN),再到擴(kuò)散模型

    簡(jiǎn)要介紹 | 生成模型的演進(jìn):從自編碼器(AE)到變分自編碼器(VAE)和生成對(duì)抗網(wǎng)絡(luò)(GAN),再到擴(kuò)散模型

    注1:本文系“簡(jiǎn)要介紹”系列之一,僅從概念上對(duì)生成模型(包括AE, VAE, GAN,以及擴(kuò)散模型)進(jìn)行非常簡(jiǎn)要的介紹,不適合用于深入和詳細(xì)的了解。 生成模型在機(jī)器學(xué)習(xí)領(lǐng)域已經(jīng)成為了一個(gè)熱門(mén)的研究領(lǐng)域。它們的主要目標(biāo)是學(xué)習(xí)數(shù)據(jù)的真實(shí)分布,以便能夠生成新的、與真

    2024年02月14日
    瀏覽(15)
  • AI繪畫(huà)——Stable Diffusion模型,變分自編碼器(VAE)模型 , lora模型——調(diào)配設(shè)置與分享

    AI繪畫(huà)——Stable Diffusion模型,變分自編碼器(VAE)模型 , lora模型——調(diào)配設(shè)置與分享

    目錄 Stable Diffusion模型 模型調(diào)配 模型設(shè)置? 變分自編碼器(VAE)模型? 模型調(diào)配 模型設(shè)置? ?lora模型(原生)(插件) 模型調(diào)配 模型設(shè)置? ?AI生成prompt及模型分享 Stable Diffusion模型?pastel-mix+對(duì)應(yīng)的VAE ?Stable Diffusion模型國(guó)風(fēng)+Lora模型 墨心+疏可走馬 Stable Diffusion模型國(guó)風(fēng)+Lo

    2024年02月04日
    瀏覽(22)
  • 變分自編碼器生成新的手寫(xiě)數(shù)字圖像

    變分自編碼器生成新的手寫(xiě)數(shù)字圖像

    變分自編碼器(Variational Autoencoder,VAE) 是一種生成模型,通常用于學(xué)習(xí)數(shù)據(jù)的潛在表示,并用于生成新的數(shù)據(jù)樣本。它由兩部分組成:編碼器和解碼器。 編碼器(Encoder) :接收輸入數(shù)據(jù),并將其映射到潛在空間中的分布。這意味著編碼器將數(shù)據(jù)轉(zhuǎn)換為均值和方差參數(shù)的分

    2024年04月11日
    瀏覽(25)
  • AIGC實(shí)戰(zhàn)——使用變分自編碼器生成面部圖像

    在自編碼器和變分自編碼器上,我們都僅使用具有兩個(gè)維度的潛空間。這有助于我們可視化自編碼器和變分自編碼器的內(nèi)部工作原理,并理解自編碼器和變分自編碼潛空間分布的區(qū)別。在本節(jié)中,我們將使用更復(fù)雜的數(shù)據(jù)集,并了解增加潛空間的維度時(shí),變

    2024年02月05日
    瀏覽(22)
  • AE(自動(dòng)編碼器)與VAE(變分自動(dòng)編碼器)的區(qū)別和聯(lián)系?

    AE(自動(dòng)編碼器)與VAE(變分自動(dòng)編碼器)的區(qū)別和聯(lián)系?

    他們各自的概念看以下鏈接就可以了:https://blog.csdn.net/weixin_43135178/category_11543123.html ?這里主要談一下他們的區(qū)別? VAE是AE的升級(jí)版,VAE也可以被看作是一種特殊的AE AE主要用于數(shù)據(jù)的 壓縮與還原 ,VAE主要用于 生成 。 AE是將數(shù)據(jù)映直接映射為數(shù)值 code(確定的數(shù)值) ,而

    2024年02月03日
    瀏覽(164)
  • 在 CelebA 數(shù)據(jù)集上訓(xùn)練的 PyTorch 中的基本變分自動(dòng)編碼器

    在 CelebA 數(shù)據(jù)集上訓(xùn)練的 PyTorch 中的基本變分自動(dòng)編碼器

    摩西·西珀博士 ????????我最近發(fā)現(xiàn)自己需要一種方法將圖像 編碼到潛在嵌入中, 調(diào)整 嵌入,然后 生成 新圖像。有一些強(qiáng)大的方法可以創(chuàng)建嵌入 或 從嵌入生成。如果你想同時(shí)做到這兩點(diǎn),一種自然且相當(dāng)簡(jiǎn)單的方法是使用變分自動(dòng)編碼器。

    2024年02月05日
    瀏覽(18)
  • 圖像生成模型【自編碼器、RNN、VAE、GAN、Diffusion、AIGC等】

    圖像生成模型【自編碼器、RNN、VAE、GAN、Diffusion、AIGC等】

    目錄 監(jiān)督學(xué)習(xí) 與 無(wú)監(jiān)督學(xué)習(xí) 生成模型 自編碼器 從線性維度壓縮角度: 2D-1D 線性維度壓縮: 3D-2D 推廣線性維度壓縮 流形 自編碼器:流形數(shù)據(jù)的維度壓縮 全圖像空間 自然圖像流形 自編碼器的去噪效果 自編碼器的問(wèn)題 圖像預(yù)測(cè) (“結(jié)構(gòu)化預(yù)測(cè)”) 顯式密度模型 RNN PixelRNN [van

    2024年02月10日
    瀏覽(19)

覺(jué)得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包