Scalable Diffusion Models with Transformers
論文地址:
https://arxiv.org/pdf/2212.09748.pdf
項(xiàng)目地址:
https://github.com/facebookresearch/DiT
論文主頁:
https://www.wpeebles.com/DiT
摘要
我們探索了一類新的基于Transformer結(jié)構(gòu)的擴(kuò)散模型。我們訓(xùn)練圖像的潛在擴(kuò)散模型,用一個(gè)對潛在補(bǔ)丁操作的Transformer取代常用的U-Net骨干網(wǎng)。我們通過Gflops測量的前向傳遞復(fù)雜性來分析擴(kuò)散Transformer(dit)的可伸縮性。我們發(fā)現(xiàn),具有較高gflop的dit(通過增加Transformer深度/寬度或增加輸入令牌數(shù)量)始終具有較低的FID。除了具有良好的可擴(kuò)展性屬性外,我們最大的DiT-XL/2模型在類條件ImageNet 512×512和256×256基準(zhǔn)上優(yōu)于所有先前的擴(kuò)散模型,在后者上實(shí)現(xiàn)了2.27的最先進(jìn)的FID
1. 介紹
在Transformer的推動下,機(jī)器學(xué)習(xí)正在經(jīng)歷復(fù)興。在過去的五年中,用于自然語言處理[8,39]、視覺[10]和其他幾個(gè)領(lǐng)域的神經(jīng)架構(gòu)在很大程度上已經(jīng)被Transformer[57]所包含。許多圖像級生成模型仍然不受這一趨勢的影響,盡管Transformer在自回歸模型中得到了廣泛的應(yīng)用[3,6,40,44],但它們在其他生成建??蚣苤袘?yīng)用較少。例如,擴(kuò)散模型一直處于圖像級生成模型最新進(jìn)展的前沿[9,43];然而,它們都采用卷積U-Net架構(gòu)作為骨干網(wǎng)的事實(shí)上選擇。
Ho等人的開創(chuàng)性工作[19]首次為擴(kuò)散模型引入了U-Net骨干網(wǎng)。設(shè)計(jì)選擇繼承自pixelcnn++[49,55],一個(gè)自回歸生成模型,有一些架構(gòu)上的變化。該模型是卷積的,主要由ResNet[15]塊組成。與標(biāo)準(zhǔn)U-Net[46]相比,附加的空間自注意塊是Transformer中的基本組件,以較低的分辨率散布。Dhariwal和Nichol[9]消除了U-Net的幾種架構(gòu)選擇,例如使用自適應(yīng)歸一化層[37]為卷積層注入條件信息和通道計(jì)數(shù)。然而,Ho等人的U-Net的高級設(shè)計(jì)在很大程度上保持完整。
通過這項(xiàng)工作,我們的目標(biāo)是揭開擴(kuò)散模型中建筑選擇的意義,并為未來的生成建模研究提供經(jīng)驗(yàn)基線。我們表明,U-Net電感偏差對擴(kuò)散模型的性能不是至關(guān)重要的,并且它們可以很容易地用標(biāo)準(zhǔn)設(shè)計(jì)(如Transformer)取代。因此,擴(kuò)散模型很好地準(zhǔn)備從架構(gòu)統(tǒng)一的最近趨勢中獲益。通過繼承其他領(lǐng)域的最佳實(shí)踐和培訓(xùn)方法,以及保留可伸縮性、健壯性和效率等有利屬性。標(biāo)準(zhǔn)化的體系結(jié)構(gòu)也將為跨領(lǐng)域研究開辟新的可能性。
本文討論了一類新的基于Transformer的擴(kuò)散模型。我們稱之為擴(kuò)散Transformer,簡稱dit。dit堅(jiān)持視覺Transformer(ViTs)[10]的最佳實(shí)踐,它已被證明比傳統(tǒng)的卷積網(wǎng)絡(luò)(例如ResNet[15])更有效地?cái)U(kuò)展視覺識別。
更具體地說,我們研究了Transformer的標(biāo)度行為與網(wǎng)絡(luò)復(fù)雜度和樣本質(zhì)量的關(guān)系。我們表明,通過在潛伏擴(kuò)散模型(ldm)[45]框架下構(gòu)建和對標(biāo)DiT設(shè)計(jì)空間,其中擴(kuò)散模型在V AE的潛伏空間內(nèi)訓(xùn)練,我們可以成功地用Transformer取代U-Net骨干。我們進(jìn)一步表明dit是擴(kuò)散模型的可擴(kuò)展架構(gòu):網(wǎng)絡(luò)復(fù)雜性(由Gflops測量)與樣本質(zhì)量(由FID測量)之間存在很強(qiáng)的相關(guān)性。通過簡單地?cái)U(kuò)展DiT并訓(xùn)練具有高容量主干(118.6 Gflops)的LDM,我們能夠在256 × 256類條件ImageNet生成基準(zhǔn)上實(shí)現(xiàn)2.27 FID的最先進(jìn)結(jié)果。
2. 相關(guān)工作
Transformer。Transformer[57]已經(jīng)取代了跨語言、視覺[10]、強(qiáng)化學(xué)習(xí)[5,23]和元學(xué)習(xí)[36]的特定領(lǐng)域架構(gòu)。它們在增加模型大小、訓(xùn)練語言域[24]的計(jì)算和數(shù)據(jù)時(shí)表現(xiàn)出了顯著的縮放特性,如通用自回歸模型[17]和ViTs[60]。除了語言,Transformer已經(jīng)被訓(xùn)練為自回歸預(yù)測像素[6,7,35]。它們還在離散碼本[56]上進(jìn)行了訓(xùn)練,作為自回歸模型[11,44]和掩模生成模型[4,14];前者在20B參數(shù)[59]下表現(xiàn)出優(yōu)異的縮放性能。最后,在DDPMs中探索了Transformer對非空間數(shù)據(jù)的綜合;例如,在DALL·E 2中生成CLIP圖像嵌入[38,43]。本文研究了Transformer作為圖像擴(kuò)散模型主干時(shí)的標(biāo)度特性。
去噪擴(kuò)散概率模型(DDPM)。擴(kuò)散[19,51]和基于分?jǐn)?shù)的生成模型[22,53]作為圖像的生成模型特別成功[32,43,45,47],在許多情況下優(yōu)于以前最先進(jìn)的生成對抗網(wǎng)絡(luò)(GANs)[12]。過去兩年DDPM的改進(jìn)主要是由改進(jìn)的采樣技術(shù)驅(qū)動的[19,25,52],最顯著的是無分類引導(dǎo)[21],重新定義擴(kuò)散模型以預(yù)測噪聲而不是像素[19],并使用級聯(lián)DDPM管道,其中低分辨率基礎(chǔ)擴(kuò)散模型與上采樣器并行訓(xùn)練[9,20]。對于上面列出的所有擴(kuò)散模型,卷積U-Nets[46]實(shí)際上是骨干架構(gòu)的選擇。
體系結(jié)構(gòu)的復(fù)雜性。在圖像生成文獻(xiàn)中評估體系結(jié)構(gòu)復(fù)雜性時(shí),使用參數(shù)計(jì)數(shù)是相當(dāng)常見的實(shí)踐。一般來說,參數(shù)計(jì)數(shù)不能很好地代表圖像模型的復(fù)雜性,因?yàn)樗鼈儾荒苷f明圖像分辨率等顯著影響性能的因素[41,42]。相反,本文中的大部分模型復(fù)雜性分析都是通過理論Gflops的透鏡進(jìn)行的。這使我們與架構(gòu)設(shè)計(jì)文獻(xiàn)保持一致,在這些文獻(xiàn)中,gflop被廣泛用于衡量復(fù)雜性。在實(shí)踐中,黃金復(fù)雜度度量仍然存在爭議,因?yàn)樗?jīng)常依賴于特定的應(yīng)用場景。Nichol和Dhariwal改進(jìn)擴(kuò)散模型的開創(chuàng)性工作[9,33]與我們最相關(guān),在那里,他們分析了U-Net架構(gòu)類的可伸縮性和Gflop屬性。在本文中,我們主要關(guān)注transformer類。
3. Diffusion Transformers
3.1. Preliminaries
擴(kuò)散公式。在介紹我們的體系結(jié)構(gòu)之前,我們簡要回顧了理解擴(kuò)散模型(DDPM)所需的一些基本概念[19,51]。高斯擴(kuò)散模型假設(shè)一個(gè)前向噪聲處理過程,逐步將噪聲應(yīng)用于真實(shí)數(shù)據(jù) x 0 : q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 ? α ˉ t ) I ) x_0:{q\left(\mathbf{x}_{t} \mid \mathbf{x}_{0}\right)=N\left(\mathbf{x}_{t} ; \sqrt{\bar{\alpha}_{t}} \mathbf{x}_{0},\left(1-\bar{\alpha}_{t}\right) \mathbf{I}\right)} x0?:q(xt?∣x0?)=N(xt?;αˉt??x0?,(1?αˉt?)I)
其中常量 α ˉ t \bar{\alpha}_{t} αˉt?是超參數(shù)。通過應(yīng)用重新參數(shù)化技巧,我們可以對 x t = α ˉ t x 0 + 1 ? α l  ̄ ? {\mathrm{x}_{t}=\sqrt{\bar{\alpha}_{t}} \mathrm{x}_{0}+\sqrt{1-\overline{\alpha_{l}}} \epsilon} xt?=αˉt??x0?+1?αl????,其中${\epsilon}{\in} {\mathcal{N}(0,I)} $。
擴(kuò)散模型被訓(xùn)練來學(xué)習(xí)反轉(zhuǎn)正向過程破壞的反向過程: p θ ( x t ? 1 ∣ x t ) = N ( x t ? 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_{\theta}\left(\mathbf{x}_{t-1} \mid \mathbf{x}_{t}\right)=\mathcal{N}\left(\mathbf{x}_{t-1} ; \mu_{\theta}\left(\mathbf{x}_{t}, t\right), \Sigma_{\theta}\left(\mathbf{x}_{t}, t\right)\right) pθ?(xt?1?∣xt?)=N(xt?1?;μθ?(xt?,t),Σθ?(xt?,t)),其中神經(jīng)網(wǎng)絡(luò)用于預(yù)測 p θ p_θ pθ?的統(tǒng)計(jì)量。用 x 0 x_0 x0?的對數(shù)似然的變分下界[27]來訓(xùn)練逆向過程模型,可簡化為: L ( θ ) = ? p ( x 0 ∣ x 1 ) + ∑ t D K L ( q ? ( x t ? 1 ∣ x t , η ) ∥ p θ ( x t ? 1 ∣ x t ) ) \mathcal{L}(\theta)=-p(x_0|x_1)+\sum_t \mathrm{D_KL}\left(q^*\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \boldsymbol{\eta}\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right) L(θ)=?p(x0?∣x1?)+∑t?DK?L(q?(xt?1?∣xt?,η)∥pθ?(xt?1?∣xt?))。排除了與訓(xùn)練無關(guān)的附加術(shù)語。由于 q ? q^ * q?和 p θ p_θ pθ?都是高斯分布, D K L D_{KL} DKL?可以用兩個(gè)分布的均值和協(xié)方差來計(jì)算。通過重新參數(shù)化 μ θ \mu_{\theta} μθ?作為噪聲預(yù)測網(wǎng)絡(luò) ? θ {\epsilon}_{\theta} ?θ?,可以使用預(yù)測噪聲 ? θ ( x t ) {\epsilon}_{\theta}(x_t) ?θ?(xt?)與地面真理采樣高斯噪聲 ? t {\epsilon}_t ?t?之間的簡單均方誤差來訓(xùn)練模型: L simple? = E t , x 0 , ? [ ∥ ? ? ? θ ( x t , t ) ∥ 2 ] L_{\text {simple }}=E_{t, x_{0}, \epsilon}\left[\left\|\epsilon-\epsilon_{\theta}\left(x_{t}, t\right)\right\|^{2}\right] Lsimple??=Et,x0?,??[∥???θ?(xt?,t)∥2]。但是,為了用學(xué)習(xí)到的反向過程協(xié)方差 Σ θ \Sigma_{\theta} Σθ?訓(xùn)練擴(kuò)散模型,需要優(yōu)化完整的 D K L D_{KL} DKL?項(xiàng)。我們遵循Nichol和Dhariwal的方法[33]:用 L s i m p l e L_{simple} Lsimple?訓(xùn)練 ? θ {\epsilon}_{\theta} ?θ?,用完整的 L L L訓(xùn)練 Σ θ \Sigma_{\theta} Σθ?。一旦訓(xùn)練 p θ p_θ pθ?,就可以通過初始化 x t m a x ~ N ( 0 , I ) x_{t_{max}}~N(0,I) xtmax??~N(0,I)和通過重新參數(shù)化技巧采樣 x t ? 1 ~ p θ ( x t ? 1 ∣ x t ) x_{t?1}~p_θ(x_{t?1}|x_t) xt?1?~pθ?(xt?1?∣xt?)來對新圖像進(jìn)行采樣。
Classifier-free指導(dǎo)。條件擴(kuò)散模型將額外的信息作為輸入,例如類標(biāo)簽c。在這種情況下,相反的過程變成 p θ ( x t ? 1 ∣ x t , c ) p_θ(x_{t?1}|x_t, c) pθ?(xt?1?∣xt?,c),其中 ? θ {\epsilon}_{\theta} ?θ?和 Σ θ \Sigma_{\theta} Σθ?以 c c c為條件。在這種設(shè)置下,可以使用無分類器引導(dǎo)來鼓勵(lì)抽樣過程找到x,使 l o g p ( c ∣ x ) log p(c|x) logp(c∣x)高[21]。根據(jù)貝葉斯規(guī)則, log ? p ( c ∣ x ) ∝ log ? p ( x ∣ c ) ? log ? p ( x ) \log p(c \mid x) \propto \log p(x \mid c)-\log p(x) logp(c∣x)∝logp(x∣c)?logp(x),因此 ? σ log ? p ( c ∣ x ) ∝ ? σ log ? p ( x ∣ c ) ? ? τ log ? p ( x ) \nabla_\sigma \log p(c \mid x) \propto \nabla_\sigma \log p(x \mid c)-\nabla_\tau \log p(x) ?σ?logp(c∣x)∝?σ?logp(x∣c)??τ?logp(x)。通過將擴(kuò)散模型的輸出解釋為得分函數(shù),DDPM抽樣程序可以通過以下方法指導(dǎo) x x x到 p ( x ∣ c ) p(x|c) p(x∣c)高的樣本: p ( x ∣ c ) p(x \mid c) p(x∣c) by: ? ^ θ ( x t , c ) = ? θ ( x t , ? ) + s \hat{\epsilon}_\theta\left(x_t, c\right)=\epsilon_\theta\left(x_t, \emptyset\right)+s ?^θ?(xt?,c)=?θ?(xt?,?)+s. ? x log ? p ( x ∣ c ) ∝ ? θ ( x t , ? ) + s ? ( ? θ ( x t , c ) ? ? θ ( x t , ? ) ) \nabla_x \log p(x \mid c) \propto \epsilon_\theta\left(x_t, \emptyset\right)+s \cdot\left(\epsilon_\theta\left(x_t, c\right)-\epsilon_\theta\left(x_t, \emptyset\right)\right) ?x?logp(x∣c)∝?θ?(xt?,?)+s?(?θ?(xt?,c)??θ?(xt?,?)),其中s > 1為指導(dǎo)尺度(注意s = 1恢復(fù)標(biāo)準(zhǔn)抽樣)。對c =?的擴(kuò)散模型的評估是在訓(xùn)練時(shí)隨機(jī)剔除c,代之以一個(gè)習(xí)得的“null”嵌入?。眾所周知,與通用抽樣技術(shù)相比,無分類器指導(dǎo)可以產(chǎn)生顯著改善的樣本[21,32,43],而且我們的DiT模型也具有這一趨勢。
潛在擴(kuò)散模型。在高分辨率像素空間中直接訓(xùn)練擴(kuò)散模型在計(jì)算上是非常困難的。潛擴(kuò)散模型(ldm)[45]用兩階段方法解決這個(gè)問題:(1)學(xué)習(xí)一個(gè)自編碼器,用學(xué)習(xí)的編碼器E將圖像壓縮成更小的空間表示;(2)訓(xùn)練一個(gè)表示 z = E ( x ) z = E(x) z=E(x)的擴(kuò)散模型,而不是圖像x (E被凍結(jié))的擴(kuò)散模型。然后,可以通過從擴(kuò)散模型中采樣表示z來生成新的圖像,然后使用學(xué)習(xí)的解碼器 x = D ( z ) x = D(z) x=D(z)將其解碼為圖像。如圖2所示,LDM在使用像素空間擴(kuò)散模型(如ADM)的一小部分gflop的情況下實(shí)現(xiàn)了良好的性能。由于我們關(guān)心計(jì)算效率,這使它們成為架構(gòu)探索的一個(gè)吸引人的起點(diǎn)。在本文中,我們將dit應(yīng)用于潛在空間,盡管它們也可以應(yīng)用于像素空間而無需修改。這使得我們的圖像生成管道成為一種基于混合的方法;我們使用現(xiàn)成的卷積vav和基于Transformer的DDPM。
3.2. Diffusion Transformer Design Space
我們介紹了擴(kuò)散Transformer(DiTs),一種用于擴(kuò)散模型的新架構(gòu)。我們的目標(biāo)是盡可能忠實(shí)于標(biāo)準(zhǔn)Transformer架構(gòu),以保留其縮放特性。由于我們的重點(diǎn)是訓(xùn)練圖像的DDPM(特別是圖像的空間表示),DiT基于視覺轉(zhuǎn)換器(ViT)架構(gòu),該架構(gòu)操作補(bǔ)丁[10]的序列。DiT保留了ViT的許多最佳實(shí)踐。圖3顯示了完整DiT體系結(jié)構(gòu)的概述。在本節(jié)中,我們將描述DiT的正向傳遞,以及DiT類的設(shè)計(jì)空間的組件。
Patchify。DiT的輸入是一個(gè)空間表示z(對于256 × 256 × 3圖像,z的形狀為32 × 32 × 4)。DiT的第一層是“patchify”,它通過將每個(gè)patch線性嵌入到輸入中,將空間輸入轉(zhuǎn)換為一個(gè)T標(biāo)記序列,每個(gè)標(biāo)記的維度為d。在patchify之后,我們將標(biāo)準(zhǔn)的基于ViT頻率的位置嵌入(sin -cos版本)應(yīng)用于所有輸入令牌。
patchify創(chuàng)建的令牌T的數(shù)量由補(bǔ)丁大小超參數(shù)p決定。如圖4所示,將p減半將使T翻四倍,因此至少使總Transformergflop翻四倍。盡管它對Gflops有重大影響,但請注意,更改p對下游參數(shù)計(jì)數(shù)沒有重大影響。
我們將p = 2,4,8添加到DiT設(shè)計(jì)空間。
DiT塊設(shè)計(jì)。在patchify之后,輸入令牌由一系列Transformer塊處理。除了有噪聲的圖像輸入外,擴(kuò)散模型有時(shí)還處理附加的條件信息,如噪聲時(shí)間步長t、類標(biāo)簽c、自然語言等。我們探索了四種不同的Transformer塊,它們以不同的方式處理?xiàng)l件輸入。該設(shè)計(jì)對標(biāo)準(zhǔn)ViT塊設(shè)計(jì)進(jìn)行了微小但重要的修改。各塊的設(shè)計(jì)如圖3所示。
? -情境條件反射。我們簡單地將t和c的向量嵌入作為輸入序列中的兩個(gè)附加標(biāo)記,將它們與圖像標(biāo)記區(qū)別對待。這類似于ViT中的cls令牌,它允許我們無需修改就使用標(biāo)準(zhǔn)ViT塊。在最后一個(gè)塊之后,我們從序列中刪除條件令牌。這種方法為模型引入了可以忽略不計(jì)的新Gflops。
? -交叉注意模塊。我們將t和c的嵌入連接到一個(gè)長度為2的序列中,與圖像標(biāo)記序列分開。Transformer塊經(jīng)過修改,在多頭自注意塊之后增加了一個(gè)多頭交叉注意層,類似于Vaswani等人的原始設(shè)計(jì)[57],也類似于LDM用于類標(biāo)簽的條件調(diào)節(jié)。交叉注意為模型增加了最多的gflop,大約15%的開銷。
? -Adaptive layer norm (adaLN)塊。在GANs中廣泛使用自適應(yīng)歸一化層[37][2,26]和使用UNet主干[9]的擴(kuò)散模型之后,我們探索用自適應(yīng)層范數(shù)(adaLN)取代Transformer塊中的標(biāo)準(zhǔn)層范數(shù)層。我們不是直接學(xué)習(xí)量維尺度和移位參數(shù)γ和β,而是從t和c的嵌入向量的和中回歸它們。在我們探索的三個(gè)塊設(shè)計(jì)中,adaLN添加的Gflops最少,因此計(jì)算效率最高。它也是唯一限制將相同函數(shù)應(yīng)用于所有令牌的調(diào)節(jié)機(jī)制。
? -adaln -零塊。先前關(guān)于ResNets的工作已經(jīng)發(fā)現(xiàn),將每個(gè)剩余塊初始化為恒等函數(shù)是有益的。例如,Goyal等人發(fā)現(xiàn),在每個(gè)塊中對最后一批范數(shù)尺度因子γ進(jìn)行零初始化可以加速監(jiān)督學(xué)習(xí)設(shè)置[13]下的大規(guī)模訓(xùn)練。擴(kuò)散U-Net模型使用類似的初始化策略,在任何剩余連接之前對每個(gè)塊中的最終卷積層進(jìn)行零初始化。我們將探索adaLN DiT塊的修改,它具有相同的功能。除了回歸γ和β,我們還回歸在DiT塊內(nèi)任何剩余連接之前立即應(yīng)用的維度縮放參數(shù)α。我們初始化MLP以輸出所有α的零向量;這將整個(gè)DiT塊初始化為恒等函數(shù)。與vanilla adaLN塊一樣,adaLNZero為模型添加了可以忽略不計(jì)的gflop。
我們在DiT設(shè)計(jì)空間中包括上下文內(nèi)、交叉注意、自適應(yīng)層規(guī)范和adaLN-Zero塊。
模型的尺寸。我們應(yīng)用了N個(gè)DiT塊序列,每個(gè)塊的隱藏維度大小為d。在ViT之后,我們使用標(biāo)準(zhǔn)Transformer配置,共同縮放N, d和注意頭[10,60]。具體來說,我們使用四種配置:DiT-S, DiT-B, DiT-L和DiT-XL。它們涵蓋了廣泛的模型大小和觸發(fā)器分配,從0.3到118.6 gflop,允許我們衡量縮放性能。表1給出了配置的詳細(xì)信息。我們將B、S、L和XL配置添加到DiT設(shè)計(jì)空間。
Transformer解碼器。在最后一個(gè)DiT塊之后,我們需要將我們的圖像標(biāo)記序列解碼為輸出噪聲預(yù)測和輸出對角線協(xié)方差預(yù)測。這兩個(gè)輸出的形狀都等于原始的空間輸入。我們使用標(biāo)準(zhǔn)的線性解碼器來做到這一點(diǎn);我們應(yīng)用最后一層范數(shù)(如果使用adaLN則自適應(yīng)),并將每個(gè)令牌線性解碼為p×p×2C張量,其中C是DiT空間輸入中的通道數(shù)。最后,我們將解碼后的標(biāo)記重新排列到原始的空間布局中,得到預(yù)測的噪聲和協(xié)方差。我們探索的完整DiT設(shè)計(jì)空間是補(bǔ)丁大小、Transformer塊架構(gòu)和模型大小。
4. 實(shí)驗(yàn)設(shè)置
我們探索了DiT設(shè)計(jì)空間,并研究了模型類的縮放屬性。我們的模型是根據(jù)它們的配置和潛在補(bǔ)丁大小p命名的;例如,DiT-XL/2指的是XLarge配置,p = 2。
訓(xùn)練。我們在ImageNet數(shù)據(jù)集[28]上訓(xùn)練256 × 256和512 × 512圖像分辨率的類條件潛在DiT模型,這是一個(gè)高度競爭的生成建?;鶞?zhǔn)。我們用零初始化最后的線性層,否則使用ViT的標(biāo)準(zhǔn)權(quán)重初始化技術(shù)。我們用AdamW[27,30]訓(xùn)練所有模型。我們使用1 × 10?4的恒定學(xué)習(xí)速率,沒有權(quán)重衰減,批處理大小為256。我們使用的唯一數(shù)據(jù)增強(qiáng)是水平翻轉(zhuǎn)。與之前對vit的許多工作不同[54,58],我們沒有發(fā)現(xiàn)學(xué)習(xí)率熱身或正則化對訓(xùn)練dit達(dá)到高性能是必要的。即使沒有這些技術(shù),訓(xùn)練在所有模型配置中都是高度穩(wěn)定的,我們沒有觀察到訓(xùn)練Transformer時(shí)常見的任何損失峰值。根據(jù)生成建模文獻(xiàn)中的常見做法,我們在訓(xùn)練中保持DiT權(quán)重的指數(shù)移動平均(EMA),衰減為0.9999。所有報(bào)告的結(jié)果均使用EMA模型。我們在所有DiT模型大小和補(bǔ)丁大小上使用相同的訓(xùn)練超參數(shù)。我們的訓(xùn)練超參數(shù)幾乎完全從adm中保留下來。我們沒有調(diào)整學(xué)習(xí)率、衰減/熱身計(jì)劃、Adam β1/β2或權(quán)重衰減。
擴(kuò)散。我們使用來自穩(wěn)定擴(kuò)散[45]的現(xiàn)成的預(yù)訓(xùn)練變分自編碼器(VAE)模型[27]。給定形狀為256 × 256 × 3, z = E(x)形狀為32 × 32 × 4的RGB圖像x, VAE編碼器的下采樣因子為8。在本節(jié)的所有實(shí)驗(yàn)中,我們的擴(kuò)散模型都在這個(gè)z空間中運(yùn)行。在從我們的擴(kuò)散模型中采樣一個(gè)新的潛伏后,我們使用VAE解碼器x = D(z)將其解碼為像素。我們保留ADM[9]的擴(kuò)散超參數(shù):具體來說,我們使用 t m a x = 1000 t_{max} = 1000 tmax?=1000線性方差表,范圍為1×10?4到2 ×10?2,ADM的協(xié)方差參數(shù)化 Σ θ \Sigma_{\theta} Σθ?和他們嵌入輸入時(shí)間步和標(biāo)簽的方法。
評價(jià)指標(biāo)。我們用Fréchet初始距離(FID)[18]來衡量縮放性能,這是評估圖像生成模型的標(biāo)準(zhǔn)度量。在與之前的工作進(jìn)行比較時(shí),我們遵循慣例,并使用250 DDPM采樣步驟報(bào)告FID-50K。眾所周知,F(xiàn)ID對小的實(shí)現(xiàn)細(xì)節(jié)[34]非常敏感;為了保證比較的準(zhǔn)確性,本文報(bào)告的所有值都是通過導(dǎo)出樣本和使用ADM的TensorFlow評估套件[9]得到的。除非另有說明,本節(jié)中報(bào)告的FID編號不使用無分類器指導(dǎo)。我們還報(bào)告盜夢評分[48],sFID[31]和精度/召回[29]作為次要指標(biāo)。
計(jì)算。我們在JAX[1]中實(shí)現(xiàn)所有模型,并使用TPU-v3 pod進(jìn)行訓(xùn)練。DiT-XL/2是我們最密集的計(jì)算模型,在全球批處理大小為256的TPU v3-256 pod上以大約5.7次/秒的速度訓(xùn)練。
5. 實(shí)驗(yàn)結(jié)果
DiT塊設(shè)計(jì)。我們訓(xùn)練了四個(gè)最高的Gflop DiT-XL/2模型,每個(gè)模型都使用了不同的塊設(shè)計(jì)——情境(119.4 Gflops)、交叉注意(137.6 Gflops)、自適應(yīng)層范數(shù)(adaLN, 118.6 Gflops)或adaLN- 0 (118.6 Gflops)。我們在整個(gè)訓(xùn)練過程中測量FID。圖5顯示了結(jié)果。adaLN-Zero塊產(chǎn)生的FID比交叉注意和上下文條件反射都要低,但計(jì)算效率最高。在400K訓(xùn)練迭代中,使用adaLN-Zero模型實(shí)現(xiàn)的FID幾乎是上下文模型的一半,這表明調(diào)節(jié)機(jī)制嚴(yán)重影響模型質(zhì)量。初始化也很重要——adalnzero將每個(gè)DiT塊初始化為恒等函數(shù),顯著優(yōu)于普通adaLN。F或其余的紙張,所有模型將使用adaLN-Zero DiT塊。
縮放模型大小和補(bǔ)丁大小。我們訓(xùn)練了12個(gè)DiT模型,覆蓋了模型配置(S, B, L, XL)和補(bǔ)丁大小(8,4,2)。注意,DiT-L和DiT-XL在相對gflop方面比其他配置明顯更接近。圖2(左)給出了每個(gè)模型的Gflops和它們在400K訓(xùn)練迭代時(shí)的FID的概述。在所有情況下,我們發(fā)現(xiàn)增加模型大小和減少補(bǔ)丁大小產(chǎn)生顯著改善的擴(kuò)散模型。圖6(上)展示了FID如何隨著模型大小的增加和補(bǔ)丁大小保持不變而變化。在所有四種配置中,通過使Transformer更深更寬,F(xiàn)ID在所有訓(xùn)練階段都得到了顯著的改進(jìn)。類似地,圖6(下)顯示了當(dāng)補(bǔ)丁大小減小且模型大小保持不變時(shí)的FID。我們再次觀察到在整個(gè)訓(xùn)練過程中,通過簡單地?cái)U(kuò)大DiT處理的令牌數(shù)量,保持參數(shù)大約固定,F(xiàn)ID有了很大的提高。
DiT gflop是提高性能的關(guān)鍵。圖6的結(jié)果表明,參數(shù)計(jì)數(shù)在決定DiT模型的質(zhì)量方面最終并不重要。當(dāng)模型尺寸保持不變,patch尺寸減小時(shí),Transformer的總參數(shù)實(shí)際上是不變的,只有Gflops增加。這些結(jié)果表明,縮放模型Gflops實(shí)際上是提高性能的關(guān)鍵。為了進(jìn)一步研究這一點(diǎn),我們將FID-50K在400K訓(xùn)練步驟中與模型Gflops繪制在圖8中。結(jié)果表明,具有不同大小和令牌的DiT模型在其總Gflops相似(例如DiT- s /2和DiT- b /4)時(shí)最終獲得相似的FID值。事實(shí)上,我們發(fā)現(xiàn)模型Gflops和FID-50K之間存在很強(qiáng)的負(fù)相關(guān),這表明額外的模型計(jì)算是改進(jìn)DiT模型的關(guān)鍵因素。在圖12(附錄)中,我們發(fā)現(xiàn)這一趨勢也適用于其他指標(biāo),如Inception Score。
較大的DiT模型計(jì)算效率更高。在圖9中,我們將FID繪制為所有DiT模型的總訓(xùn)練計(jì)算的函數(shù)。我們估計(jì)訓(xùn)練計(jì)算為模型Gflops·批大小·訓(xùn)練步驟·3,其中因子3大致近似于向后傳遞的計(jì)算量是向前傳遞的兩倍。我們發(fā)現(xiàn),即使訓(xùn)練時(shí)間較長,與訓(xùn)練步驟較少的大型DiT模型相比,小型DiT模型最終也會變得計(jì)算效率低下。類似地,我們發(fā)現(xiàn)除了補(bǔ)丁大小之外相同的模型即使在控制訓(xùn)練Gflops時(shí)也具有不同的性能配置文件。例如,XL/4在大約1010 Gflops后的性能優(yōu)于XL/2。
可視化擴(kuò)展。我們在圖7中可視化縮放對樣本質(zhì)量的影響。在400K訓(xùn)練步驟中,我們使用相同的起始噪聲xtmax、采樣噪聲和類標(biāo)簽從我們的12個(gè)DiT模型中采樣一張圖像。這讓我們可以直觀地解釋縮放如何影響DiT樣本質(zhì)量。事實(shí)上,縮放模型大小和令牌數(shù)量在視覺質(zhì)量上都有顯著的改善。
5.1. State-of-the-Art Diffusion Models
256×256 ImageNet。在我們的縮放分析之后,我們繼續(xù)訓(xùn)練我們最高的Gflop模型DiT-XL/2,用于7M步長。我們在圖1中展示了該模型的樣本,并將其與最先進(jìn)的類條件生成模型進(jìn)行了比較。我們在表2中報(bào)告結(jié)果。當(dāng)使用無分類器制導(dǎo)時(shí),DiT-XL/2優(yōu)于所有先前的擴(kuò)散模型,將先前由LDM實(shí)現(xiàn)的最佳FID-50K從3.60降低到2.27。圖2(右)顯示DiT-XL/2 (118.6 Gflops)相對于LDM-4 (103.6 Gflops)等潛在空間U-Net模型的計(jì)算效率更高,并且比ADM (1120 Gflops)或ADM- u (742 Gflops)等像素空間U-Net模型的效率更高。我們的方法實(shí)現(xiàn)了所有先前生成模型中最低的FID,包括先前最先進(jìn)的StyleGANXL[50]。最后,我們還觀察到,與LDM-4和LDM-8相比,DiT-XL/2在所有測試的無分類器引導(dǎo)量表上獲得了更高的召回值。當(dāng)只訓(xùn)練2.35M步長(類似于ADM)時(shí),XL/2仍然優(yōu)于所有先前的擴(kuò)散模型,其FID為2.55。
512×512 ImageNet。我們在ImageNet上以512 × 512分辨率訓(xùn)練一個(gè)新的DiT-XL/2模型,用于3M迭代,其超參數(shù)與256 × 256模型相同。在補(bǔ)丁大小為2的情況下,這個(gè)XL/2模型在對64 × 64 × 4輸入潛在值(524.6 Gflops)進(jìn)行補(bǔ)丁后,總共處理1024個(gè)令牌。表3顯示了與最先進(jìn)方法的比較。XL/2在此分辨率下再次優(yōu)于所有先前的擴(kuò)散模型,將ADM的最佳FID從3.85提高到3.04。即使增加了令牌數(shù)量,XL/2仍然保持計(jì)算效率。例如,ADM使用1983個(gè)gflop, ADM- u使用2813個(gè)gflop;XL/2使用524.6 gflop。我們在圖1和附錄中展示了來自高分辨率XL/2模型的樣本。
5.2. Model Compute vs. Sampling Compute
與大多數(shù)生成模型不同,擴(kuò)散模型的獨(dú)特之處在于,它們可以在生成圖像時(shí)通過增加采樣步驟的數(shù)量來訓(xùn)練后使用額外的計(jì)算??紤]到模型Gflops在樣本質(zhì)量中的重要性,在本節(jié)中,我們將研究較小的模型計(jì)算dit是否可以通過使用更多的抽樣計(jì)算來優(yōu)于較大的dit。我們在400K訓(xùn)練步驟后計(jì)算所有12個(gè)DiT模型的FID,每張圖像使用[16,32,64,128,256,1000]采樣步驟。主要結(jié)果如圖10所示??紤]使用1000個(gè)采樣步驟的DiT-L/2與使用128個(gè)采樣步驟的DiT-XL/2。在這種情況下,L/2使用80.7 tflop對每張圖像進(jìn)行采樣;XL/2使用5倍少的計(jì)算量- 15.2 tflop -對每張圖像進(jìn)行采樣。盡管如此,XL/2具有更好的FID-10K (23.7 vs 25.9)。一般來說,抽樣計(jì)算不能彌補(bǔ)模型計(jì)算的不足。
6. 結(jié)論
我們介紹了擴(kuò)散Transformer(DiTs),這是一種簡單的基于Transformer的擴(kuò)散模型骨干,優(yōu)于先前的U-Net模型,并繼承了Transformer模型類的優(yōu)秀縮放特性。鑒于本文中有希望的擴(kuò)展結(jié)果,未來的工作應(yīng)該繼續(xù)將dit擴(kuò)展到更大的模型和令牌數(shù)量。DiT還可以作為文本-圖像模型(如DALL·e2和穩(wěn)定擴(kuò)散模型)的主干進(jìn)行探索。
A. 其他實(shí)施細(xì)節(jié)
我們在表4中包含了關(guān)于所有DiT模型的信息,包括256 × 256和512 × 512模型。我們包括Gflop計(jì)數(shù),參數(shù),訓(xùn)練細(xì)節(jié),fid等。我們還在表6中包括了來自ADM和LDM的DDPM U-Net模型的Gflop計(jì)數(shù)。DiT模型細(xì)節(jié)。為了嵌入輸入時(shí)間步長,我們使用256維頻率嵌入[9],然后使用兩層MLP,其維度等于Transformer的隱藏大小和SiLU激活。每個(gè)adaLN層將時(shí)間步長和類嵌入的和饋送到一個(gè)SiLU非線性層和一個(gè)線性層,輸出神經(jīng)元等于Transformer的隱藏大小的4× (adaLN)或6× (adaLN- 0)。我們在核心Transformer[16]中使用了GELU非線性(近似于tanh)。
B. VAE解碼器消融
我們在實(shí)驗(yàn)中使用了現(xiàn)成的、預(yù)先訓(xùn)練好的V AEs。V AE模型(ft-MSE和ft-EMA)是原始LDM“f8”模型的微調(diào)版本(只有解碼器權(quán)重進(jìn)行了微調(diào))。在第5節(jié)中,我們使用ft-MSE解碼器來監(jiān)控縮放分析的指標(biāo),我們使用ft-EMA解碼器來處理表2和表3中報(bào)告的最終指標(biāo)。在本節(jié)中,我們?nèi)サ袅巳N不同的V AE解碼器的選擇;LDM使用的原始譯碼器和Stable Diffusion使用的兩個(gè)微調(diào)譯碼器。因?yàn)榫幋a器在模型中是相同的,解碼器可以在不重新訓(xùn)練擴(kuò)散模型的情況下被替換。表5顯示了結(jié)果;當(dāng)使用LDM解碼器時(shí),XL/2優(yōu)于所有先前的擴(kuò)散模型。
C. 模型樣本
我們展示了來自兩個(gè)DiT-XL/2模型的樣本,分別在512 × 512和256 × 256分辨率下訓(xùn)練3M和7M步長。圖1和11顯示了從兩個(gè)模型中選取的樣本。圖13到32顯示了兩個(gè)模型在一系列分類器自由引導(dǎo)尺度和輸入類別標(biāo)簽(使用250 DDPM采樣步驟和ft-EMA VAE解碼器生成)上的非策展樣本。與之前使用指導(dǎo)的工作一樣,我們觀察到更大的尺度增加了視覺保真度,減少了樣本多樣性。
256 × 256分辨率下訓(xùn)練3M和7M步長。圖1和11顯示了從兩個(gè)模型中選取的樣本。圖13到32顯示了兩個(gè)模型在一系列分類器自由引導(dǎo)尺度和輸入類別標(biāo)簽(使用250 DDPM采樣步驟和ft-EMA VAE解碼器生成)上的非策展樣本。與之前使用指導(dǎo)的工作一樣,我們觀察到更大的尺度增加了視覺保真度,減少了樣本多樣性。文章來源:http://www.zghlxwxcb.cn/news/detail-650403.html
文章來源地址http://www.zghlxwxcb.cn/news/detail-650403.html
到了這里,關(guān)于Scalable Diffusion Models with Transformers的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!