国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例

這篇具有很好參考價值的文章主要介紹了Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例。希望對大家有所幫助。如果存在錯誤或未考慮完全的地方,請大家不吝賜教,您也可以點擊"舉報違法"按鈕提交疑問。

學(xué)習前言

用了很久的Stable Diffusion,但從來沒有好好解析過它內(nèi)部的結(jié)構(gòu),寫個博客記錄一下,嘿嘿。
Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD

源碼下載地址

https://github.com/bubbliiiing/stable-diffusion

喜歡的可以點個star噢。

網(wǎng)絡(luò)構(gòu)建

一、什么是Stable Diffusion(SD)

Stable Diffusion是比較新的一個擴散模型,翻譯過來是穩(wěn)定擴散,雖然名字叫穩(wěn)定擴散,但實際上換個seed生成的結(jié)果就完全不一樣,非常不穩(wěn)定哈。

Stable Diffusion最開始的應(yīng)用應(yīng)該是文本生成圖像,即文生圖,隨著技術(shù)的發(fā)展Stable Diffusion不僅支持image2image圖生圖的生成,還支持ControlNet等各種控制方法來定制生成的圖像。

Stable Diffusion基于擴散模型,所以不免包含不斷去噪的過程,如果是圖生圖的話,還有不斷加噪的過程,此時離不開DDPM那張老圖,如下:
Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
Stable Diffusion相比于DDPM,使用了DDIM采樣器,使用了隱空間的擴散,另外使用了非常大的LAION-5B數(shù)據(jù)集進行預(yù)訓(xùn)練。

直接Finetune Stable Diffusion大多數(shù)同學(xué)應(yīng)該是無法cover住成本的,不過Stable Diffusion有很多輕量Finetune的方案,比如Lora、Textual Inversion等,但這是后話。

本文主要是解析一下整個SD模型的結(jié)構(gòu)組成,一次擴散,多次擴散的流程。

大模型、AIGC是當前行業(yè)的趨勢,不會的話容易被淘汰,hh。

txt2img的原理如博文
Diffusion擴散模型學(xué)習2——Stable Diffusion結(jié)構(gòu)解析-以文本生成圖像(txt2img)為例
所示。

二、Stable Diffusion的組成

Stable Diffusion由四大部分組成。
1、Sampler采樣器。
2、Variational Autoencoder (VAE) 變分自編碼器。
3、UNet 主網(wǎng)絡(luò),噪聲預(yù)測器。
4、CLIPEmbedder文本編碼器。

每一部分都很重要,我們以圖像生成圖像為例進行解析。既然是圖像生成圖像,那么我們的輸入有兩個,一個是文本,另外一個是圖片。

三、img2img生成流程

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
生成流程分為四個部分:
1、對圖片進行VAE編碼,根據(jù)denoise數(shù)值進行加噪聲。
2、Prompt文本編碼。
3、根據(jù)denoise數(shù)值進行若干次采樣。
4、使用VAE進行解碼。

相比于文生圖,圖生圖的輸入發(fā)生了變化,不再以Gaussian noise作為初始化,而是以加噪后的圖像特征為初始化,這樣便以圖像的方式為模型注入了信息。

詳細來講,如上圖所示:

  • 第一步為對輸入的圖像利用VAE編碼,獲得輸入圖像的Latent特征;然后使用該Latent特征基于DDIM Sampler進行加噪,此時獲得輸入圖片加噪后的特征。假設(shè)我們設(shè)置denoise數(shù)值為0.8,總步數(shù)為20步,那么第一步中,我們會對輸入圖片進行0.8x20次的加噪聲,剩下4步不加,可理解為打亂了80%的特征,保留20%的特征。
  • 第二步是對輸入的文本進行編碼,獲得文本特征;
  • 第三步是根據(jù)denoise數(shù)值對 第一步中獲得的 加噪后的特征 進行若干次采樣。還是以第一步中denoise數(shù)值為0.8為例,我們只加了0.8x20次噪聲那么我們也只需要進行0.8x20次采樣就可以恢復(fù)出圖片了。
  • 第四步是將采樣后的圖片利用VAE的Decoder進行恢復(fù)。
with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   對輸入圖片進行編碼并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))

    # ----------------------- #
    #   獲得編碼后的prompt
    # ----------------------- #
    cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   進行采樣
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   進行解碼
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

1、輸入圖片編碼

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD

在圖生圖中,我們首先要指定一張參考的圖像,然后在這個參考圖像上開始工作:
1、利用VAE編碼器對這張參考圖像進行編碼,使其進入隱空間,只有進入了隱空間,網(wǎng)絡(luò)才知道這個圖像是什么;
2、然后使用該Latent特征基于DDIM Sampler進行加噪,此時獲得輸入圖片加噪后的特征。加噪的邏輯如下:

  • denoise可認為是重建的比例,1代表全部重建,0代表不重建;
  • 假設(shè)我們設(shè)置denoise數(shù)值為0.8,總步數(shù)為20步;我們會對輸入圖片進行0.8x20次的加噪聲,剩下4步不加,可理解為80%的特征,保留20%的特征;不過就算加完20步噪聲,原始輸入圖片的信息還是有一點保留的,不是完全不保留。

此時我們便獲得在隱空間加噪后的圖像,后續(xù)會在這個 隱空間加噪后的圖像 的基礎(chǔ)上進行采樣。

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   對輸入圖片進行編碼并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))

2、文本編碼

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
文本編碼的思路比較簡單,直接使用CLIP的文本編碼器進行編碼就可以了,在代碼中定義了一個FrozenCLIPEmbedder類別,使用了transformers庫的CLIPTokenizer和CLIPTextModel。

在前傳過程中,我們對輸入進來的文本首先利用CLIPTokenizer進行編碼,然后使用CLIPTextModel進行特征提取,通過FrozenCLIPEmbedder,我們可以獲得一個[batch_size, 77, 768]的特征向量。

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        # 定義文本的tokenizer和transformer
        self.tokenizer      = CLIPTokenizer.from_pretrained(version)
        self.transformer    = CLIPTextModel.from_pretrained(version)
        self.device         = device
        self.max_length     = max_length
        # 凍結(jié)模型參數(shù)
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        # self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 對輸入的圖片進行分詞并編碼,padding直接padding到77的長度。
        batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        # 拿出input_ids然后傳入transformer進行特征提取。
        tokens          = batch_encoding["input_ids"].to(self.device)
        outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        # 取出所有的token
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)

3、采樣流程

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD

a、生成初始噪聲

在圖生圖中,我們的初始噪聲獲取于參考圖片,所以參考第一步就可以獲得圖生圖的噪聲

b、對噪聲進行N次采樣

既然Stable Diffusion是一個不斷擴散的過程,那么少不了不斷的去噪聲,那么怎么去噪聲便是一個問題。

在上一步中,我們已經(jīng)獲得了一個圖生圖的噪聲,它是一個符合正態(tài)分布的向量,我們便從它開始去噪聲。

我們會對ddim_timesteps的時間步取反,因為我們現(xiàn)在是去噪聲而非加噪聲,然后對其進行一個循環(huán),由于我們此時不再是txt2img中的采樣流程,我們使用sampler的另外一個方法decode,循環(huán)的代碼如下:

@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
           use_original_steps=False):

    # 使用ddim的時間步
    # 這里內(nèi)容看起來很多,但是其實很少,本質(zhì)上就是取了self.ddim_timesteps,然后把它reversed一下
    timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
    timesteps = timesteps[:t_start]

    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]
    print(f"Running DDIM Sampling with {total_steps} timesteps")

    iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
    x_dec = x_latent
    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
        # 進行單次采樣
        x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
    return x_dec
c、單次采樣解析
I、預(yù)測噪聲

在進行單詞采樣前,需要首先判斷是否有neg prompt,如果有,我們需要同時處理neg prompt,否則僅僅需要處理pos prompt。實際使用的時候一般都有neg prompt(效果會好一些),所以默認進入對應(yīng)的處理過程。

在處理neg prompt時,我們對輸入進來的隱向量和步數(shù)進行復(fù)制,一個屬于pos prompt,一個屬于neg prompt。torch.cat默認堆疊維度為0,所以是在batch_size維度進行堆疊,二者不會互相影響。然后我們將pos prompt和neg prompt堆疊到一個batch中,也是在batch_size維度堆疊。

# 首先判斷是否由neg prompt,unconditional_conditioning是由neg prompt獲得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    # 一般都是有neg prompt的,所以進入到這里
    # 在這里我們對隱向量和步數(shù)進行復(fù)制,一個屬于pos prompt,一個屬于neg prompt
    # torch.cat默認堆疊維度為0,所以是在bs維度進行堆疊,二者不會互相影響
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    # 然后我們將pos prompt和neg prompt堆疊到一個batch中
    if isinstance(c, dict):
        assert isinstance(unconditional_conditioning, dict)
        c_in = dict()
        for k in c:
            if isinstance(c[k], list):
                c_in[k] = [
                    torch.cat([unconditional_conditioning[k][i], c[k][i]])
                    for i in range(len(c[k]))
                ]
            else:
                c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
    else:
        c_in = torch.cat([unconditional_conditioning, c])

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
堆疊完后,我們將隱向量、步數(shù)和prompt條件一起傳入網(wǎng)絡(luò)中,將結(jié)果在bs維度進行使用chunk進行分割。

因為我們在堆疊時,neg prompt放在了前面。因此分割好后,前半部分e_t_uncond屬于利用neg prompt得到的,后半部分e_t屬于利用pos prompt得到的,我們本質(zhì)上應(yīng)該擴大pos prompt的影響,遠離neg prompt的影響。因此,我們使用e_t-e_t_uncond計算二者的距離,使用scale擴大二者的距離。在e_t_uncond基礎(chǔ)上,得到最后的隱向量。

# 堆疊完后,隱向量、步數(shù)和prompt條件一起傳入網(wǎng)絡(luò)中,將結(jié)果在bs維度進行使用chunk進行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
此時獲得的e_t就是通過隱向量和prompt共同獲得的預(yù)測噪聲啦。

II、施加噪聲

獲得噪聲就OK了嗎?顯然不是的,我們還要將獲得的新噪聲,按照一定的比例添加到原來的原始噪聲上。

這個地方我們最好結(jié)合ddim中的公式來看,我們需要獲得 α ˉ t \bar{\alpha}_t αˉt?、 α ˉ t ? 1 \bar{\alpha}_{t-1} αˉt?1? σ t \sigma_t σt?、 1 ? α ˉ t \sqrt{1-\bar{\alpha}_t} 1?αˉt? ?
Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
代碼中,我們其實已經(jīng)預(yù)先計算好了這些參數(shù)。我們只需要直接取出即可,下方的a_t也就是公式中括號外的 α ˉ t \bar{\alpha}_t αˉt?,a_prev 就是公式中的 α ˉ t ? 1 \bar{\alpha}_{t-1} αˉt?1?,sigma_t就是公式中的 σ t \sigma_t σt?,sqrt_one_minus_at就是公式中的 1 ? α ˉ t \sqrt{1-\bar{\alpha}_t} 1?αˉt? ?。

# 根據(jù)采樣器選擇參數(shù)
alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

# 根據(jù)步數(shù)選擇參數(shù),
# 這里的index就是上面循環(huán)中的total_steps - i - 1
a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

其實這一步我們只是把公式需要用到的系數(shù)全都拿了出來,方便后面的加減乘除。然后我們便在代碼中實現(xiàn)上述的公式。

# current prediction for x_0
# 公式中的最左邊
pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
    pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中間
dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右邊
noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 輸出添加完公式的結(jié)果
return x_prev, pred_x0

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD

d、預(yù)測噪聲過程中的網(wǎng)絡(luò)結(jié)構(gòu)解析
I、apply_model方法解析

在3.a的預(yù)測噪聲過程中,我們使用了model.apply_model方法進行噪聲的預(yù)測,這個方法具體做了什么被隱掉了,我們看看具體做的工作。

apply_model方法在ldm.models.diffusion.ddpm.py文件中。在apply_model中,我們將x_noisy傳入self.model中預(yù)測噪聲。

x_recon = self.model(x_noisy, t, **cond)

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
self.model是一個預(yù)先構(gòu)建好的類,定義在ldm.models.diffusion.ddpm.py文件的1416行,內(nèi)部包含Stable Diffusion的Unet網(wǎng)絡(luò),self.model的功能有點類似于包裝器,根據(jù)模型選擇的特征融合方式,進行文本與上文生成的噪聲的融合。

c_concat代表使用堆疊的方式進行融合,c_crossattn代表使用attention的方式融合。

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        # stable diffusion的unet網(wǎng)絡(luò)
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
代碼中的self.diffusion_model便是Stable Diffusion的Unet網(wǎng)絡(luò),網(wǎng)絡(luò)結(jié)構(gòu)位于ldm.modules.diffusionmodules.openaimodel.py文件中的UNetModel類。

II、UNetModel模型解析

UNetModel主要做的工作是結(jié)合時間步t和文本Embedding計算這一時刻的噪聲。盡管UNet的思路非常簡單,但是在StableDiffusion中,UNetModel由ResBlock和Transformer模塊組成,整體來講相比于普通的UNet復(fù)雜一些。

Prompt通過Frozen CLIP Text Encoder獲得Text Embedding,Timesteps通過全連接(MLP)獲得Timesteps Embedding;

ResBlock用于結(jié)合時間步Timesteps Embedding,Transformer模塊用于結(jié)合文本Text Embedding。

我在這里放一張大圖,同學(xué)們可以看到內(nèi)部shape的變化。
Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD

Unet代碼如下所示:

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        # 用于計算當前采樣時間t的embedding
        time_embed_dim  = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
        
        # 定義輸入模塊的第一個卷積
        # TimestepEmbedSequential也可以看作一個包裝器,根據(jù)層的種類進行時間或者文本的融合。
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size  = model_channels
        input_block_chans   = [model_channels]
        ch                  = model_channels
        ds                  = 1
        # 對channel_mult進行循環(huán),channel_mult一共有四個值,代表unet四個部分通道的擴張比例
        # [1, 2, 4, 4]
        for level, mult in enumerate(channel_mult):
            # 每個部分循環(huán)兩次
            # 添加一個ResBlock和一個AttentionBlock
            for _ in range(num_res_blocks):
                # 先添加一個ResBlock
                # 用于對輸入的噪聲進行通道數(shù)的調(diào)整,并且融合t的特征
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                # ch便是上述ResBlock的輸出通道數(shù)
                ch = mult * model_channels
                if ds in attention_resolutions:
                    # num_heads=8
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    # 使用了SpatialTransformer自注意力,加強全局特征,融合文本的特征
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            # 如果不是四個部分中的最后一個部分,那么都要進行下采樣。
            if level != len(channel_mult) - 1:
                out_ch = ch
                # 在此處進行下采樣
                # 一般直接使用Downsample模塊
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                # 為下一階段定義參數(shù)。
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        # 定義中間層
        # ResBlock + SpatialTransformer + ResBlock
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        # 定義Unet上采樣過程
        self.output_blocks = nn.ModuleList([])
        # 循環(huán)把channel_mult反了過來
        for level, mult in list(enumerate(channel_mult))[::-1]:
            # 上采樣時每個部分循環(huán)三次
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                # 首先添加ResBlock層
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                # 然后進行SpatialTransformer自注意力
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                # 如果不是channel_mult循環(huán)的第一個
                # 且
                # 是num_res_blocks循環(huán)的最后一次,則進行上采樣
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # 最后在輸出部分進行一次卷積
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )
        if self.predict_codebook_ids:
            self.id_predictor = nn.Sequential(
            normalization(ch),
            conv_nd(dims, model_channels, n_embed, 1),
            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs      = []
        # 用于計算當前采樣時間t的embedding
        t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb     = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # 對輸入模塊進行循環(huán),進行下采樣并且融合時間特征與文本特征。
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)

        # 中間模塊的特征提取
        h = self.middle_block(h, emb, context)

        # 上采樣模塊的特征提取
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        # 輸出模塊
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

4、隱空間解碼生成圖片

Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例,擴散模型學(xué)習,stable diffusion,pytorch,img2img,圖生圖,SD
通過上述步驟,已經(jīng)可以多次采樣獲得結(jié)果,然后我們便可以通過隱空間解碼生成圖片。

隱空間解碼生成圖片的過程非常簡單,將上文多次采樣后的結(jié)果,使用decode_first_stage方法即可生成圖片。

在decode_first_stage方法中,網(wǎng)絡(luò)調(diào)用VAE對獲取到的64x64x3的隱向量進行解碼,獲得512x512x3的圖片。

@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
    if predict_cids:
        if z.dim() == 4:
            z = torch.argmax(z.exp(), dim=1).long()
        z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
        z = rearrange(z, 'b h w c -> b c h w').contiguous()

    z = 1. / self.scale_factor * z
	# 一般無需分割輸入,所以直接將x_noisy傳入self.model中,在下面else進行
    if hasattr(self, "split_input_params"):
    	......
    else:
        if isinstance(self.first_stage_model, VQModelInterface):
            return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
        else:
            return self.first_stage_model.decode(z)

圖像到圖像預(yù)測過程代碼

整體預(yù)測代碼如下:文章來源地址http://www.zghlxwxcb.cn/news/detail-617973.html

import os
import random

import cv2
import einops
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import seed_everything

from ldm_hacked import *

# ----------------------- #
#   使用的參數(shù)
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path  = "model_data/v1-5-pruned-emaonly.safetensors"
# fp16,可以加速與節(jié)省顯存
sd_fp16     = True
vae_fp16    = True

# ----------------------- #
#   生成圖片的參數(shù)
# ----------------------- #
# 生成的圖像大小為input_shape,對于img2img會進行Centter Crop
input_shape = [512, 512]
# 一次生成幾張圖像
num_samples = 1
# 采樣的步數(shù)
ddim_steps  = 20
# 采樣的種子,為-1的話則隨機。
seed        = 12345
# eta
eta         = 0
# denoise強度,for img2img
denoise_strength = 1.0

# ----------------------- #
#   提示詞相關(guān)參數(shù)
# ----------------------- #
# 提示詞
prompt      = "a cute cat, with yellow leaf, trees"
# 正面提示詞
a_prompt    = "best quality, extremely detailed"
# 負面提示詞
n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正負擴大倍數(shù)
scale       = 9
# img2img使用,如果不想img2img這設(shè)置為None。
image_path  = None

# ----------------------- #
#   保存路徑
# ----------------------- #
save_path   = "imgs/outputs_imgs"

# ----------------------- #
#   創(chuàng)建模型
# ----------------------- #
model   = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model   = model.cuda()
ddim_sampler = DDIMSampler(model)
if sd_fp16:
    model = model.half()

if image_path is not None:
    img = Image.open(image_path)
    img = crop_and_resize(img, input_shape[0], input_shape[1])

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)
    
    # ----------------------- #
    #   對輸入圖片進行編碼并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()
        if vae_fp16:
            img = img.half()
            model.first_stage_model = model.first_stage_model.half()
        else:
            model.first_stage_model = model.first_stage_model.float()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
        z_enc = z_enc.half() if sd_fp16 else z_enc.float()

    # ----------------------- #
    #   獲得編碼后的prompt
    # ----------------------- #
    cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   進行采樣
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   進行解碼
    # ----------------------- #
    x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())

    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

# ----------------------- #
#   保存圖片
# ----------------------- #
if not os.path.exists(save_path):
    os.makedirs(save_path)
for index, image in enumerate(x_samples):
    cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

到了這里,關(guān)于Diffusion擴散模型學(xué)習3——Stable Diffusion結(jié)構(gòu)解析-以圖像生成圖像(圖生圖,img2img)為例的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務(wù),不擁有所有權(quán),不承擔相關(guān)法律責任。如若轉(zhuǎn)載,請注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實不符,請點擊違法舉報進行投訴反饋,一經(jīng)查實,立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費用

相關(guān)文章

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包