通過(guò)估計(jì)數(shù)據(jù)分布梯度進(jìn)行生成建模
一文解釋 Diffusion Model (一) DDPM 理論推導(dǎo)
1 GAN到Stable Diffusion的改朝換代
隨著人工智能在圖像生成,文本生成以及多模態(tài)生成等生成領(lǐng)域
的技術(shù)不斷累積,生成對(duì)抗網(wǎng)絡(luò)(GAN)、變微分自動(dòng)編碼器(VAE)、normalizing flow models、自回歸模型(AR)、energy-based models以及近年來(lái)大火的擴(kuò)散模型(Diffusion Model)。
GAN:額外的判別器
VAE:對(duì)準(zhǔn)后驗(yàn)分布
EBM基于能量的模型:處理分區(qū)函數(shù)
歸一化流:施加網(wǎng)絡(luò)約束
生成領(lǐng)域 G A N 已經(jīng)有點(diǎn)過(guò)時(shí), S t a b l e D i f f u s i o n 替代了他的位置 生成領(lǐng)域GAN已經(jīng)有點(diǎn)過(guò)時(shí),Stable Diffusion替代了他的位置 生成領(lǐng)域GAN已經(jīng)有點(diǎn)過(guò)時(shí),StableDiffusion替代了他的位置
- GAN要訓(xùn)練倆網(wǎng)絡(luò),感覺難度較大,容易不收斂,而且多樣性比較差,只關(guān)注能騙過(guò)判別器就得了。
- Diffusion Model用一種更簡(jiǎn)單的方法來(lái)詮釋了生成模型該如何學(xué)習(xí)以及生成,其實(shí)感覺更簡(jiǎn)單。
采樣正態(tài)分布z,經(jīng)過(guò)Network G(·),得到生成圖像x,期望生成圖像x與目標(biāo)圖像real盡可能相似。
生成式模型的共同目標(biāo):
使用神經(jīng)網(wǎng)絡(luò),將正態(tài)分布模擬生成圖像的概率分布,盡可能接近真正圖像的概率分布
使用神經(jīng)網(wǎng)絡(luò),將正態(tài)分布模擬生成圖像的概率分布,盡可能接近真正圖像的概率分布
使用神經(jīng)網(wǎng)絡(luò),將正態(tài)分布模擬生成圖像的概率分布,盡可能接近真正圖像的概率分布
求Network的最佳參數(shù)
θ
\theta
θ -> 最大似然估計(jì)(使生成圖像分布
P
θ
P_\theta
Pθ?與標(biāo)簽圖片分布
P
d
a
t
a
P_{data}
Pdata?中
x
i
x_i
xi?的概率最大)最大似然估計(jì)
(使生成圖像的分布接近目標(biāo)圖像) -> 最小化KL散度
(最小化兩種分布的差異)
———————————————————————————————————————————————————————
2 從DDPM到Stable Diffusion發(fā)展史
2.1 DDPM
Diffusion擴(kuò)散模型是一類生成式模型,從隨機(jī)噪聲直接生成圖片。[DDPM: Denoising Diffusion Probabilistic Models]
擴(kuò)散過(guò)程(正向)
輸入原始圖像
x
0
x_0
x0?,經(jīng)過(guò)T步不斷將高斯噪聲
?
t
?
1
∈
N
(
0
,
1
)
\epsilon_{t-1}\in N(0,1)
?t?1?∈N(0,1) 加入到原始圖片
x
0
x_0
x0?中,得到破壞圖片
x
t
x_t
xt?。(擴(kuò)散階段是不含訓(xùn)練參數(shù)的,噪聲的標(biāo)準(zhǔn)差是固定的,均值由標(biāo)準(zhǔn)差和X0決定
)
理想的擴(kuò)散過(guò)程,是分為N步進(jìn)行加入高斯噪聲
因?yàn)槊看渭尤氲脑肼暿仟?dú)立的,可以使用遞歸帶入實(shí)現(xiàn)一次加N個(gè)獨(dú)立采樣的高斯噪聲
更進(jìn)一步,因?yàn)閮蓚€(gè)高斯分布X~
N
(
μ
1
,
σ
1
)
N(\mu_1, \sigma_1)
N(μ1?,σ1?)和Y~
N
(
μ
2
,
σ
2
)
N(\mu_2, \sigma_2)
N(μ2?,σ2?)疊加后的分布 aX+bY ~
N
(
a
μ
1
+
b
μ
2
,
a
2
σ
1
+
b
2
σ
2
2
)
N(a\mu_1+b\mu_2, \sqrt{a^2\sigma_1+b^2\sigma_2^2})
N(aμ1?+bμ2?,a2σ1?+b2σ22??)。所以可以把使用不同weight獨(dú)立采樣的噪聲
簡(jiǎn)化成 使用綜合weight只采樣一次高斯噪聲(t 越大,噪聲權(quán)重越大,beta 越大)
以此類推,從
x
0
x_0
x0?擴(kuò)散為
x
t
x_t
xt?,可以只用一次采樣的高斯噪聲
即可完成。同時(shí)做變量代換
,簡(jiǎn)化
β
\beta
β系列的weight為
α
\alpha
α系列,實(shí)際上
(
1
?
α
)
(1-\alpha)
(1?α)和
β
\beta
β就是噪聲的權(quán)重。(一階高斯-馬爾可夫過(guò)程
是指一個(gè)連續(xù)時(shí)間的隨機(jī)過(guò)程,其中狀態(tài)變量服從高斯分布,并且滿足馬爾可夫性質(zhì),即未來(lái)狀態(tài)只取決于當(dāng)前狀態(tài),可以使用參數(shù)重整化
的技巧完成單次采樣)
最后實(shí)際上,只做一次采樣噪聲,就可以完成
x
0
x_0
x0?到
x
t
x_t
xt?的轉(zhuǎn)換。任意時(shí)刻的圖像分布
q
(
x
t
∣
x
0
)
q(x_t|x_0)
q(xt?∣x0?)都可以直接推導(dǎo)出來(lái)
q
(
x
t
∣
x
0
)
=
N
(
x
t
;
a
t
ˉ
x
0
,
(
1
?
a
t
ˉ
)
ε
)
q(x_t|x_0)=N(x_t; \sqrt{\bar{a_t}}x_0, (1-\bar{a_t})\varepsilon)
q(xt?∣x0?)=N(xt?;at?ˉ??x0?,(1?at?ˉ?)ε),不用迭代訓(xùn)練。
擴(kuò)散過(guò)程總結(jié)
:①多次采樣化單次采樣(
x
0
x_0
x0?到
x
t
x_t
xt?)。②變量代換(
β
\beta
β系列的weight代換為
α
\alpha
α系列)
去噪過(guò)程(反向)
輸入含噪圖片
x
t
x_t
xt?,通過(guò)預(yù)測(cè)噪聲
?
θ
(
x
t
,
t
)
\epsilon_\theta( x_t,t)
?θ?(xt?,t) ,將預(yù)測(cè)到的噪聲從圖片中去除,多次迭代逐漸將被破壞的
x
t
x_t
xt? 恢復(fù)成
x
0
x_0
x0?,實(shí)現(xiàn)還原成圖片。訓(xùn)練一個(gè)噪聲預(yù)測(cè)模型,并將輸入隨機(jī)噪聲還原成圖片,其中噪聲就是標(biāo)簽,還原的時(shí)候,模型根據(jù)噪聲生成對(duì)應(yīng)的圖像
。
實(shí)際上,在含噪圖像 減去 預(yù)測(cè)噪聲
?
θ
(
x
t
,
t
)
\epsilon_\theta( x_t,t)
?θ?(xt?,t)后,還要再加上一個(gè)隨機(jī)生成的噪聲Z。
為什么最后還要再加上一個(gè)隨機(jī)噪聲Z ? 因?yàn)槟P?code>預(yù)測(cè)的是高斯噪聲的均值mean,所以還要加上一個(gè)方差
(用z表示)
但為什么不直接取mean呢?(為什么需要保持隨機(jī)性:在denoise時(shí),加點(diǎn)隨機(jī)性
,效果會(huì)更好,保證了生成多樣性
)
DDPM的關(guān)鍵在于去噪過(guò)程 :訓(xùn)練一個(gè) (根據(jù) 含噪圖片
x
t
x_t
xt? 和 輪次
t
t
t 來(lái))噪聲估計(jì)模型
?
θ
(
x
t
,
t
)
\epsilon_\theta( x_t,t)
?θ?(xt?,t) ,其中
θ
\theta
θ 就是模型的訓(xùn)練參數(shù), 使模型
?
θ
(
x
t
,
t
)
\epsilon_\theta( x_t,t)
?θ?(xt?,t) 預(yù)測(cè)的噪聲
σ
^
\hat{\sigma}
σ^ 與真實(shí)用于破壞圖片的噪聲標(biāo)簽
σ
\sigma
σ的L2 Loss更小。在DDPM中,使用U-Net作為預(yù)測(cè)噪聲的模型。
Unet可以嵌入更多的信息:時(shí)間步time
、生成文本描述context
總結(jié)
擴(kuò)散模型是怎么工作的?
-
前向擴(kuò)散過(guò)程: 一個(gè)
固定不含參(或預(yù)定義)的
前向擴(kuò)散過(guò)程 q q q ,這個(gè)前向過(guò)程 q ( x T ∣ x 0 ) = N ( x t ; a t ˉ x 0 , ( 1 ? a t ˉ ) I ) q(x_T|x_0)=N(x_t; \sqrt{\bar{a_t}}x_0, (1-\bar{a_t})I) q(xT?∣x0?)=N(xt?;at?ˉ??x0?,(1?at?ˉ?)I) 會(huì)逐漸向圖像添加高斯噪聲 z t z_t zt?,直到你最終得到純?cè)肼暋?span id="n5n3t3z" class="katex--inline"> x t = a t ˉ x 0 + 1 ? a t ˉ z t x_t=\sqrt{\bar{a_t}}x_0+\sqrt{1-\bar{a_t}}z_t xt?=at?ˉ??x0?+1?at?ˉ??zt? -
反向生成過(guò)程: 一個(gè)
通過(guò)學(xué)習(xí)得到的
含參 θ \theta θ反向去噪擴(kuò)散過(guò)程 p θ p_\theta pθ? ???,這個(gè)反向過(guò)程 p θ ( x 0 ∣ x T ) p_{\theta}(x_0|x_T) pθ?(x0?∣xT?)就是神經(jīng)網(wǎng)絡(luò)從訓(xùn)練中學(xué)會(huì)如何從純?cè)肼曢_始逐漸對(duì)一個(gè)圖像進(jìn)行去噪,直到你最終能夠得到一個(gè)實(shí)際圖像。,從一個(gè)隨機(jī)噪音開始逐漸去噪音 z t z_t zt?,直至生成一張圖像。 x 0 = 1 a t ˉ ( x t ? 1 ? a t ˉ z t ) x_0=\frac{1}{\sqrt{\bar{a_t}}}(x_t-\sqrt{1-\bar{a_t}}z_t) x0?=at?ˉ??1?(xt??1?at?ˉ??zt?),后驗(yàn)方差 β t ~ = 1 ? a t ? 1 ˉ 1 ? a t ˉ β t \tilde{\beta_t}=\frac{1-\bar{a_{t-1}}}{1-\bar{a_t}}\beta_t βt?~?=1?at?ˉ?1?at?1?ˉ??βt?,后驗(yàn)均值 μ t ~ = 1 a t ( x t ? β t 1 ? a t ˉ z t ) \tilde{\mu_t}=\frac{1}{\sqrt{a_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar{a_t}}}z_t) μt?~?=at??1?(xt??1?at?ˉ??βt??zt?)
- 正向和反向都是
馬爾可夫鏈
:由一組狀態(tài)和狀態(tài)之間的轉(zhuǎn)移概率組成。每個(gè)狀態(tài)代表一個(gè)可能的事件或狀態(tài),轉(zhuǎn)移概率表示從一個(gè)狀態(tài)轉(zhuǎn)移到另一個(gè)狀態(tài)的概率。根據(jù)這些轉(zhuǎn)移概率,可以使用馬爾可夫鏈進(jìn)行模擬和預(yù)測(cè)。
優(yōu)化目標(biāo)
https://zhuanlan.zhihu.com/p/563661713
高斯分布p和高斯分布q的KL散度:
相比VAE來(lái)說(shuō),擴(kuò)散模型的隱變量是和原始數(shù)據(jù)同維度的,而且encoder(即擴(kuò)散過(guò)程)是固定的。既然擴(kuò)散模型是隱變量模型,那么我們可以就可以基于變分推斷來(lái)得到variational lower bound(VLB,又稱evidence lower bound,ELBO證據(jù)下界)作為最大化優(yōu)化目標(biāo)
生成目標(biāo)
x
0
x_0
x0?的分布
p
θ
(
x
0
)
p_{\theta}(x_0)
pθ?(x0?)的似然函數(shù):求對(duì)數(shù)似然最大值 -> 求負(fù)對(duì)數(shù)似然最小值 -> 求負(fù)似然函數(shù)上界最小值 ->
求
L
V
L
B
L_{VLB}
LVLB?最小值-> 優(yōu)化預(yù)測(cè)噪聲和真實(shí)噪聲的L2誤差
這里最后一步是利用了Jensen’s inequality,對(duì)于網(wǎng)絡(luò)訓(xùn)練來(lái)說(shuō),其訓(xùn)練目標(biāo)為VLB取負(fù):
雖然擴(kuò)散模型背后的推導(dǎo)比較復(fù)雜,但是我們最終得到的優(yōu)化目標(biāo)非常簡(jiǎn)單,就是讓網(wǎng)絡(luò)預(yù)測(cè)的噪音和真實(shí)的噪音一致
。DDPM的訓(xùn)練過(guò)程也非常簡(jiǎn)單:隨機(jī)選擇一個(gè)訓(xùn)練樣本
->從[1,T]隨機(jī)抽樣一個(gè)t
->隨機(jī)產(chǎn)生噪音
-計(jì)算當(dāng)前所產(chǎn)生的帶噪音數(shù)據(jù)
->輸入網(wǎng)絡(luò)預(yù)測(cè)噪音
->計(jì)算產(chǎn)生的噪音和預(yù)測(cè)的噪音的L2損失(等價(jià)于優(yōu)化負(fù)對(duì)數(shù)似然)
->計(jì)算梯度并更新網(wǎng)絡(luò)
(實(shí)際上DDPM的優(yōu)化目標(biāo)是噪聲預(yù)測(cè),而不是直接優(yōu)化生成的圖片
x
0
x_0
x0?)
一旦訓(xùn)練完成,其采樣過(guò)程也非常簡(jiǎn)單:從一個(gè)隨機(jī)噪音開始
,并用訓(xùn)練好的網(wǎng)絡(luò)預(yù)測(cè)噪音
,然后計(jì)算條件分布的均值
,然后用均值加標(biāo)準(zhǔn)差再乘以一個(gè)隨機(jī)噪音
,直至t=0完成新樣本的生成
(最后一步不加噪音)。
不過(guò)實(shí)際的代碼實(shí)現(xiàn)和上述過(guò)程略有區(qū)別(見https://github.com/hojonathanho/diffusion/issues/5:先基于預(yù)測(cè)的噪音生成 x 0 x_0 x0?,并進(jìn)行了clip處理(范圍[-1, 1],原始數(shù)據(jù)歸一化到這個(gè)范圍),然后再計(jì)算均值。我個(gè)人的理解這應(yīng)該算是一種約束,既然模型預(yù)測(cè)的是噪音,那么我們也希望用預(yù)測(cè)噪音重構(gòu)處理的原始數(shù)據(jù)也應(yīng)該滿足范圍要求。
理論推導(dǎo)
為什么假設(shè)噪聲是正態(tài)分布(高斯分布)?
中心極限定理(CLT):對(duì)于一個(gè)分布的預(yù)測(cè),若通過(guò)大量的獨(dú)立同分布采樣取均值進(jìn)行,在滿足一些條件下(工程上一般默認(rèn)滿足),它依分布逼近于正態(tài)分布,且具有與未知分布相同的均值和方差。因此,在前向過(guò)程中,不斷給樣本加高斯分布的噪聲,最后樣本也變成一個(gè)高斯噪聲了。
前向擴(kuò)散過(guò)程重要公式:
x
t
x_t
xt?是t時(shí)刻的圖像分布,
z
i
z_i
zi?是噪聲,我們可以通過(guò)初始的分布
x
0
x_0
x0?和噪聲
z
i
z_i
zi?,進(jìn)行N步擴(kuò)散,得到最終的噪聲圖像
x
n
x_n
xn?
反向生成過(guò)程重要公式:
學(xué)習(xí)到噪聲預(yù)估模型
?
θ
(
x
n
,
n
)
\epsilon_\theta( x_n,n)
?θ?(xn?,n)后,隨機(jī)生成一個(gè)初始噪聲
x
n
x_n
xn?,通過(guò)該模型,做N步生成去噪聲,恢復(fù)到
x
0
x_0
x0?圖片。
Diffusion起作用的關(guān)鍵:
隱變量模型、兩個(gè)過(guò)程都是一個(gè)參數(shù)化的馬爾可夫鏈、變分推斷來(lái)進(jìn)行建模和求解
代碼解析
import io
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.datasets import make_s_curve
from PIL import Image
def diffusion_loss_fun(model, x_0):
"""輸入原圖x_0預(yù)測(cè)隨機(jī)時(shí)刻t的噪聲計(jì)算loss,t是隨機(jī)生成的,實(shí)際上是計(jì)算batch_size個(gè)原圖的噪聲預(yù)測(cè)loss"""
batch_size = x_0.shape[0]
# 隨機(jī)采樣一個(gè)時(shí)刻t,為了提高訓(xùn)練效率,保證t不重復(fù)
t = torch.randint(0, model.num_steps, size=(batch_size // 2,))
t = torch.cat([t, model.num_steps - 1 - t], dim=0) # [batch_size, 1]
t = t.unsqueeze(-1) # batch_size長(zhǎng)度的序列
# x0的系數(shù)
x_weight = model.alpha_bar_sqrt[t]
# noise_eps的系數(shù)
noise_weight = model.one_minus_alpha_bar_sqrt[t]
# 生成noise
noise = torch.randn_like(x_0)
# 構(gòu)造模型的輸入x_t
x_t = x_0 * x_weight + noise * noise_weight
# 送入模型,預(yù)測(cè)t時(shí)刻的噪聲
pred = model(x_t, t.squeeze(-1))
# 計(jì)算預(yù)測(cè)結(jié)果與真實(shí)結(jié)果的L2誤差(噪聲圖像MSE loss)
return (noise - pred).square().mean()
class DDPM(nn.Module):
def __init__(self, num_steps=100, num_units=128):
super(DDPM, self).__init__()
"""設(shè)置超參數(shù) T、alpha、beta"""
self.num_steps = num_steps # 迭代步數(shù)T
# 生成100步中,每一步的beta, 保證噪聲的權(quán)重beta比較小,且逐漸增大,來(lái)滿足每個(gè)逆擴(kuò)散過(guò)程也是高斯分布的假設(shè)
self.betas = torch.linspace(-6, 6, self.num_steps)
self.betas = torch.sigmoid(self.betas) * (0.5e-2 - 1e-5) + 1e-5
# 根據(jù)beta, 計(jì)算alpha, alpha_prod, alpha_previous, alpha_bar_sqrt
self.alpha = 1 - self.betas
self.alpha_prod = torch.cumprod(self.alpha, 0) # 每一步之前所有alpha的累乘
self.alpha_prod_p = torch.cat([torch.tensor([1]).float(), self.alpha_prod[:-1]],
0) # 去掉alpha_prod[-1],然后在前面添加1值
self.alpha_bar_sqrt = torch.sqrt(self.alpha_prod) # 原圖x_0的權(quán)重weight
# 計(jì)算log(1-alpha_bar), sqrt(1-alpha_bar)
self.one_minus_alpha_bar_log = torch.log(1 - self.alpha_prod)
self.one_minus_alpha_bar_sqrt = torch.sqrt(1 - self.alpha_prod) # 噪聲noise的權(quán)重weight
assert (self.alpha.shape == self.alpha_prod.shape == self.alpha_prod_p.shape == self.alpha_bar_sqrt.shape ==
self.one_minus_alpha_bar_sqrt.shape == self.one_minus_alpha_bar_log.shape)
print(f"all shape same:{self.betas.shape}")
"""反向去噪過(guò)程,預(yù)測(cè)噪聲的模型(一般為Unet),但此處使用MLP與直接x+t"""
# 輸入含噪圖像x_t的mlp
self.mlp = nn.ModuleList(
[
nn.Linear(2, num_units),
nn.ReLU(),
nn.Linear(num_units, num_units),
nn.ReLU(),
nn.Linear(num_units, num_units),
nn.ReLU(),
nn.Linear(num_units, 2)
]
)
# 時(shí)間步t的embedding
self.step_embeddings = nn.ModuleList(
[
nn.Embedding(num_steps, num_units),
nn.Embedding(num_steps, num_units),
nn.Embedding(num_steps, num_units),
]
)
def forward(self, x_0, t): # 用mlp模擬unet預(yù)測(cè)輸入原圖x_0在第t步生成的噪聲
x = x_0
for idx, embedding_layer in enumerate(self.step_embeddings): # 3次對(duì)x進(jìn)行t_embedding融合
t_embedding = embedding_layer(t) # 對(duì)t進(jìn)行embedding
x = self.mlp[2 * idx](x) # 對(duì)x進(jìn)行全連接計(jì)算
x += t_embedding # x+t
x = self.mlp[2 * idx + 1](x) # x經(jīng)過(guò)relu
return self.mlp[-1](x) # 經(jīng)過(guò)最后的fc層使得x形狀不變
def q_x(self, x_0, t):
"""基于 原圖x_0 計(jì)算 第t步 生成 噪聲圖片x_t"""
noise = torch.randn_like(x_0) # noise與x_0形狀相同的高斯噪聲圖像
alpha_x_0_t = self.alpha_bar_sqrt[t] # 第t步原圖x_0的權(quán)重
alpha_noise_t = self.one_minus_alpha_bar_sqrt[t] # 第t步噪聲noise的權(quán)重
return alpha_x_0_t * x_0 + alpha_noise_t * noise # 基于x_0和步驟t直接計(jì)算噪聲圖像x_t
def forward_diffusion(self, num_show: int, dataset):
"""模擬經(jīng)過(guò)num_steps步加噪聲的過(guò)程"""
fig, axs = plt.subplots(2, num_show // 2, figsize=(28, 3))
plt.rc('text', color='blue')
# 10000個(gè)點(diǎn),每個(gè)點(diǎn)2個(gè)坐標(biāo),生成num_steps=100步以內(nèi)每隔num_steps//num_show=5步加噪聲后的圖像
for i in range(num_show):
j = i // 10
k = i % 10
q_i = self.q_x(dataset, torch.tensor([i * self.num_steps // num_show])) # 生成i時(shí)刻的加噪圖像x_i
axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')
axs[j, k].set_axis_off()
axs[j, k].set_title('$q(\mathbf{x}_{' + str(i * self.num_steps // num_show) + '})$')
plt.savefig('forward_diffusion.png')
def p_sample_loop(model, shape):
"""inference:從x_t恢復(fù)x_{t-1}...x_0"""
cur_x = torch.randn(shape) # [10000,2] 10000個(gè)點(diǎn)的坐標(biāo)
x_seq = [cur_x]
for i in reversed(range(model.num_steps)): # 倒著從x_100, x_99, x_98...x_0進(jìn)行預(yù)測(cè)噪聲去噪生成,放入x_seq
cur_x = p_sample(model, cur_x, i)
x_seq.append(cur_x)
return x_seq
def p_sample(model, x, t):
"""從x_t預(yù)測(cè)t時(shí)刻的噪聲,重構(gòu)圖像x_0"""
t = torch.tensor([t])
coeff = model.betas[t] / model.one_minus_alpha_bar_sqrt[t]
eps_theta = model(x, t) # 預(yù)測(cè)第t步噪聲圖像的噪聲eps
mean = (1 / (1 - model.betas[t].sqrt()) * (x - (coeff * eps_theta))) # 將含噪圖像x減去噪聲eps得到均值mean
z = torch.rand_like(x) # 再隨機(jī)采樣作為方差sigma
sigma_t = model.betas[t].sqrt() # 方差的權(quán)重weight
sample = mean + sigma_t * z
return sample
def train():
seed = 0
batch_size = 128
num_epoch = 4000
print('Train model...')
"""導(dǎo)入數(shù)據(jù)集dataset和dataloader"""
s_curve, _ = make_s_curve(10 ** 4, noise=0.1) # 生成S曲線散點(diǎn)數(shù)據(jù)集(高斯噪聲0.1)
s_curve = s_curve[:, [0, 2]] / 10.0 # 包含10000個(gè)點(diǎn)的坐標(biāo)
dataset = torch.Tensor(s_curve).float()
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)
plt.rc('text', color='blue')
"""實(shí)例化DDPM模型"""
model = DDPM() # 輸出維度是2,輸入是x和step
"""實(shí)例化optimizer優(yōu)化器,但不能實(shí)例化loss,因?yàn)橐徊⑤斎隻atch_size個(gè)x_0"""
optimizer = torch.optim.Adam(model.parameters())
x_seq = None
for t in range(num_epoch):
for idx, batch_x in enumerate(dataloader): # dataloader這里只能得到batch_size個(gè)x_0,不能得到y(tǒng)噪聲標(biāo)簽,因?yàn)槟鞘窃趌oss里生成的
loss = diffusion_loss_fun(model, batch_x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 這里遍歷model每個(gè)參數(shù)用EMA
if t % 100 == 0: # 每訓(xùn)練100個(gè)epoch,就打印一次loss,并進(jìn)行一次推理預(yù)測(cè)p_sample_loop
print(f"epoch:{t}, loss:{loss}")
# 進(jìn)行一次推理預(yù)測(cè)p_sample_loop,預(yù)測(cè)100個(gè)圖像的序列x_seq的
x_seq = p_sample_loop(model, dataset.shape) # 共100個(gè)元素
fig, axs = plt.subplots(1, 10, figsize=(28, 3)) # 對(duì)100個(gè)元素進(jìn)行推理生成
for i in range(1, 11):
cur_x = x_seq[i * 10].detach() # 間隔10,共取10個(gè)元素進(jìn)行可視化
axs[i - 1].scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white')
axs[i - 1].set_axis_off()
axs[i - 1].set_title('$q(\mathbf{x}_{' + str(i * 10) + '})$')
plt.savefig(f'./logs/res{t}.png')
plt.close(fig) # 手動(dòng)關(guān)閉圖形窗口
torch.save(x_seq, './logs/final_x_seq.pt') # 保存為 PyTorch Tensor 格式
def Generating_gif():
"""正向過(guò)程(加噪聲)gif生成"""
s_curve, _ = make_s_curve(10 ** 4, noise=0.1) # 生成S曲線散點(diǎn)數(shù)據(jù)集(高斯噪聲0.1)
s_curve = s_curve[:, [0, 2]] / 10.0 # 包含10000個(gè)點(diǎn)的坐標(biāo)
dataset = torch.Tensor(s_curve).float()
ddpm = DDPM()
imgs = []
for i in range(100):
plt.clf()
q_i = ddpm.q_x(dataset, torch.tensor([i]))
plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolors='white', s=5)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, format('png'))
img = Image.open(img_buf)
imgs.append(img)
"""正向過(guò)程(加噪聲)gif生成"""
reverse = []
for i in range(100):
plt.clf()
x_seq = torch.load('./logs/final_x_seq.pt') # 拿到訓(xùn)練階段生成的x_seq
cur_x = x_seq[i].detch()
plt.scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolors='white', s=5)
plt.axis('off')
img_buf = io.BytesIO()
plt.savefig(img_buf, format('png'))
img = Image.open(img_buf)
reverse.append(img)
"""合并加噪和去噪"""
imgs = imgs + reverse
imgs[0].save("./logs/diffusion.gif", format='GIF', append_imges=imgs, save_all=True, duration=100, loop=0)
if __name__ == "__main__":
train()
Generating_gif()
2.2 Stable Diffusion
擴(kuò)散模型最大的問題是它的時(shí)間成本和經(jīng)濟(jì)成本都極其“昂貴”。Stable Diffusion的出現(xiàn)就是為了解決上述問題。如果我們想要生成一張 1024 × 1024 尺寸的圖像,U-Net 會(huì)使用 1024 × 1024 尺寸的噪聲,然后從中生成圖像。這里做一步擴(kuò)散的計(jì)算量就很大,更別說(shuō)要循環(huán)迭代多次直到100%。一個(gè)解決方法是將大圖片拆分為若干小分辨率的圖片進(jìn)行訓(xùn)練,然后再使用一個(gè)額外的神經(jīng)網(wǎng)絡(luò)來(lái)產(chǎn)生更大分辨率的圖像(超分辨率擴(kuò)散)。
潛在空間(Lantent Space)
潛在空間簡(jiǎn)單的說(shuō)是對(duì)壓縮數(shù)據(jù)的表示。所謂壓縮指的是用比原始表示更小的數(shù)位來(lái)編碼信息的過(guò)程。比如我們用一個(gè)顏色通道(黑白灰)來(lái)表示原來(lái)由RGB三原色構(gòu)成的圖片,此時(shí)每個(gè)像素點(diǎn)的顏色向量由3維變成了1維度。維度降低會(huì)丟失一部分信息,然而在某些情況下,降維不是件壞事。通過(guò)降維我們可以過(guò)濾掉一些不太重要的信息你,只保留最重要的信息。
假設(shè)我們像通過(guò)全連接的卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練一個(gè)圖像分類模型。當(dāng)我們說(shuō)模型在學(xué)習(xí)時(shí),我們的意思是它在學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)每一層的特定屬性,比如邊緣、角度、形狀等……每當(dāng)模型使用數(shù)據(jù)(已經(jīng)存在的圖像)學(xué)習(xí)時(shí),都會(huì)將圖像的尺寸先減小再恢復(fù)到原始尺寸。最后,模型使用解碼器從壓縮數(shù)據(jù)中重建圖像,同時(shí)學(xué)習(xí)之前的所有相關(guān)信息。因此,空間變小,以便提取和保留最重要的屬性。這就是潛在空間適用于擴(kuò)散模型的原因。
Latent Diffusion
“潛在擴(kuò)散模型”(Latent Diffusion Model)將GAN的感知能力、擴(kuò)散模型的細(xì)節(jié)保存能力和Transformer的語(yǔ)義能力三者結(jié)合,創(chuàng)造出比上述所有模型更穩(wěn)健和高效的生成模型。與其他方法相比,Latent Diffusion不僅節(jié)省了內(nèi)存,而且生成的圖像保持了多樣性和高細(xì)節(jié)度,同時(shí)圖像還保留了數(shù)據(jù)的語(yǔ)義結(jié)構(gòu)。
任何生成性學(xué)習(xí)方法都有兩個(gè)主要階段:感知壓縮和語(yǔ)義壓縮。
感知壓縮
在感知壓縮學(xué)習(xí)階段,學(xué)習(xí)方法必須去除高頻細(xì)節(jié)將數(shù)據(jù)封裝到抽象表示中。此步驟對(duì)構(gòu)建一個(gè)穩(wěn)定、魯棒的環(huán)境表示是必要的。GAN 擅長(zhǎng)感知壓縮,通過(guò)將高維冗余數(shù)據(jù)從像素空間投影到潛在空間的超空間來(lái)實(shí)現(xiàn)這一點(diǎn)。潛在空間中的潛在向量是原始像素圖像的壓縮形式,可以有效地代替原始圖像。更具體地說(shuō),用自動(dòng)編碼器 (Auto Encoder) 結(jié)構(gòu)捕獲感知壓縮。 自動(dòng)編碼器中的編碼器將高維數(shù)據(jù)投影到潛在空間,解碼器從潛在空間恢復(fù)圖像。
語(yǔ)義壓縮
在學(xué)習(xí)的第二階段,圖像生成方法必須能夠捕獲數(shù)據(jù)中存在的語(yǔ)義結(jié)構(gòu)。 這種概念和語(yǔ)義結(jié)構(gòu)提供了圖像中各種對(duì)象的上下文和相互關(guān)系的保存。 Transformer擅長(zhǎng)捕捉文本和圖像中的語(yǔ)義結(jié)構(gòu)。 Transformer的泛化能力和擴(kuò)散模型的細(xì)節(jié)保存能力相結(jié)合,提供了兩全其美的方法,并提供了一種生成細(xì)粒度的高度細(xì)節(jié)圖像的方法,同時(shí)保留圖像中的語(yǔ)義結(jié)構(gòu)。
自動(dòng)編碼器VAE
自動(dòng)編碼器 (VAE) 由兩個(gè)主要部分組成:編碼器和解碼器。編碼器會(huì)將圖像轉(zhuǎn)換為低維潛在表示(像素空間–>潛在空間),該表示將作為輸入傳遞給U_Net。解碼器做的事情剛好相反,將潛在表示轉(zhuǎn)換回圖像(潛在空間–>像素空間)。
U-Net
U-Net 也由編碼器和解碼器組成,兩者都由 ResNet 塊組成。編碼器將圖像表示壓縮為較低分辨率的圖像,解碼器將較低分辨率解碼回較高分辨率的圖像。為了防止 U-Net 在下采樣時(shí)丟失重要信息,通常在編碼器的下采樣 ResNet 和解碼器的上采樣 ResNet 之間添加快捷連接。
此外,Stable Diffusion 中的 U-Net 能夠通過(guò)交叉注意力層調(diào)節(jié)其在文本嵌入上的輸出。 交叉注意力層被添加到 U-Net 的編碼器和解碼器部分,通常在 ResNet 塊之間。
文本編碼器
文本編碼器會(huì)將輸入提示轉(zhuǎn)換為 U-Net 可以理解的嵌入空間。一般是一個(gè)簡(jiǎn)單的基于Transformer的編碼器,它將標(biāo)記序列映射到潛在文本嵌入序列。高質(zhì)量的提示(prompt)對(duì)輸出質(zhì)量直觀重要,這就是為什么現(xiàn)在大家這么強(qiáng)調(diào)提示設(shè)計(jì)(prompt design)。提示設(shè)計(jì)就是要找到某些關(guān)鍵詞或表達(dá)方式,讓提示可以觸發(fā)模型產(chǎn)生具有預(yù)期屬性或效果的輸出。
3 Consistency終結(jié)Diffusion
擴(kuò)散模型依賴于迭代生成過(guò)程,這導(dǎo)致此類方法采樣速度緩慢,進(jìn)而限制了它們?cè)趯?shí)時(shí)應(yīng)用中的潛力。
OpenAI 為了克服這個(gè)限制,提出了 Consistency Models,這是一類新的生成模型,無(wú)需對(duì)抗訓(xùn)練
即可快速獲得高質(zhì)量樣本。Consistency Models 支持快速 one-step 生成
,同時(shí)仍然允許 few-step 采樣
,以權(quán)衡計(jì)算量和樣本質(zhì)量。它們還支持零樣本(zero-shot)數(shù)據(jù)編輯
,例如圖像修復(fù)、著色和超分辨率,而無(wú)需針對(duì)這些任務(wù)進(jìn)行具體訓(xùn)練。Consistency Models 可以用蒸餾預(yù)訓(xùn)練擴(kuò)散模型的方式進(jìn)行訓(xùn)練,也可以作為獨(dú)立的生成模型進(jìn)行訓(xùn)練。
Consistency Models 作為一種生成模型,核心設(shè)計(jì)思想是支持 single-step 生成,同時(shí)仍然允許迭代生成,支持零樣本(zero-shot)數(shù)據(jù)編輯,權(quán)衡了樣本質(zhì)量與計(jì)算量
。
首先 Consistency Models 建立在連續(xù)時(shí)間擴(kuò)散模型中的概率流 (PF) 常微分方程 (ODE) 之上。如下圖 所示,給定一個(gè)將數(shù)據(jù)平滑地轉(zhuǎn)換為噪聲的 PF ODE,Consistency Models 學(xué)會(huì)在任何時(shí)間步(time step)將任意點(diǎn)映射成軌跡的初始點(diǎn)以進(jìn)行生成式建模。Consistency Models 一個(gè)顯著的特性是自洽性(self-consistency):同一軌跡上的點(diǎn)會(huì)映射到相同的初始點(diǎn)。這也是模型被命名為 Consistency Models(一致性模型)的原因。
Consistency Models 允許通過(guò)僅使用 one network 評(píng)估轉(zhuǎn)換隨機(jī)噪聲向量(ODE 軌跡的端點(diǎn),例如圖 1 中的 x_T)來(lái)生成數(shù)據(jù)樣本(ODE 軌跡的初始點(diǎn),例如圖 1 中的 x_0)。更重要的是,通過(guò)在多個(gè)時(shí)間步鏈接 Consistency Models 模型的輸出,該方法可以提高樣本質(zhì)量,并以更多計(jì)算為代價(jià)執(zhí)行零樣本數(shù)據(jù)編輯,類似于擴(kuò)散模型的迭代優(yōu)化。
在訓(xùn)練方面,研究團(tuán)隊(duì)為 Consistency Models 提供了兩種基于自洽性的方法。
-
第一種方法依賴于使用數(shù)值 ODE 求解器和預(yù)訓(xùn)練擴(kuò)散模型來(lái)生成 PF ODE 軌跡上的相鄰點(diǎn)對(duì)。通過(guò)最小化這些點(diǎn)對(duì)的模型輸出之間的差異,該研究有效地將擴(kuò)散模型蒸餾為 Consistency Models,從而允許通過(guò) one network 評(píng)估生成高質(zhì)量樣本。
-
第二種方法則是完全消除了對(duì)預(yù)訓(xùn)練擴(kuò)散模型的依賴,可獨(dú)立訓(xùn)練 Consistency Models。這種方法將 Consistency Models 定位為一類獨(dú)立的生成模型。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-491687.html
值得注意的是,這兩種訓(xùn)練方法都不需要對(duì)抗訓(xùn)練,并且都允許 Consistency Models 靈活采用神經(jīng)網(wǎng)絡(luò)架構(gòu)。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-491687.html
到了這里,關(guān)于Stable Diffusion擴(kuò)散模型 + Consistency一致性模型的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!