From 8a34671fe91e142bce9e5556cca2258b3be9dd6e Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 24 Mar 2023 22:48:16 -0400 Subject: [PATCH] Add support for the Variations models (unclip-h and unclip-l) --- launch.py | 2 +- models/karlo/ViT-L-14_stats.th | Bin 0 -> 7079 bytes modules/lowvram.py | 10 +++++--- modules/processing.py | 41 +++++++++++++++++++++--------- modules/sd_models.py | 5 ++++ modules/sd_models_config.py | 7 +++++ modules/sd_samplers_compvis.py | 31 +++++++++++++++++----- modules/sd_samplers_kdiffusion.py | 19 +++++++++----- 8 files changed, 85 insertions(+), 30 deletions(-) create mode 100644 models/karlo/ViT-L-14_stats.th diff --git a/launch.py b/launch.py index b943fed22..e70df7ba7 100644 --- a/launch.py +++ b/launch.py @@ -252,7 +252,7 @@ def prepare_environment(): codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') - stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e") + stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") diff --git a/models/karlo/ViT-L-14_stats.th b/models/karlo/ViT-L-14_stats.th new file mode 100644 index 0000000000000000000000000000000000000000..a6a06e94ecaa4f2977972ff991f75db6c90403ea GIT binary patch literal 7079 zcmb7p2UJu`vo0t}k|>HHK*5ZPnDap~9urJA z22=zI6PRIuA?KtdK``9$p8x*$oOADcZ@r$?t9Q??s$RQl*H=}$9PPWRC@E=ZDE*J2 zr_@u)eT&CB@2zV_d%6d>kJ`M!XXaF0CFj45kS?x%N@gAbn-r9z+yVoWVY`6_oB?Yy7(W$)Y*aXOnxt;y?e?^y})Rte_d= zr{dIAv3hI{^i!Ru)HT3QZK1uA;$okly1nu~KaB_vk4-*4YdiwHH~IQ&F4*F^X3HAS z>E0dzntt8P7x`%h_-QTj(_W`Ib6uBpf6BD-(^>bYEZzT-)%}+&y@++cSb7Bb^<3nq zzh;!9y}HJx9wu>Fit~R>$N!^#bd|>bO+TZ@%$z#rAKLkws{YrFx^ga>p0CFBn^zLU zn}=haz>G}~d`+r$|3-e=e@lEHeb6pd*NO?3yri5^%3=;JCufeGq=x#Lbe~ByjmT`V zd%T!Wn&Ngt?SX{f-R5^1zjP82)E=dSUZ233TOQ1)=oUYH#d#uW=t>=KpCjE$-|&xa zH)J&qK`8kc!KUR7XP!ahVEiEwozd;|+D8ZCPez?5t7O!kbI$zS|AdV)f zFq`%vZ2ljU5_;=;CMNf~&Ax2}EngZXcqw$uAxwD9wwCb7u8LQ36ikanq-R($ti`Q}X| z|BwR%amKt4v*?DHF(_zv#Pa$3sX1>Td475{8E{LTN$FNNTa4%3I6RL;|JH>iZz&?8 zI9K?EK4Er`8c1o@3o>l?SQ?Y4N!BJVWJ{SRwX!?Q|4L`DWt#QWq4OLssn;;7zjZJ> zLod8&qF<AG0uC1-%-zoVveCAyazV@&Yg-WK{1_~?=5wO(xcgfIM2%bt+)F}_Uw zY6$P-Pkr<>715s=!&zedOTPNhbpB$~wV2ueIYuumrG0&llU^U*@U`=25RVtX@uh(e ziCxK94403^qZv_TUG7*KvgQjhYV~4EZ(Spa3ks=i;A3*r)Qab1zZK(3HF0KNI5Tr9 zAcKQ*__JS2>34&%I6ujar?DrV8a`|zUyTaMJH2|EyK^Fq-Ek(#Zb}JH5Iu~Ed!OTM zZ31{R7dO&`cwKg=`48$UP-PiEhLNGR%FO9@2hYSzja6H{rIsT~NT^*lITkj7_Bg3d z#;+d?SL3sEw)T6hTtA65AHPSRSDmL>YSz>}SRY5%?&E0-o>MC|6P~|EFQ)p+n!Ogz z#V7q^ygR>-CQ5HraWu`F*R%a3buTd_%fdu->nJUDW@iKGw&Vp_X|>6&WX3jj;*tk_ zax<9(h&0hddxh=VG#4U@G^Upx3P{g6bI9vX8@s}5x5ybUQ~Gw93%S?EV@ZYmP+_4< zt{iEkRmwq#oa9VYBbMRKfai2Vw2&P9noMlCE|~bSfm&sHaqpKabYpLdwSu3#oEHaJ z*qqIYLo2`1*t_RQ#akhJwYewRK3Po1>qHTcb+3_V-%PJ*aY^xgQ>cfdH9PWRJTn#E zB+sX7v%{CeAgWCvy$tWtgK2t5G*(Kq{;Gll6_0tvSHk&UFNBgqTQAT6@|27Sdq-tE zLdarYJr>(EmbYqLUzYsxY?2Y*o4!=aVs~@jlUXV8^xnv^COGbgc**!eu0zy#`jC6KirokXIHbP&APik|8dOI=*2Kug7&Z|)dF zqsHDP!}`dmw0R+dueIB*n|zV@XbopsJauMj^pXXStfAIl)#yw$K3%Z?Sd!_zWAwzZ zc>Ww-6MsSC3R+NM&UoUkNW2-zx7aX{cAXN&HX7ciB!4|Mp4Lvotd6-hRq zZKh(Qcp`ZuCbf@xVAlh}rh4{dcHZIah5972YLtM?vhRsO=X)^6(J}0qu{Z0V6Uw_; zGlqP1HYO#(XQ^;y2HDc58$6tQV6ovKj?<+@;MR`(;3_gh~kY1;>XM;kg;cPg=+zl5GBNFp|a)!CLQ zx|p|BnJuzf%PyEc1U22rRO1F?!uGj{e>0ajYG+c%%OS+1d^X!KOqJIlo6Ub9TZ$#Q zKX`q*cB3o%9AvqNqL`9k1DVd#z_{3WDiB|#;X^KC?n`aNw;pB_0!v8=V!$1Dz^yNasTXX82mh@6)P?bH~it#o5KIrpS{v=4#Pd3EAlVrZ*N|?MpOf zpF(PnXT&$<0H*ZcPjm-W@p>F`L57wWuWX?WKD!KJrN%|{?VI`hx%V;>jW3I-x8*q2 zXW3_d>#{?5Sh9dfc293R3A<{JiM_tl)O>AByS@m`k4Lai zmYaDs=F@l4+uxD5nl|kI1!FkH=JBJh@8=C{`-AtgXgKmGtVeub z!thB-hjm_{J+E7`$pwd*pTSP@ZPN~p3vxH-#sbDc!Q7{RCL z2kpIV2X(5mB%j^2*}Dba?7FH7J3rQi##(FBa|Ro^=I8y{?aLQv?fZ}Ln{|_HeMsnD z)x+3t*vX5uo4|`RaYILqGa{3}(r&YQGZXOzYC9l~7hK(y(f+d_&#a_sE>V1k@Vj(r zi5BMkRz?m4TCi1G%5a(!NetXR^SfTl;k{Tt4uPv&sa)$eiHSYMTtx$^=iVD6RJk-E zOIJos2am)qizaekOb6WW}V$9}zbvlxkF+f|b)y5_j=ED}3;P>6f1* z$MOtV|MaEc9(||7YP+!zNey|sYcjm_W}@Gh5q#e}u5dkQ!E^-b@E28+;cjLm-|0Es zcJDNKz33VlfqlFSCswi`-DuWmsPNy-79{eDC)MfyFj1pp0{_lhcc$Dj3Tk%`(3}Go zXqD~_h?jf9{jvjER}7}+RvkoAEF$7+XWE(TN=bT`DQ@4Qh<=9~9JYke3x?y!R+IU3 za=I0{YJP=WIGju$k2%gJ^P_o3O|B4CwUO|+)R*bJyG-AWdxRQ)Cw|DP^+fgLE<6+6 z=dB*M149qYB_~_EvZ$UH=x@Ek$=UB)@hl~Tgy}71D|KeDYa!uuUhF|w&dukoo*Pg4 zD%-J7CvEw+T?}BGU`UhmMv+bRL2UZ#DePkQM>;7EJ}eo)2lSv}9&uEJ+m2e& z=Hf{l8v4-f-$Urbp@Vq|%^&E%S<92oJ0-G8xnYvQtr-6H`)zc?)dBpX1bw`Jk;a$w zuBE#r=SWan4bScJ32NMWo9=3{<^Qf>!EZTwieEnO0iDxUL7bMDl0CvX=(Ww2Smx|X zOj$Y}hKFr&Epj%qxV)K$<){&dXl3xG?8Si~6{33RS;Dl6-7qy6loY&g7B)u~+nID( zN%zmaPAu2t&>#8>$fT*aS)=6z=(V+y#U7b-iMN0}nK=c&>WxhE1@kluj3P;l79IaMk6h@R85trq8Xx`_6QuOGm9d4?Da|FgK?V^JkZ#k4qj@JH@C!QxC&^#jsl`hsAIa zhVY9J7nY2>QDv}+$cD*NA+~DCFs9-?mQ85F5t~Xp9wmd9O*YCl39(^sCUno{V!mG! zR!%R)$;wK6Ir0^g&*x&DeHGleT%=dOftjNm%FjeF_|T3EO%g1aAj0isP1q{WL3Mru zLY>;6w>AyR$~Djt<)O!nDwJF0AZnZ(+h>d6_ehL0WkMWlZbOdICw!@?hbkw;gSJm7 zUX%w>sxc=!mxhw*0!*mL!i(@WIG*~1%!~C{=9rEo-(vhM5n}ltGIVLrK;M8&)Hh2o z_E|Q}bqnxvSTSyO{|2Ld6}TqIK&5gSLT4&-tsPldZkvbQx3UmhE`k1%5=@PipkZ1y zBK&eNFW8AY)z*n_#<^&*6C%d42n*wckTwgjWTP>suO@^2=VIKvUxlx|v(c{l34(hy z-1Plo*!|IpX&W0c<0iO;2aAv(E5fS8Qlyg-gzc2z!^%`VT$PPZQ!%Pnf5PzG78tB6 zht{efoq1)8CvgbKyf7py`19I`a zt^$S?Ip{cD0HgLwc-o23RwaW{^=oKKDEy7naBheI!?zScwy6ZG(#2SRtPLt-a>2qX zQO*zJ{wQ$a)XvLLQ{lt~Ud(~$bT+no2%&Ucgj)e0Vc(&_8Ed2>#v~uqCI`*AX;^EN z2kG2;s42_Q7^KP_Od8Kc)_h0WT@fNTD4unO2#ueF@Uxr3-O?z+A5SwOF|9{8YcUo| zKA`k=8mi)Da9<@tpX5SR?@YyUYlSU-FGKjNYCL-@!IBFq+_171s6V9;Tk$xQKc(I zm&Za37}kgbV+x^nSPX7JDFRY5ut$)Bc_bGNsKu#xJMP)_0`z^U!rAl{;nXUHKk8(| zDWnBC$`UMn)0=bmsemu!s5QyP&?~xJ+W|3tEV1R1_esI+&&K*OVyLT&VRbwMnTvEd zFQWpiSu29}odV<>72$qf4(x3eH1%?X?frl?;W8{PmY^vk537hA!De~5xU>$DE%~q+ zx|6HZ%*Cm?!JLazA%HxYkM~?CA1#(I|aDnnhL|01-Lz5iW4G=cc~(5 zxLb#VeM&L8UVyBK0;Ff8BV&XFuZO0==RhgSd&yD3mtt4QM>KkrLSMfGE2=Z$_`!^e zDa^vxwpN_AFyMmw2+-mrgWuDWIJ36}yo6+Y`Cf_RZh2@{lVDqG}9(s=_T) zEx@ap5?HO4z*4&g9)eV8SQetTqXfQvz};UxoV%Z%jkxXwxD;-|sS+)2Xgsi> zGvbU)v(fEADw-0Nx%=lE(f7d+E;+aiUECBleOid=dkP?!CW5p+1%X|9aLI!Uuq`ed zx`GEJLX)?JLcKGhH$5g_724lJ(hM4J0sEWOi;1DEAc z$x&eW9tpZU&x2l4Hd-bX;q^B$vS&}_v>#U>@Pa+pPdyXg&C4*=R*r|3Vmz2#i1y41 z#O#$|*|}n5zRbm|vl$4HCU6xoHr$?QDR|Bb{hdF@J&MUgO|cwX&q%SYt`IX6esJcb z0Q1uXxZ%-=y$XG%mg#drExC}iXJJTJ5wu>XV%&Izzid~;6L&c_PcK4fqvHMAl7+EL zTJU^VS1#dYHq;-8@W`fGqA+utC2Ik((N9(q7tadWu_QvMn&azhYz9mCW=wNORNr%Z9EzTuFjLR$!_ImFy z(5xS4Yfykbu60;3QiD5MnT$K9YEf!Xh04$Qcq%`|{WQ&mYJMg*{*Yp>iZ3V7$iU=< z3cvU!MBmf^?rQJ%7^8@9Yr~TuO%TAvEE}==M2IfB{hdHo^6z5Jg`y5f)mAniVC8^K;{7nTv33S^?ZGWVm=y zkz3?_!9yPb9$Ob+U9}?jNRh+qRS^y>5W;A50nU7DKzq-R=qD6Fq`>g^Jx6iJFPCFh zj3S51P~ysWRKw6W87jv0SkWaL5kustFDXY*-gm5R65&Q)W6svU94-g4aQB1^Go~u+ z_NE@^U9#Y?WCVA8>mm+2yK#;e3lLW$g#1GRG|p$>&YXPA-I#$RZeK98R|U#@S93WI z0_gp|oO@v=f^FYobiMuoQ*;Uu_P7epkMhv{U=|m(w+-Ccbi9ogB6@u$G@MHjwKWT; zT!ko@pomq9oFZaZCA75Hx%ut>>5 zMqC#3Ph>%`rU+vlzT?}-Y}oz~AumjTGmpw)?5@ZGMd_&9nF}Y4NN&GtId0|2VLe=m zgzw)VCAAnM{{Sz38U_x`$H|avyvxi&nJ|c3eLDwpR^_7KZ;E%$IS1F5%kjZogzMuQ zpt`aE`}=1?yAIs36W`!lmx;GEA2HEFjAOlXAx`c@-oi}ahXQj_Yau$khqE?wku_vE1C=R#ITynW$Pz{9?IFeQL!|#%KRlpkZvVgQ zhwe(#{#`$8U9-hs(L0$rmHYp0Sp1`lVWau0f)QNwTG~ZX!2Fty_R7DOfA^IB+D*~> zb>na2pB<)ua{sk9()}B!k)l=fSMJec{>l4Sz5Jc`n9e_Vt^UdT*C*=!w{x@l{WtHw z^yp}>rZ(bVM1#AI{5QFxv-Y3kzt-k|?HBd`^FqP*t0$&3W5A#LU(GK^dsVexI~B$6 J7yX~K{{h(9ncn~a literal 0 HcmV?d00001 diff --git a/modules/lowvram.py b/modules/lowvram.py index 042a0254a..e254cc131 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -55,12 +55,12 @@ def setup_for_low_vram(sd_model, use_medvram): if hasattr(sd_model.cond_stage_model, 'model'): sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model - # remove four big modules, cond, first_stage, depth (if applicable), and unet from the model and then + # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then # send the model to GPU. Then put modules back. the modules will be in CPU. - stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), sd_model.model - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = None, None, None, None + stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model + sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None sd_model.to(devices.device) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.model = stored + sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored # register hooks for those the first three models sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) @@ -69,6 +69,8 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.first_stage_model.decode = first_stage_model_decode_wrap if sd_model.depth_model: sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) + if sd_model.embedder: + sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model if hasattr(sd_model.cond_stage_model, 'model'): diff --git a/modules/processing.py b/modules/processing.py index 59717b4c6..1451811ca 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -78,22 +78,28 @@ def apply_overlay(image, paste_loc, index, overlays): def txt2img_image_conditioning(sd_model, x, width, height): - if sd_model.model.conditioning_key not in {'hybrid', 'concat'}: - # Dummy zero conditioning if we're not using inpainting model. + if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + + elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models + + return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) + + else: + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - return image_conditioning - class StableDiffusionProcessing: """ @@ -190,6 +196,14 @@ class StableDiffusionProcessing: return conditioning_image + def unclip_image_conditioning(self, source_image): + c_adm = self.sd_model.embedder(source_image) + if self.sd_model.noise_augmentor is not None: + noise_level = 0 # TODO: Allow other noise levels? + c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0])) + c_adm = torch.cat((c_adm, noise_level_emb), 1) + return c_adm + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): self.is_using_inpainting_conditioning = True @@ -241,6 +255,9 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key in {'hybrid', 'concat'}: return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if self.sampler.conditioning_key == "crossattn-adm": + return self.unclip_image_conditioning(source_image) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models.py b/modules/sd_models.py index f0cb12400..c1a80d828 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -383,6 +383,11 @@ def repair_config(sd_config): elif shared.cmd_opts.upcast_sampling: sd_config.model.params.unet_config.params.use_fp16 = True + # For UnCLIP-L, override the hardcoded karlo directory + if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"): + karlo_path = os.path.join(paths.models_path, 'karlo') + sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 91c217004..9398f5284 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,8 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") +config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") +config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") @@ -65,9 +67,14 @@ def is_using_v_parameterization_for_sd2(state_dict): def guess_model_config_from_state_dict(sd, filename): sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model + elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: + return config_unclip + elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: + return config_unopenclip if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 083da18ca..bfcc55749 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -70,8 +70,13 @@ class VanillaStableDiffusionSampler: # Have to unwrap the inpainting conditioning here to perform pre-processing image_conditioning = None + uc_image_conditioning = None if isinstance(cond, dict): - image_conditioning = cond["c_concat"][0] + if self.conditioning_key == "crossattn-adm": + image_conditioning = cond["c_adm"] + uc_image_conditioning = unconditional_conditioning["c_adm"] + else: + image_conditioning = cond["c_concat"][0] cond = cond["c_crossattn"][0] unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] @@ -98,8 +103,12 @@ class VanillaStableDiffusionSampler: # Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Note that they need to be lists because it just concatenates them later. if image_conditioning is not None: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + if self.conditioning_key == "crossattn-adm": + cond = {"c_adm": image_conditioning, "c_crossattn": [cond]} + unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]} + else: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} return x, ts, cond, unconditional_conditioning @@ -176,8 +185,12 @@ class VanillaStableDiffusionSampler: # Wrap the conditioning models with additional image conditioning for inpainting model if image_conditioning is not None: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + if self.conditioning_key == "crossattn-adm": + conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]} + else: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) @@ -195,8 +208,12 @@ class VanillaStableDiffusionSampler: # Wrap the conditioning models with additional image conditioning for inpainting model # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape if image_conditioning is not None: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} + if self.conditioning_key == "crossattn-adm": + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)} + else: + conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} + unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 93f0e55a0..e9f08518f 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -92,14 +92,21 @@ class CFGDenoiser(torch.nn.Module): batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] + if shared.sd_model.model.conditioning_key == "crossattn-adm": + image_uncond = torch.zeros_like(image_cond) + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + else: + image_uncond = image_cond + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} + if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) else: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) cfg_denoiser_callback(denoiser_params) @@ -116,13 +123,13 @@ class CFGDenoiser(torch.nn.Module): cond_in = torch.cat([tensor, uncond, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -135,9 +142,9 @@ class CFGDenoiser(torch.nn.Module): else: c_crossattn = torch.cat([tensor[a:b]], uncond) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]}) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) cfg_denoised_callback(denoised_params)