一、 數(shù)學基礎
擴散模型和一般的機器學習的神經網絡不太一樣!一般的神經網絡旨在構造一個網絡模型來擬合輸入數(shù)據與希望得到的輸出結果,可以把一般的神經網絡當作一個黑盒,這個黑盒通過訓練使其輸入數(shù)據后就可以得到我們想要的結果。而擴散模型包含了大量的統(tǒng)計學和數(shù)學相關的知識,我愿把它看作是數(shù)學和AI完美結合的產物!由于擴散模型相較于普通的深度學習模型,數(shù)學難度大很多,因此學習擴散模型有必要復習(預習)一下相關的數(shù)學知識。
1.1 一般的條件概率形式
p(降溫/下雨)=0.9 : 表示的意思是在“下雨”的條件下,會“降溫”的概率為0.9
p(x∣y)= p(y)p(x,y), p(x,y)=p(x∣y)p(y)=p(y∣x)p(x)
?P(x,y,z)=P(z∣y,x)P(y,x)=P(z∣y,x)P(y∣x)P(x)
P(y,z∣x)= P(x)P(x,y,z) =P(y∣x)P(z∣x,y)
?
1.2 馬爾可夫鏈條件概率形式
馬爾科夫鏈指的是當前狀態(tài)的概率只與上一個時刻有關,例如有A->B->C滿足馬爾可夫關系?
,則有:
?P(x,y,z)=P(z∣y,x)P(y,x)=P(z∣y)P(y∣x)P(x)
P(y,z∣x)=P(y∣x)P(z∣y)
1.3 先驗概率和后驗概率
在介紹先驗概率之前我們先來復習一下全概率公式。
?
可以看出,全概率公式是“由因推果”的思想,當知道某件事的原因后,推斷由某個原因導致這件事發(fā)生的概率為多少。
先驗概率(prior probability):指根據以往經驗和分析。在實驗或采樣前就可以得到的概率。它往往作為“由因求果”問題中的“因”出現(xiàn)。
在介紹后驗概率之前我們先來復習一下貝葉斯公式。
我們可以發(fā)現(xiàn)貝葉斯公式就是一種“由果求因”的思想,當知道某系列些事情的結果后,我們可以根據這類事情推斷出發(fā)這類事情是某種原因的概率。
后驗概率(posterior probability):指某件事已經發(fā)生,想要計算這件事發(fā)生的原因是由某個因素引起的概率。指在得到“結果”的信息后重新修正的概率, 是“執(zhí)果尋因”問題中的“因”。
我們舉個例子來更好理解先驗、后驗概率。
假設我們現(xiàn)在有兩個盒子,分別為紅色和藍色。在紅色盒子中放著2個蘋果和6個橙子,在藍色盒子中放著1個橙子和3個蘋果,如下圖所示:
圖中綠色表示蘋果,橙色代表橙子。
假設我們每次實驗的時候會隨機從某個盒子里挑出一個水果,
隨機變量B(box)表示挑出的是哪個盒子,并且P(B=blue) = 0.6(藍色盒子被選中的概率),P(B=red) = 0.4(紅色盒子被選中的概率)。
隨機變量F(fruit)表示挑中的是哪種水果,F(xiàn)的取值為"a (apple)“和"o (orange)”。
現(xiàn)在假設我們已經得知某次實驗中挑出的水果是orange,那么這個orange是從紅色盒子里挑出的概率是多大呢?依據貝葉斯公式有:
P(F=o)的概率是根據全概率公式算出來的,
P(F=o)=P(B=blue)* P(F=o|B=blue)+P(B=red) P(F=o|B=red)=0.61/4+0.4*3/4=9/20
同時,由概率的加法規(guī)則我們可以得到:
在上面的計算過程中,我們將P(B=red)或者說P(B)稱為先驗概率(prior probability),因為我們在得到F是“a”或者“o”之前,就可以得到P(B)。
同理,將P(B=red|F=o)和P(B=blue|F=o)稱為后驗概率,因為我們在完整的一次實驗之后也就是得到了F的具體取值之后才能得到這個概率。
1.4 重參數(shù)化技巧
若希望從高斯分布N (μ,σ) 中采樣,可以先從標準正態(tài)分布N (0,I)中采樣出z 再得到σ ? z + μ。這樣做的好處是將隨機性轉移到了z 這個常量上,而σ則是仿射變換網絡的一部分。
1.5 KL散度公式
對于兩個單一變量的高斯分布 p ?和 q ?而言,它們的KL散度為:
二、擴散模型的整體邏輯(以DDPM為例)
如上圖所示。DDPM模型主要分為兩個過程:forward加噪過程(從右往左)和reverse去噪過程(從左往右)。加噪過程意思是指向數(shù)據集的真實圖片中逐步加入高斯噪聲,而去噪過程是指對加了噪聲的圖片逐步去噪,從而還原出真實圖片。加噪過程滿足一定的數(shù)學規(guī)律,而去噪過程則采用神經網絡來學習。這么一來,神經網絡就可以從一堆雜亂無章的噪聲圖片中生成真實圖片了。
2.1 Diffusion擴散過程(Forward加噪過程)
這里Forward加噪過程是一個馬爾科夫鏈過程,我們可以看到最終通過不斷的加入噪聲,原始的圖片變成了一個完全混亂的圖片,這個完全混亂的圖片就可以當成一個隨機生成的噪聲圖片。
擴散(Diffusion)在熱力學中指細小顆粒從高密度區(qū)域擴散至低密度區(qū)域,在統(tǒng)計領域,擴散則指將復雜的分布轉換為一個簡單的分布的過程。擴散模型為什么能夠起作用是因為它的一個關鍵性的性質:平穩(wěn)性。一個概率分布如果隨時間變化,那么在馬爾可夫鏈的作用下,它一定會趨于某種平穩(wěn)分布(例如高斯分布)。只要終止時間足夠長,概率分布就會趨近于這個平穩(wěn)分布。
馬爾可夫鏈每一步的轉移概率,本質上都是在加噪聲。這就是擴散模型中“擴散”的由來:噪聲在馬爾可夫鏈演化的過程中,逐漸進入diffusion體系。隨著時間的推移,加入的噪聲(加入的溶質)越來越少,而體系中的噪聲(這個時刻前的所有溶質)逐漸在diffussion體系中擴散,直至均勻。
Diffusion模型定義了一個概率分布轉換模型T(注意:這不是"t ∈ { 1 , 2 , 3… T }"中的T),能將原始數(shù)據x0構成的復雜分布qcomplex 轉換為一個簡單的已知參數(shù)的先驗分布pprior:
具體來說,Diffusion模型提出可以用馬爾科夫鏈(Markov Chain)來構造T,即定義一系列條件概率分布q(x t∣ xt-1 ) t ∈ { 1 , 2 , 3… T } , 將x 0 依次轉換為x 1 、x 2 、x 3 …x T
,希望當T 足夠大時:
為了簡潔和有效,此處的pprior選擇高斯分布,因此整個前向擴散過程可以被看作是,在T步內,不斷添加少量的高斯噪聲到樣本中。
具體來說,在馬爾科夫鏈的每一步,我們向 xt-1添加方差為βt的高斯噪聲,產生一個新的隱變量
xt,其分布為 q(x t∣ xt-1 ) 。這個擴散過程可以表述如下:
由于我們處于多維情況下,I是單位矩陣,表明每個維度有相同的標準偏差 βt。注意到, q(x t∣ xt-1 ) 是一個正態(tài)分布,其均值是 μ t,方差為 ∑ t,其中 ∑是一個對角矩陣的方差(這里就是 βt )。
因此,我們可以自 x 0 到 x T 以一種可操作的方式來近似輸入。在數(shù)學上,這種后驗概率定義如下:
其中, x 1 :T 意味著我們從時間 1 到 T 重復應用 q(x t∣ xt-1 ) 。
這種累乘的方式過于繁瑣,利用重參數(shù)化技巧,可以得到:
β不斷增大,論文中是0.0001~0.002,所以之后α越來越小。則:當前向時刻越往后,噪音影響的權重越來越大,z是服從高斯分布的噪音,當 t 趨近于正無窮時, x t等同于各向同性的高斯分布。
這樣我們就可以直接得到任意時刻的 x t 。
2.2 逆向過程(reverse去噪過程)
Diffusion Model的逆向過程就是與正向Forward加噪過程相反不斷去除圖像中的噪聲的過程。不幸的是, q(x t∣ xt-1 ) 雖然知道但是 q(x t-1∣ xt ) 卻是未知的。但有相關研究表明:連續(xù)擴散過程的逆轉具有與正向過程相同的分布形式。即,即當擴散率βt足夠小,擴散次數(shù)足夠多時,離散擴散過程接近于連續(xù)擴散過程 q(x t∣ xt-1 ) 的分布形式同 q(x t-1∣ xt ) 一致,同樣是高斯分布。
盡管如此,我們依然不能夠直接得到 q(x t-1∣ xt ),因此我們就需要學習一個網絡模型 p(x t-1∣ xt)擬合 q(x t-1∣ xt ):
在DDPM中不學習方差,方差設置為βt。
這樣,逆向過程中高斯的后驗概率定義為:
使用貝葉斯公式可以得到:
利用公式:
將上面由貝葉斯公式得到的結果湊成高斯分布概率密度的形式:
因此,我們可以得到q(x t-1∣ xt ,x 0)的高斯概率密度表示為:
用x t替換x 0得:
到此,我們在逆向過程中的目標就變成了拉近以下兩個高斯分布的距離,這可以通過計算兩個分布的KL散度實現(xiàn),其中q(x t-1∣ xt ,x 0)的均值和方差都是已知的:
這就是我們訓練網絡的損失函數(shù)。
三、訓練過程和采樣過程
我們重新梳理一下擴散模型的整個流程。
前向傳播過程(q過程):從x0開始不斷加入噪聲到 xt, xt只是一個帶有噪聲的圖片,再逐漸加入更多噪聲,到 xT的時候圖片已經完全變成一個噪聲圖片了。
逆向過程(p過程):在一張完全混亂的噪聲圖片當中不斷拿去剛剛加入的噪聲,讓其變得不混亂,逐步更加接近真實圖片,就可以得到最開始的圖片。
前向過程是一個完全的馬爾科夫鏈加入噪聲過程實通過固定計算完成的,逆向過程里面如何預測噪聲就成了我們的關鍵需求,人是算不出來的,所以我們需要借助網絡來幫忙。
3.1 訓練過程
我們在逆向降噪過程中由于沒辦法得到q(x t-1∣ xt ),因此定義了一個 需要學習的模型p(x t-1∣ xt ) 來對其進行近似,并且在訓練階段我們可以利用后驗q(x t-1∣ xt ,x 0)來對p進行優(yōu)化(就是計算損失不斷訓練的過程)。
那么,要怎么優(yōu)化這個p呢?即如何訓練模型預測到靠譜的均值和方差根據分布進行計算呢?
我們可以最大化模型預測分布的對數(shù)似然,優(yōu)化模型真實分布和預測分布的交叉熵,優(yōu)化 x0 ~ q(x 0)下的 Pθ(x0)交叉熵:
使用變分下限優(yōu)化負對數(shù)似然,因為KL散度非負:
上式中q (x0)是真實的數(shù)據分布,而Pθ(x0)是模型。
為了最小化這個損失,可以將其轉化為最小化其上界LVLB:
由于前向q沒有可學習參數(shù),而xT則是純高斯噪聲,LT可以當做常量忽略。因此我們只要研究L0和Lt(t和t-1其實意思是一樣的)。
Lt可以看作是拉近2個高斯分布q(x t-1∣ xt ,x 0)和p(x t-1∣ xt ) ,可以根據多元高斯分布的KL散度求解:
把前面得到的公式:
代入得:
我們可以看出,擴散模型訓練的核心是學習真實噪聲 zt和預測噪聲z θ的均方誤差MSE, DDPM (Ho et al 2020)使用了不帶權重項的簡化損失, 使得訓練更加穩(wěn)定:
其中C是一個常數(shù)。
對于L0:
因為:
實際上L0是一個多元高斯分布的負對數(shù)似然期望,即其熵:
多元高斯分布的熵僅與其協(xié)方差有關,即L0僅與σ1^2 I有關,L0是個常數(shù)。
綜上,擴散模型(DDPM)的訓練過程可以看做是最小化預測噪聲和真實采樣的?之間的距離的過程。
DDPM論文里面訓練過程的偽代碼如下:
可以理解為:
重復這一過程直到網絡收斂。
3.2 采樣過程
DDPM論文中對采樣過程的描述:
因為我們通過訓練已經得到了一個用于擬合 q(x t-1∣ xt )的網絡p(x t-1∣ xt),因此我們可以從 xT一步步得到 x0。具體的步驟可以為:
3.3 模型訓練的一些細節(jié)
3.3.1 網絡的選擇
擴散模型的網絡的輸入和輸出都是同等規(guī)格的,因此理論上只要網絡的input的規(guī)格和output規(guī)格一樣就可以。比如你可以選擇Unet作為擬合的網絡:
3.3.2 一些超參數(shù)的選擇
在前向傳播的過程中,我們不知道噪聲到底要添加到什么時候才合適,每次添加噪聲的方差怎么設置也是很重要。這些都需要不斷的嘗試調優(yōu)才能得到。
DDPM中T設置為1000,βt被設置為從β1 = 0.0001到βT=0.02線性增加。當然別的擴散模型也有不同的策略,只要能夠調試網絡到最好就是最好的方法。不同任務不同的網絡策略可能也會不同。
四、DDPM案例代碼實現(xiàn)
為了更好的掌握擴散模型的工作過程,我參考網上的代碼一步步編寫調試了一個簡單擴散模型案例-
DDPM S_curve
4.1 數(shù)據集準備
這里需要注意的是,這里的整個數(shù)據集就是上面可視化的這張圖片中的點,一共有10000個數(shù)據,每個數(shù)據就是構成上面這張圖中S的一個個點,一共有10000個點,這些點滿足這是“s”形的分布。
構建數(shù)據集的代碼:
import numpy as np
from sklearn.datasets import make_s_curve
import torch
s_curve,_ = make_s_curve(10**4, noise=0.1)
s_curve = s_curve[:, [0, 2]]/10.0 # 得到的是一個三維的點我們只需要二維的
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = torch.Tensor(s_curve).float().to(device)
對數(shù)據集進行可視化:
data = s_curve.T
fig,ax = plt.subplots()
ax.scatter(*data,color='blue',edgecolor='white');
ax.axis('off')
plt.show()
4.2 前向傳播過程
先確定兩個超參數(shù)β(betas)和T(num_steps),我們T設置為100,β先從(-6,6)取100個數(shù),然后用sigmoid得到100個非線性增加的數(shù)。
num_steps = 100
betas = torch.linspace(-6, 6, num_steps).to(device)
betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5)+1e-5
提前計算好前向傳播公式中需要用到的表達:
alphas = 1-betas
alphas_prod = torch.cumprod(alphas, dim=0)
αt-1
alphas_prod_p = torch.cat([torch.tensor([1]).float().to(device),alphas_prod[:-1]],0)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
根據公式
編寫一個可以獲得任意時刻t的狀態(tài)圖Xt的函數(shù):
def q_x(x_0,t):
noise = torch.randn_like(x_0).to(device) #隨機獲得的一個和x_0一樣規(guī)格的噪聲
alphas_t = alphas_bar_sqrt[t]
alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
return (alphas_t * x_0 + alphas_1_m_t * noise)#在x[0]的基礎上添加噪聲
可視化每5步添加噪聲后的數(shù)據集:
num_shows = 20
fig,axs = plt.subplots(2,10,figsize=(28,3))
plt.rc('text',color='black')
for i in range(num_shows):
j = i//10
k = i%10
q_i = q_x(dataset, torch.tensor([i*num_steps//num_shows]).to(device))#生成t時刻的采樣數(shù)據
q_i = q_i.to('cpu')
axs[j,k].scatter(q_i[:,0],q_i[:,1],color='red',edgecolor='white')
axs[j,k].set_axis_off()
axs[j,k].set_title('$q(\mathbf{x}_{'+str(i*num_steps//num_shows)+'})$')
定義損失函數(shù):
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
"""對任意時刻t進行采樣計算loss"""
batch_size = x_0.shape[0]
# 對一個batchsize樣本生成隨機的時刻t
t = torch.randint(0, n_steps, size=(batch_size // 2,)).to(device)
t = torch.cat([t, n_steps - 1 - t], dim=0)
t = t.unsqueeze(-1)
# x0的系數(shù)
a = alphas_bar_sqrt[t]
# eps的系數(shù)
aml = one_minus_alphas_bar_sqrt[t]
# 生成隨機噪音eps
e = torch.randn_like(x_0).to(device)
# 構造模型的輸入
x = x_0 * a + e * aml
# 送入模型,得到t時刻的隨機噪聲預測值
output = model(x, t.squeeze(-1))
# 與真實噪聲一起計算誤差,求平均值
return (e - output).square().mean()
該損失函數(shù)計算的就是網絡預測的噪聲與真實噪聲的損失。x = x_0 * a + e * aml就是公式:
4.3 逆向過程(模型訓練過程)
這里需要定義一個從XT恢復到X0的一個函數(shù):
def p_sample_loop(model,shape,n_steps,betas,one_minus_alphas_bar_sqrt):
"""從x[T]恢復x[T-1]、x[T-2]|...x[0]"""
cur_x = torch.randn(shape).to(device)
x_seq = [cur_x]
for i in reversed(range(n_steps)):
cur_x = p_sample(model,cur_x,i,betas,one_minus_alphas_bar_sqrt)
x_seq.append(cur_x)
return x_seq
def p_sample(model,x,t,betas,one_minus_alphas_bar_sqrt):
"""從x[T]采樣t時刻的重構值"""
t = torch.tensor([t]).to(device)
coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
eps_theta = model(x,t)
mean = (1/(1-betas[t]).sqrt())*(x-(coeff*eps_theta))
z = torch.randn_like(x).to(device)
sigma_t = betas[t].sqrt()
sample = mean + sigma_t * z
return (sample)
然后就是定義一個網絡模型用于擬合q,這里定義的全是線性層連接的網絡:
這里回推運用到了一個公式:
# 定義擬合的網絡
class MLPDiffusion(nn.Module):
def __init__(self, n_steps, num_units=128):
super(MLPDiffusion, self).__init__()
self.linears = nn.ModuleList(
[
nn.Linear(2, num_units),
nn.ReLU(),
nn.Linear(num_units, num_units),
nn.ReLU(),
nn.Linear(num_units, num_units),
nn.ReLU(),
nn.Linear(num_units, 2),
]
)
self.step_embeddings = nn.ModuleList(
[
nn.Embedding(n_steps, num_units),
nn.Embedding(n_steps, num_units),
nn.Embedding(n_steps, num_units),
]
)
def forward(self, x, t):
# x = x_0
for idx, embedding_layer in enumerate(self.step_embeddings):
t_embedding = embedding_layer(t)
x = self.linears[2 * idx](x)
x += t_embedding
x = self.linears[2 * idx + 1](x)
x = self.linears[-1](x)
return x
最后就是常規(guī)的網絡訓練過程,我們的batch_size設置為128,訓練4000個輪次,因為網絡很簡單,我的電腦不到20分鐘就訓練完了。過程中每100輪次可視化一次。文章來源:http://www.zghlxwxcb.cn/news/detail-596923.html
seed = 1234
print('Training model...')
batch_size = 128
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
num_epoch = 4000
plt.rc('text',color='blue')
model = MLPDiffusion(num_steps)#輸出維度是2,輸入是x和step
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
for t in range(num_epoch):
for idx,batch_x in enumerate(dataloader):
loss = diffusion_loss_fn(model,batch_x,alphas_bar_sqrt,one_minus_alphas_bar_sqrt,num_steps)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(),1.)
optimizer.step()
if(t%100==0):
print(loss)
x_seq = p_sample_loop(model,dataset.shape,num_steps,betas,one_minus_alphas_bar_sqrt)
x_seq = [item.to('cpu') for item in x_seq]
fig,axs = plt.subplots(1,10,figsize=(28,3))
for i in range(1,11):
cur_x = x_seq[i*10].detach()
axs[i-1].scatter(cur_x[:,0],cur_x[:,1],color='red',edgecolor='white');
axs[i-1].set_axis_off();
axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$')
下面展示的是訓練過程中的一部分可視化輸出:
epoch=0
epoch=200
epoch =600
epoch=1500
epoch =3000
epoch = 4000
參考文獻
[1]: https://zhuanlan.zhihu.com/p/415487792
[2]: https://zhuanlan.zhihu.com/p/499206074
[3]: https://blog.csdn.net/weixin_42363544/article/details/127495570
[4]:https://blog.csdn.net/weixin_43850253/article/details/128275723
[5]:Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models[J]. Advances in Neural Information Processing Systems, 2020, 33: 6840-6851.文章來源地址http://www.zghlxwxcb.cn/news/detail-596923.html
到了這里,關于擴散模型原理+DDPM案例代碼解析的文章就介紹完了。如果您還想了解更多內容,請在右上角搜索TOY模板網以前的文章或繼續(xù)瀏覽下面的相關文章,希望大家以后多多支持TOY模板網!