學(xué)習前言
用了很久的Stable Diffusion,但從來沒有好好解析過它內(nèi)部的結(jié)構(gòu),寫個博客記錄一下,嘿嘿。
源碼下載地址
https://github.com/bubbliiiing/stable-diffusion
喜歡的可以點個star噢。
網(wǎng)絡(luò)構(gòu)建
一、什么是Stable Diffusion(SD)
Stable Diffusion是比較新的一個擴散模型,翻譯過來是穩(wěn)定擴散,雖然名字叫穩(wěn)定擴散,但實際上換個seed生成的結(jié)果就完全不一樣,非常不穩(wěn)定哈。
Stable Diffusion最開始的應(yīng)用應(yīng)該是文本生成圖像,即文生圖,隨著技術(shù)的發(fā)展Stable Diffusion不僅支持image2image圖生圖的生成,還支持ControlNet等各種控制方法來定制生成的圖像。
Stable Diffusion基于擴散模型,所以不免包含不斷去噪的過程,如果是圖生圖的話,還有不斷加噪的過程,此時離不開DDPM那張老圖,如下:
Stable Diffusion相比于DDPM,使用了DDIM采樣器,使用了隱空間的擴散,另外使用了非常大的LAION-5B數(shù)據(jù)集進行預(yù)訓(xùn)練。
直接Finetune Stable Diffusion大多數(shù)同學(xué)應(yīng)該是無法cover住成本的,不過Stable Diffusion有很多輕量Finetune的方案,比如Lora、Textual Inversion等,但這是后話。
本文主要是解析一下整個SD模型的結(jié)構(gòu)組成,一次擴散,多次擴散的流程。
大模型、AIGC是當前行業(yè)的趨勢,不會的話容易被淘汰,hh。
txt2img的原理如博文
Diffusion擴散模型學(xué)習2——Stable Diffusion結(jié)構(gòu)解析-以文本生成圖像(txt2img)為例
所示。
二、Stable Diffusion的組成
Stable Diffusion由四大部分組成。
1、Sampler采樣器。
2、Variational Autoencoder (VAE) 變分自編碼器。
3、UNet 主網(wǎng)絡(luò),噪聲預(yù)測器。
4、CLIPEmbedder文本編碼器。
每一部分都很重要,我們以圖像生成圖像為例進行解析。既然是圖像生成圖像,那么我們的輸入有兩個,一個是文本,另外一個是圖片。
三、img2img生成流程
生成流程分為四個部分:
1、對圖片進行VAE編碼,根據(jù)denoise數(shù)值進行加噪聲。
2、Prompt文本編碼。
3、根據(jù)denoise數(shù)值進行若干次采樣。
4、使用VAE進行解碼。
相比于文生圖,圖生圖的輸入發(fā)生了變化,不再以Gaussian noise作為初始化,而是以加噪后的圖像特征為初始化,這樣便以圖像的方式為模型注入了信息。
詳細來講,如上圖所示:
- 第一步為對輸入的圖像利用VAE編碼,獲得輸入圖像的Latent特征;然后使用該Latent特征基于DDIM Sampler進行加噪,此時獲得輸入圖片加噪后的特征。假設(shè)我們設(shè)置denoise數(shù)值為0.8,總步數(shù)為20步,那么第一步中,我們會對輸入圖片進行0.8x20次的加噪聲,剩下4步不加,可理解為打亂了80%的特征,保留20%的特征。
- 第二步是對輸入的文本進行編碼,獲得文本特征;
- 第三步是根據(jù)denoise數(shù)值對 第一步中獲得的 加噪后的特征 進行若干次采樣。還是以第一步中denoise數(shù)值為0.8為例,我們只加了0.8x20次噪聲,那么我們也只需要進行0.8x20次采樣就可以恢復(fù)出圖片了。
- 第四步是將采樣后的圖片利用VAE的Decoder進行恢復(fù)。
with torch.no_grad():
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
# ----------------------- #
# 對輸入圖片進行編碼并加噪
# ----------------------- #
if image_path is not None:
img = HWC3(np.array(img, np.uint8))
img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
img = torch.stack([img for _ in range(num_samples)], dim=0)
img = einops.rearrange(img, 'b h w c -> b c h w').clone()
ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
z = model.get_first_stage_encoding(model.encode_first_stage(img))
z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
# ----------------------- #
# 獲得編碼后的prompt
# ----------------------- #
cond = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
H, W = input_shape
shape = (4, H // 8, W // 8)
if image_path is not None:
samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
else:
# ----------------------- #
# 進行采樣
# ----------------------- #
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
# ----------------------- #
# 進行解碼
# ----------------------- #
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
1、輸入圖片編碼
在圖生圖中,我們首先要指定一張參考的圖像,然后在這個參考圖像上開始工作:
1、利用VAE編碼器對這張參考圖像進行編碼,使其進入隱空間,只有進入了隱空間,網(wǎng)絡(luò)才知道這個圖像是什么;
2、然后使用該Latent特征基于DDIM Sampler進行加噪,此時獲得輸入圖片加噪后的特征。加噪的邏輯如下:
- denoise可認為是重建的比例,1代表全部重建,0代表不重建;
- 假設(shè)我們設(shè)置denoise數(shù)值為0.8,總步數(shù)為20步;我們會對輸入圖片進行0.8x20次的加噪聲,剩下4步不加,可理解為80%的特征,保留20%的特征;不過就算加完20步噪聲,原始輸入圖片的信息還是有一點保留的,不是完全不保留。
此時我們便獲得在隱空間加噪后的圖像,后續(xù)會在這個 隱空間加噪后的圖像 的基礎(chǔ)上進行采樣。
with torch.no_grad():
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
# ----------------------- #
# 對輸入圖片進行編碼并加噪
# ----------------------- #
if image_path is not None:
img = HWC3(np.array(img, np.uint8))
img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
img = torch.stack([img for _ in range(num_samples)], dim=0)
img = einops.rearrange(img, 'b h w c -> b c h w').clone()
ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
z = model.get_first_stage_encoding(model.encode_first_stage(img))
z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
2、文本編碼
文本編碼的思路比較簡單,直接使用CLIP的文本編碼器進行編碼就可以了,在代碼中定義了一個FrozenCLIPEmbedder類別,使用了transformers庫的CLIPTokenizer和CLIPTextModel。
在前傳過程中,我們對輸入進來的文本首先利用CLIPTokenizer進行編碼,然后使用CLIPTextModel進行特征提取,通過FrozenCLIPEmbedder,我們可以獲得一個[batch_size, 77, 768]的特征向量。
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
# 定義文本的tokenizer和transformer
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
# 凍結(jié)模型參數(shù)
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
# 對輸入的圖片進行分詞并編碼,padding直接padding到77的長度。
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
# 拿出input_ids然后傳入transformer進行特征提取。
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
# 取出所有的token
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
3、采樣流程
a、生成初始噪聲
在圖生圖中,我們的初始噪聲獲取于參考圖片,所以參考第一步就可以獲得圖生圖的噪聲
b、對噪聲進行N次采樣
既然Stable Diffusion是一個不斷擴散的過程,那么少不了不斷的去噪聲,那么怎么去噪聲便是一個問題。
在上一步中,我們已經(jīng)獲得了一個圖生圖的噪聲,它是一個符合正態(tài)分布的向量,我們便從它開始去噪聲。
我們會對ddim_timesteps的時間步取反,因為我們現(xiàn)在是去噪聲而非加噪聲,然后對其進行一個循環(huán),由于我們此時不再是txt2img中的采樣流程,我們使用sampler的另外一個方法decode,循環(huán)的代碼如下:
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
# 使用ddim的時間步
# 這里內(nèi)容看起來很多,但是其實很少,本質(zhì)上就是取了self.ddim_timesteps,然后把它reversed一下
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
# 進行單次采樣
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec
c、單次采樣解析
I、預(yù)測噪聲
在進行單詞采樣前,需要首先判斷是否有neg prompt,如果有,我們需要同時處理neg prompt,否則僅僅需要處理pos prompt。實際使用的時候一般都有neg prompt(效果會好一些),所以默認進入對應(yīng)的處理過程。
在處理neg prompt時,我們對輸入進來的隱向量和步數(shù)進行復(fù)制,一個屬于pos prompt,一個屬于neg prompt。torch.cat默認堆疊維度為0,所以是在batch_size維度進行堆疊,二者不會互相影響。然后我們將pos prompt和neg prompt堆疊到一個batch中,也是在batch_size維度堆疊。
# 首先判斷是否由neg prompt,unconditional_conditioning是由neg prompt獲得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
# 一般都是有neg prompt的,所以進入到這里
# 在這里我們對隱向量和步數(shù)進行復(fù)制,一個屬于pos prompt,一個屬于neg prompt
# torch.cat默認堆疊維度為0,所以是在bs維度進行堆疊,二者不會互相影響
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
# 然后我們將pos prompt和neg prompt堆疊到一個batch中
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
堆疊完后,我們將隱向量、步數(shù)和prompt條件一起傳入網(wǎng)絡(luò)中,將結(jié)果在bs維度進行使用chunk進行分割。
因為我們在堆疊時,neg prompt放在了前面。因此分割好后,前半部分e_t_uncond
屬于利用neg prompt得到的,后半部分e_t
屬于利用pos prompt得到的,我們本質(zhì)上應(yīng)該擴大pos prompt的影響,遠離neg prompt的影響。因此,我們使用e_t-e_t_uncond
計算二者的距離,使用scale擴大二者的距離。在e_t_uncond基礎(chǔ)上,得到最后的隱向量。
# 堆疊完后,隱向量、步數(shù)和prompt條件一起傳入網(wǎng)絡(luò)中,將結(jié)果在bs維度進行使用chunk進行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
此時獲得的e_t就是通過隱向量和prompt共同獲得的預(yù)測噪聲啦。
II、施加噪聲
獲得噪聲就OK了嗎?顯然不是的,我們還要將獲得的新噪聲,按照一定的比例添加到原來的原始噪聲上。
這個地方我們最好結(jié)合ddim中的公式來看,我們需要獲得
α
ˉ
t
\bar{\alpha}_t
αˉt?、
α
ˉ
t
?
1
\bar{\alpha}_{t-1}
αˉt?1?、
σ
t
\sigma_t
σt?、
1
?
α
ˉ
t
\sqrt{1-\bar{\alpha}_t}
1?αˉt??。
代碼中,我們其實已經(jīng)預(yù)先計算好了這些參數(shù)。我們只需要直接取出即可,下方的a_t也就是公式中括號外的
α
ˉ
t
\bar{\alpha}_t
αˉt?,a_prev 就是公式中的
α
ˉ
t
?
1
\bar{\alpha}_{t-1}
αˉt?1?,sigma_t就是公式中的
σ
t
\sigma_t
σt?,sqrt_one_minus_at就是公式中的
1
?
α
ˉ
t
\sqrt{1-\bar{\alpha}_t}
1?αˉt??。
# 根據(jù)采樣器選擇參數(shù)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# 根據(jù)步數(shù)選擇參數(shù),
# 這里的index就是上面循環(huán)中的total_steps - i - 1
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
其實這一步我們只是把公式需要用到的系數(shù)全都拿了出來,方便后面的加減乘除。然后我們便在代碼中實現(xiàn)上述的公式。
# current prediction for x_0
# 公式中的最左邊
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中間
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右邊
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 輸出添加完公式的結(jié)果
return x_prev, pred_x0
d、預(yù)測噪聲過程中的網(wǎng)絡(luò)結(jié)構(gòu)解析
I、apply_model方法解析
在3.a的預(yù)測噪聲過程中,我們使用了model.apply_model方法進行噪聲的預(yù)測,這個方法具體做了什么被隱掉了,我們看看具體做的工作。
apply_model方法在ldm.models.diffusion.ddpm.py文件中。在apply_model中,我們將x_noisy傳入self.model中預(yù)測噪聲。
x_recon = self.model(x_noisy, t, **cond)
self.model是一個預(yù)先構(gòu)建好的類,定義在ldm.models.diffusion.ddpm.py文件的1416行,內(nèi)部包含Stable Diffusion的Unet網(wǎng)絡(luò),self.model的功能有點類似于包裝器,根據(jù)模型選擇的特征融合方式,進行文本與上文生成的噪聲的融合。
c_concat代表使用堆疊的方式進行融合,c_crossattn代表使用attention的方式融合。
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
# stable diffusion的unet網(wǎng)絡(luò)
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1)
else:
cc = c_crossattn
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
else:
raise NotImplementedError()
return out
代碼中的self.diffusion_model便是Stable Diffusion的Unet網(wǎng)絡(luò),網(wǎng)絡(luò)結(jié)構(gòu)位于ldm.modules.diffusionmodules.openaimodel.py文件中的UNetModel類。
II、UNetModel模型解析
UNetModel主要做的工作是結(jié)合時間步t和文本Embedding計算這一時刻的噪聲。盡管UNet的思路非常簡單,但是在StableDiffusion中,UNetModel由ResBlock和Transformer模塊組成,整體來講相比于普通的UNet復(fù)雜一些。
Prompt通過Frozen CLIP Text Encoder獲得Text Embedding,Timesteps通過全連接(MLP)獲得Timesteps Embedding;
ResBlock用于結(jié)合時間步Timesteps Embedding,Transformer模塊用于結(jié)合文本Text Embedding。
我在這里放一張大圖,同學(xué)們可以看到內(nèi)部shape的變化。
Unet代碼如下所示:
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
# 用于計算當前采樣時間t的embedding
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
# 定義輸入模塊的第一個卷積
# TimestepEmbedSequential也可以看作一個包裝器,根據(jù)層的種類進行時間或者文本的融合。
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
# 對channel_mult進行循環(huán),channel_mult一共有四個值,代表unet四個部分通道的擴張比例
# [1, 2, 4, 4]
for level, mult in enumerate(channel_mult):
# 每個部分循環(huán)兩次
# 添加一個ResBlock和一個AttentionBlock
for _ in range(num_res_blocks):
# 先添加一個ResBlock
# 用于對輸入的噪聲進行通道數(shù)的調(diào)整,并且融合t的特征
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
# ch便是上述ResBlock的輸出通道數(shù)
ch = mult * model_channels
if ds in attention_resolutions:
# num_heads=8
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# 使用了SpatialTransformer自注意力,加強全局特征,融合文本的特征
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
# 如果不是四個部分中的最后一個部分,那么都要進行下采樣。
if level != len(channel_mult) - 1:
out_ch = ch
# 在此處進行下采樣
# 一般直接使用Downsample模塊
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
# 為下一階段定義參數(shù)。
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
# 定義中間層
# ResBlock + SpatialTransformer + ResBlock
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
# 定義Unet上采樣過程
self.output_blocks = nn.ModuleList([])
# 循環(huán)把channel_mult反了過來
for level, mult in list(enumerate(channel_mult))[::-1]:
# 上采樣時每個部分循環(huán)三次
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
# 首先添加ResBlock層
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
# 然后進行SpatialTransformer自注意力
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
)
)
# 如果不是channel_mult循環(huán)的第一個
# 且
# 是num_res_blocks循環(huán)的最后一次,則進行上采樣
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# 最后在輸出部分進行一次卷積
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
# 用于計算當前采樣時間t的embedding
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
# 對輸入模塊進行循環(huán),進行下采樣并且融合時間特征與文本特征。
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
# 中間模塊的特征提取
h = self.middle_block(h, emb, context)
# 上采樣模塊的特征提取
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
# 輸出模塊
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
4、隱空間解碼生成圖片
通過上述步驟,已經(jīng)可以多次采樣獲得結(jié)果,然后我們便可以通過隱空間解碼生成圖片。
隱空間解碼生成圖片的過程非常簡單,將上文多次采樣后的結(jié)果,使用decode_first_stage方法即可生成圖片。
在decode_first_stage方法中,網(wǎng)絡(luò)調(diào)用VAE對獲取到的64x64x3的隱向量進行解碼,獲得512x512x3的圖片。文章來源:http://www.zghlxwxcb.cn/news/detail-617973.html
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
# 一般無需分割輸入,所以直接將x_noisy傳入self.model中,在下面else進行
if hasattr(self, "split_input_params"):
......
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
圖像到圖像預(yù)測過程代碼
整體預(yù)測代碼如下:文章來源地址http://www.zghlxwxcb.cn/news/detail-617973.html
import os
import random
import cv2
import einops
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import seed_everything
from ldm_hacked import *
# ----------------------- #
# 使用的參數(shù)
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path = "model_data/v1-5-pruned-emaonly.safetensors"
# fp16,可以加速與節(jié)省顯存
sd_fp16 = True
vae_fp16 = True
# ----------------------- #
# 生成圖片的參數(shù)
# ----------------------- #
# 生成的圖像大小為input_shape,對于img2img會進行Centter Crop
input_shape = [512, 512]
# 一次生成幾張圖像
num_samples = 1
# 采樣的步數(shù)
ddim_steps = 20
# 采樣的種子,為-1的話則隨機。
seed = 12345
# eta
eta = 0
# denoise強度,for img2img
denoise_strength = 1.0
# ----------------------- #
# 提示詞相關(guān)參數(shù)
# ----------------------- #
# 提示詞
prompt = "a cute cat, with yellow leaf, trees"
# 正面提示詞
a_prompt = "best quality, extremely detailed"
# 負面提示詞
n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正負擴大倍數(shù)
scale = 9
# img2img使用,如果不想img2img這設(shè)置為None。
image_path = None
# ----------------------- #
# 保存路徑
# ----------------------- #
save_path = "imgs/outputs_imgs"
# ----------------------- #
# 創(chuàng)建模型
# ----------------------- #
model = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model = model.cuda()
ddim_sampler = DDIMSampler(model)
if sd_fp16:
model = model.half()
if image_path is not None:
img = Image.open(image_path)
img = crop_and_resize(img, input_shape[0], input_shape[1])
with torch.no_grad():
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
# ----------------------- #
# 對輸入圖片進行編碼并加噪
# ----------------------- #
if image_path is not None:
img = HWC3(np.array(img, np.uint8))
img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
img = torch.stack([img for _ in range(num_samples)], dim=0)
img = einops.rearrange(img, 'b h w c -> b c h w').clone()
if vae_fp16:
img = img.half()
model.first_stage_model = model.first_stage_model.half()
else:
model.first_stage_model = model.first_stage_model.float()
ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
z = model.get_first_stage_encoding(model.encode_first_stage(img))
z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
z_enc = z_enc.half() if sd_fp16 else z_enc.float()
# ----------------------- #
# 獲得編碼后的prompt
# ----------------------- #
cond = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
H, W = input_shape
shape = (4, H // 8, W // 8)
if image_path is not None:
samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
else:
# ----------------------- #
# 進行采樣
# ----------------------- #
samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
# ----------------------- #
# 進行解碼
# ----------------------- #
x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
# ----------------------- #
# 保存圖片
# ----------------------- #
if not os.path.exists(save_path):
os.makedirs(save_path)
for index, image in enumerate(x_samples):
cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
到了這里,關(guān)于Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!