DALL·E 2【論文精讀】_嗶哩嗶哩_bilibili更多論文:https://github.com/mli/paper-reading, 視頻播放量 30350、彈幕量 256、點(diǎn)贊數(shù) 1767、投硬幣枚數(shù) 1318、收藏人數(shù) 751、轉(zhuǎn)發(fā)人數(shù) 344, 視頻作者 跟李沐學(xué)AI, 作者簡介 ,相關(guān)視頻:博一研究生 求偶視頻,如何做好文獻(xiàn)閱讀及筆記整理,在線求偶|26歲985副教授,開組會時,師兄SCI見刊了,生成對抗網(wǎng)絡(luò)GAN開山之作論文精讀,GAN論文逐段精讀【論文精讀】,對比學(xué)習(xí)論文綜述【論文精讀】,01 機(jī)器學(xué)習(xí)編譯概述 【MLC-機(jī)器學(xué)習(xí)編譯中文版】,導(dǎo)師對不起,您評院士的事可能得緩緩了,【精讀AI論文】知識蒸餾https://www.bilibili.com/video/BV17r4y1u77B?spm_id_from=333.999.0.0&vd_source=4aed82e35f26bb600bc5b46e65e25c22看到市面上的一些關(guān)于dalle2的的解釋其實都不太好,沒說的很明白,生成模型的三大方向分別是vae,gan和擴(kuò)散模型,其中ae->dae->vae->vqvae->diffusion,擴(kuò)散模型的ddpm->improved ddpm->diffusion beets GAN->glide->dalle2.
1.introduction
? ? ? ? clip對圖像分布變化具有魯棒性,可以zero-shot,擴(kuò)散模型能滿足樣本多樣性且保真度也不錯。dalle2結(jié)合了這兩個模型的優(yōu)良特性。
2.method
上面這張圖畫的很好,結(jié)合這個圖來看,首先虛線上面是一個clip,這個clip是提前訓(xùn)練好的,在dalle2的訓(xùn)練期間不會再去訓(xùn)練clip,是個權(quán)重鎖死的,在dalle2的訓(xùn)練時,輸入也是一對數(shù)據(jù),一個文本對及其對應(yīng)的圖像,首先輸入一個文本,經(jīng)過clip的文本編碼模塊(bert,clip對圖像使用vit,對text使用bert進(jìn)行編碼,clip是基本的對比學(xué)習(xí),兩個模態(tài)的編碼很重要,模態(tài)編碼之后直接余弦求相似度了),在輸入一個圖像,經(jīng)過clip的圖像編碼模塊,產(chǎn)生了圖像的vector,這個圖像vector其實是gt。產(chǎn)生的文本編碼輸入到第一個prior模型中,這是一個擴(kuò)散模型,也可以用自回歸的transformer,這個擴(kuò)散模型輸出一組圖像vector,這時候通過經(jīng)過clip產(chǎn)生的圖像vector進(jìn)行監(jiān)督,此處其實是一個監(jiān)督模型,后面是一個decoder模塊,在以往的dalle中,encoder和decoder是放在dvae中一起訓(xùn)練的,但是此處的deocder是單訓(xùn)的,也是一個擴(kuò)散模型,其實虛線之下的生成模型,是將一個完整的生成步驟,變成了二階段顯式的圖像生成,作者實驗這種顯式的生成效果更好。這篇文章稱自己為unclip,clip是將輸入的文本和圖像轉(zhuǎn)成特征,而dalle2是將文本特征轉(zhuǎn)成圖像特征再轉(zhuǎn)成圖像的過程,其實圖像特征到圖像是通過一個擴(kuò)散模型實現(xiàn)的。在deocder時既用了classifier-free guidence也用了clip的guidence,這個guidence指的是在decoder的過程中,輸入是t時刻的一個帶噪聲的圖像,最終輸出是一個圖像,這個帶噪聲的圖像通過unet每一次得到的一個特征圖可以用一個圖像分類器去做判定,此處一般就用交叉熵函數(shù)做一個二分類,但是可以獲取圖像分類的梯度,利用這個梯度去引導(dǎo)擴(kuò)散去更好的decoder。
3.代碼
GitHub - lucidrains/DALLE2-pytorch: Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
核心是訓(xùn)練一個先驗?zāi)P秃鸵粋€decoder模型,這兩個都是擴(kuò)散模型,當(dāng)然先驗?zāi)P鸵部梢允亲曰貧wAE,如果是自回歸AE就是dalle的思路,clip自己選一個訓(xùn)練好的即可,clip的本質(zhì)是提供良好的圖像和文本的vector。
train_diffusion_prior輸入是經(jīng)過img2dataset和clip-retrieval的轉(zhuǎn)換,通過img2dataset下載數(shù)據(jù),訓(xùn)練先驗?zāi)P?,主要是通過clip-retrieval生成需要的img_emb/text_emb和meta_url。train_decoder的訓(xùn)練輸入是img2dataset生成的tar模型,tar中包含圖片即可,在dalle2-pytorch中會對輸入的圖片做判定,內(nèi)置了一個clip對其進(jìn)行特征提取,實際上先驗和生成模型作為顯式分離的部分,是單獨(dú)訓(xùn)練的。
diffusion_prior
DiffusionPrior(
(noise_scheduler): NoiseScheduler()
(clip): OpenAIClipAdapter(
(clip): CLIP(
(visual): VisionTransformer(
(conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(12): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(13): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(14): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(15): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(16): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(17): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(18): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(19): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(20): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(21): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(22): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(23): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(token_embedding): Embedding(49408, 768)
(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
(net): DiffusionPriorNetwork(
(to_text_embeds): Sequential(
(0): Identity()
(1): Rearrange('b (n d) -> b n d', n=1)
)
(to_time_embeds): Sequential(
(0): Embedding(1000, 768)
(1): Rearrange('b (n d) -> b n d', n=1)
)
(to_image_embeds): Sequential(
(0): Identity()
(1): Rearrange('b (n d) -> b n d', n=1)
)
(causal_transformer): CausalTransformer(
(init_norm): Identity()
(rel_pos_bias): RelPosBias(
(relative_attention_bias): Embedding(32, 12)
)
(layers): ModuleList(
(0): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(1): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(2): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(3): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(4): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(5): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(6): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(7): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(8): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(9): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(10): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(11): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
)
(norm): LayerNorm()
(project_out): Linear(in_features=768, out_features=768, bias=False)
)
)
)
先驗?zāi)P褪且粋€自回歸的transformer,流程如下:
TrainDiffusionPriorConfig.from_json_path()->prior/data/train/tracker(保存)->train:DiffusionPriorTrainer=make_model()->(prior_config)DiffusionPriorConfig/DiffusionPriorTrainConfig->diffusion_prior=prior_config.create()->clip=clip.create()->AdapterConfig.create()->OpenAIClipAdapter()->diffusion_prior_network=net.create()->DiffusionPriorNetworkConfig.create()->DiffusionPriorNetwork->trainer=DiffusionPriorTrainer->tracker=create_tracker()->TrackerConfig.create()->img_reader=get_reader()此處輸入是三組img_url/text_url/meta_url->image_reader=EmbeddingReader()/text_reader=EmbeddingReader()->train_loader/eval_loader/test_loader=make_splits->train:DiffusionPriorTrainer/Tracker/DiffusionPriorTrainConfig->img:16,768,txt:16,77->DiffusionPrior.forward:net:DiffusionPriorNetwork->image_embed,_=self.clip.embed_image()/text_embed,text_encodings=self.clip.embed_text()->times=self.noise_scheduler.sample_random_times->self.p_losses->image_embed_noisy=self.noise_scheduler.q_sample(NoiseScheduler)->pred=self.net:image_embed_noisy:16,768,text_cond:text_embed 16,768/text_encodings16,77,768->DiffusionPriorNetwork,forward()->image_embed:16,768,text_embed:16,768->tokens:16,81,768->pred_image_embed:16,768->target=noise:16,768->loss=self.noise_scheduler.loss_fn(l2:mse)->trainer.update
decoder:文章來源:http://www.zghlxwxcb.cn/news/detail-568298.html
Decoder(
(clip): OpenAIClipAdapter(
(clip): CLIP(
(visual): VisionTransformer(
(conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(12): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(13): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(14): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(15): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(16): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(17): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(18): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(19): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(20): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(21): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(22): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(23): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(token_embedding): Embedding(49408, 768)
(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
(unets): ModuleList(
(0): Unet(
(init_conv): CrossEmbedLayer(
(convs): ModuleList(
(0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Conv2d(3, 4, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
(2): Conv2d(3, 4, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7))
)
)
(to_time_hiddens): Sequential(
(0): SinusoidalPosEmb()
(1): Linear(in_features=16, out_features=64, bias=True)
(2): GELU()
)
(to_time_tokens): Sequential(
(0): Linear(in_features=64, out_features=32, bias=True)
(1): Rearrange('b (r d) -> b r d', r=2)
)
(to_time_cond): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
)
(image_to_tokens): Identity()
(norm_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(norm_mid_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(downs): ModuleList(
(0): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): Identity()
(4): Conv2d(16, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(1): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=16, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=16, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=16, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=16, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=16, out_features=64, bias=False)
(to_kv): Linear(in_features=16, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=16, bias=False)
(1): LayerNorm()
)
)
)
)
(4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(2): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(block1): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=64, bias=False)
(to_kv): Linear(in_features=32, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
)
(4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(3): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(block1): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=64, bias=False)
(to_kv): Linear(in_features=64, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
)
(4): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
)
)
(ups): ModuleList(
(0): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=64, bias=False)
(to_kv): Linear(in_features=128, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
)
(3): PixelShuffleUpsample(
(net): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
(1): SiLU()
(2): PixelShuffle(upscale_factor=2)
)
)
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=64, bias=False)
(to_kv): Linear(in_features=64, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
)
(3): PixelShuffleUpsample(
(net): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
(1): SiLU()
(2): PixelShuffle(upscale_factor=2)
)
)
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=64, bias=False)
(to_kv): Linear(in_features=32, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
)
(3): PixelShuffleUpsample(
(net): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
(1): SiLU()
(2): PixelShuffle(upscale_factor=2)
)
)
)
(3): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): Identity()
(3): Identity()
)
)
(mid_block1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(mid_attn): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=64, bias=False)
(to_kv): Linear(in_features=128, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
)
(mid_block2): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(upsample_combiner): UpsampleCombiner()
(final_resnet_block): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
(to_out): Conv2d(16, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
(vaes): ModuleList(
(0): NullVQGanVAE()
)
(noise_schedulers): ModuleList(
(0): NoiseScheduler()
)
(lowres_conds): ModuleList(
(0): None
)
)
TrainDecoderConfig->DecoderConfig/DecoderDataConfig/DecoderTrainConfig/DecoderEvaluateConfig/TrackerConfig->dataloader=create_dataloaders->create_image_embedding_dataloader->decoder=config.decoder.create()->DecoderConfig.create()->Unet(unconfigs)->clip=clip.create()->OpenAIClipAdapter->Decoder->tracker=create_tracker->train:DecoderTrainer->img:4,3,224,224,txt:cat/sea/tree/motel->DecoderTrainer.forward()->self.decoder->Decoder.forward()->resize_image_to:image:4,3,64,64->image=vae.encode(image)->p_losses->x_noisy=noise_scheduler.q_sample()->model_output=unet():4,3,64,64->Unet->target:4,3,64,64,pred:4,2,64,64->loss=noise_scheduler.loss_fn:l2:mse->mean,var=noise_scheduler.q_posterior()->kl=normal_kl->loss+vb_loss文章來源地址http://www.zghlxwxcb.cn/news/detail-568298.html
到了這里,關(guān)于dalle2:hierarchical text-conditional image generation with clip的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!