1 從AE談起
說到編碼器這塊,不可避免地要講起AE(AutoEncoder)自編碼器。它的結(jié)構(gòu)下圖所示:
據(jù)圖可知,AE通過自監(jiān)督的訓(xùn)練方式,能夠?qū)⑤斎氲脑继卣魍ㄟ^編碼encoder后得到潛在的特征編碼,實現(xiàn)了自動化的特征工程,并且達到了降維和泛化的目的。而后通過對進行decoder后,我們可以重構(gòu)輸出。一個良好的AE最好的狀態(tài)就是解碼器的輸出能夠完美地或者近似恢復(fù)出原來的輸入, 即。為此,訓(xùn)練AE所需要的損失函數(shù)是:
∣
∣
x
?
x
^
∣
∣
||x-\hat{x}||
∣∣x?x^∣∣
AE的重點在于編碼,而解碼的結(jié)果,基于訓(xùn)練目標,如果損失足夠小的話,將會與輸入相同。從這一點上看解碼的值沒有任何實際意義,除了通過增加誤差來補充平滑一些初始的零值或有些許用處。
易知,從輸入到輸出的整個過程,AE都是基于已有的訓(xùn)練數(shù)據(jù)的映射,盡管隱藏層的維度通常比輸入層小很多,但隱藏層的概率分布依然只取決于訓(xùn)練數(shù)據(jù)的分布,這就導(dǎo)致隱藏狀態(tài)空間的分布并不是連續(xù)的,它只是稀疏地記錄下來你的輸入樣本和生成圖像的一一對應(yīng)關(guān)系。 因此如果我們隨機生成隱藏層的狀態(tài),那么它經(jīng)過解碼將很可能不再具備輸入特征的特點,因此想通過解碼器來生成數(shù)據(jù)就有點強模型所難了。
如下圖所示,僅通過AE,我們在碼空間里隨機采樣的點并不能生成我們所希望的相應(yīng)圖像。這就使得我的不能夠達到AIGC的效果。
據(jù)此,我們對AE的隱藏層z作出改動(讓隱空間連續(xù)光滑),得到了VAE。
2 變分自編碼器(Variational AutoEncoder,VAE)
關(guān)于變分推斷,請查看本人的另一篇博文:變分推斷(Variational Inference)
這里只做一個總結(jié):
- 變分推斷是使用另一個分布 q ( z ) q(z) q(z)近似 p ( z ∣ x ) p(z|x) p(z∣x)
- 用KL距離衡量分布的近似程度: K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) KL(q(z)||p(z|x)) KL(q(z)∣∣p(z∣x)),所以最優(yōu)的 q ? ( z ) = a r g m i n q ( z ) ∈ Q K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) q^*(z)=argmin_{q(z) \in Q}KL(q(z)||p(z|x)) q?(z)=argminq(z)∈Q?KL(q(z)∣∣p(z∣x))
- 對 K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) KL(q(z)||p(z|x)) KL(q(z)∣∣p(z∣x))的最小化轉(zhuǎn)化為對ELBO的最大化,也就是 q ? ( z ) = a r g m i n q ( z ) ∈ Q K L ( q ( z ) ∣ ∣ p ( z ∣ x ) ) = a r g m a x q ( z ) ∈ Q E L B O = a r g m a x q ( z ) ∈ Q E q ( l o g ( p ( x , z ) ? l o g q ( z ) ) ) q^*(z)=argmin_{q(z) \in Q}KL(q(z)||p(z|x))=argmax_{q(z)\in Q}ELBO=argmax_{q(z)\in Q}E_q(log(p(x,z)-logq(z))) q?(z)=argminq(z)∈Q?KL(q(z)∣∣p(z∣x))=argmaxq(z)∈Q?ELBO=argmaxq(z)∈Q?Eq?(log(p(x,z)?logq(z)))
VAE全稱是Variational AutoEncoder,即變分自編碼器。
在VAE中 q ( z ) q(z) q(z)用一個編碼器神經(jīng)網(wǎng)絡(luò)表示,假如其參數(shù)是 θ \theta θ,那么我們用 q θ ( z ) q_{\theta}(z) qθ?(z)或者 q θ ( z ∣ x ) q_{\theta}(z|x) qθ?(z∣x)表示。 p ( z ∣ x ) p(z|x) p(z∣x)可以認為是自然界真實存在的一個概率分布,但是我們不知道,所以需要用一個神經(jīng)網(wǎng)絡(luò)把他近似出來。
2.1 VAE的目的
VAE的目的:
(1)用神經(jīng)網(wǎng)絡(luò)去逼近和模擬
p
(
z
∣
x
)
p(z|x)
p(z∣x)近似
p
(
x
∣
z
)
p(x|z)
p(x∣z)這兩個概率分布
(2)并盡量保證隱空間是連續(xù)和平滑的,即
p
(
z
)
p(z)
p(z)和
p
(
z
∣
x
)
p(z|x)
p(z∣x)是平滑的
2.2 VAE方法與損失函數(shù)
作者方法“
(1)定義:
p
(
z
)
~
N
(
0
,
1
)
p(z) \sim N(0,1)
p(z)~N(0,1)
(2)定義:
q
θ
(
z
∣
x
)
~
N
(
g
(
x
)
,
h
(
x
)
)
q_{\theta}(z|x) \sim N(g(x),h(x))
qθ?(z∣x)~N(g(x),h(x)),也就是
q
θ
(
z
∣
x
)
q_{\theta}(z|x)
qθ?(z∣x)的期望和方差是用兩個神經(jīng)網(wǎng)絡(luò)計算出來的
(3)定義:
p
θ
′
(
x
∣
z
)
~
N
(
f
(
z
)
,
c
I
)
p_{\theta'}(x|z) \sim N(f(z),cI)
pθ′?(x∣z)~N(f(z),cI),所以解碼器的輸出的是
p
θ
′
(
x
∣
z
)
p_{\theta'}(x|z)
pθ′?(x∣z)的期望
這樣直接定義好嗎?為這么直接這樣定義出來?看下面的一個slide
對ELBO做一個推導(dǎo):
因為
p
(
x
∣
z
)
=
1
2
π
c
e
∣
∣
x
?
f
(
z
)
∣
∣
2
2
c
p(x|z) = \frac{1}{\sqrt{2\pi c}}e^{\frac{||x-f(z)||^2}{2c}}
p(x∣z)=2πc?1?e2c∣∣x?f(z)∣∣2?,所以有:
也就是找到這樣的三個神經(jīng)網(wǎng)絡(luò)使得上面的式子最大。
對于上面的第二項:
所以損失函數(shù)可以寫成:
l
o
s
s
=
1
2
(
?
l
o
g
h
(
x
)
2
+
h
(
x
)
2
+
g
(
x
)
2
?
1
)
+
C
∣
∣
x
?
f
(
z
)
∣
∣
2
loss=\frac{1}{2}(-logh(x)^2+h(x)^2+g(x)^2-1)+C||x-f(z)||^2
loss=21?(?logh(x)2+h(x)2+g(x)2?1)+C∣∣x?f(z)∣∣2
2.3 重參數(shù)技巧
從高斯分布
N
(
μ
,
σ
)
N(μ,σ)
N(μ,σ)中采樣的操作被巧妙轉(zhuǎn)換為了從
N
(
0
,
1
)
N(0,1)
N(0,1)中采樣得到
?
?
?后,再通過
z
=
μ
+
σ
×
?
z=μ+σ \times ?
z=μ+σ×?變換得到。
而在重參數(shù)后,我們計算反向傳播的過程 如下圖所示:
2.4 整合起來
(1)從樣本庫中取圖片x
(2)g(x)計算均值,h(x)計算方差,從標準正太分布中采樣一個數(shù)
ζ
\zeta
ζ,然后計算
z
=
ζ
h
(
x
)
+
g
(
x
)
z=\zeta h(x)+g(x)
z=ζh(x)+g(x),然后計算
f
(
z
)
f(z)
f(z)
(3)計算損失
(4)反向傳播文章來源:http://www.zghlxwxcb.cn/news/detail-801258.html
3 代碼實現(xiàn)
3.1 VAE.py
import torch
from torch import nn
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
# [b, 784] =>[b,20]
# u: [b, 10]
# sigma: [b, 10]
self.encoder = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 20),
nn.ReLU()
)
# [b,10] => [b, 784]
# sigmoid函數(shù)把結(jié)果壓縮到0~1
self.decoder = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 784),
nn.Sigmoid()
)
def forward(self, x):
"""
:param x:
:return:
"""
batchsz = x.size(0)
# flatten
x = x.view(batchsz, 784)
# encoder
# [b, 20], including mean and sigma
h_ = self.encoder(x)
# chunk 在第二維上拆分成兩部分
# [b, 20] => [b,10] and [b, 10]
mu, sigma = h_.chunk(2, dim=1)
# reparametrize tirchk, epison~N(0, 1)
# torch.randn_like(sigma)表示正態(tài)分布
h = mu + sigma * torch.randn_like(sigma)
# decoder
x_hat = self.decoder(h)
# reshape
x_hat = x_hat.view(batchsz, 1, 28, 28)
# KL
# 1e-8是防止σ^2接近于零時該項負無窮大
# (batchsz*28*28)是讓kld變小
kld = 0.5 * torch.sum(
torch.pow(mu, 2) +
torch.pow(sigma, 2) -
torch.log(1e-8 + torch.pow(sigma, 2)) - 1
) / (batchsz*28*28)
return x, kld
3.2 main.py
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision import transforms, datasets
from ae_1 import AE
from vae import VAE
from vq-vae import VQVAE
import visdom
def main():
mnist_train = datasets.MNIST('mnist', True, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)
mnist_test = datasets.MNIST('mnist', False, transform=transforms.Compose([
transforms.ToTensor()
]), download=True)
mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)
#無監(jiān)督學(xué)習(xí),不能使用label
x, _ = iter(mnist_train).next()
print('x:', x.shape)
device = torch.device('cuda')
#model = AE().to(device)
model = VAE().to(device)
criteon = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
viz = visdom.Visdom()
for epoch in range(1000):
for batchidx, (x, _) in enumerate(mnist_train):
# [b, 1, 28, 28]
x = x.to(device)
x_hat, kld = model(x)
loss = criteon(x_hat, x)
if kld is not None:
elbo = - loss - 1.0 * kld
loss = - elbo
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, 'loss', loss.item(), kld.item())
x, _ = iter(mnist_test).next()
x = x.to(device)
with torch.no_grad():
x_hat = model(x)
# nrow表示一行的圖片
viz.images(x, nrow=8, win='x', optis=dic(title='x'))
iz.images(x_hat, nrow=8, win='x_hat', optis=dic(title='x_hat'))
if __name__ == '__main__':
main()
參考
講解變分自編碼器-VAE(附代碼)
VAE到底在做什么?VAE原理講解系列#1
VAE的神經(jīng)網(wǎng)絡(luò)是如何搭建的?VAE原理講解系列#3
從零推導(dǎo):變分自編碼器(VAE)文章來源地址http://www.zghlxwxcb.cn/news/detail-801258.html
到了這里,關(guān)于變分自編碼器(Variational AutoEncoder,VAE)的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!