From 4738486d8f528a98a525970ac06a109431fd7344 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 6 Feb 2023 18:10:55 -0500 Subject: [PATCH 1/2] Support for hypernetworks with --upcast-sampling --- modules/hypernetworks/hypernetwork.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 825a93b28..a15bae18c 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -380,8 +380,8 @@ def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] - context_k = hypernetwork_layers[0](context_k) - context_v = hypernetwork_layers[1](context_v) + context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k))) + context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v))) return context_k, context_v From 2016733814433ca2b69d10764bfa0ab4c7088782 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 7 Feb 2023 00:05:54 -0500 Subject: [PATCH 2/2] Apply hijacks in ddpm_edit for upcast sampling To avoid import errors, ddpm_edit hijacks are done after an instruct pix2pix model is loaded. --- modules/sd_hijack.py | 3 +++ modules/sd_hijack_unet.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 8fdc59909..fca418cdf 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -104,6 +104,9 @@ class StableDiffusionModelHijack: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + if m.cond_stage_key == "edit": + sd_hijack_unet.hijack_ddpm_edit() + self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 45cf2b18e..843ab66cf 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -44,6 +44,7 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + class GELUHijack(torch.nn.GELU, torch.nn.Module): def __init__(self, *args, **kwargs): torch.nn.GELU.__init__(self, *args, **kwargs) @@ -53,6 +54,16 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module): else: return torch.nn.GELU.forward(self, x) + +ddpm_edit_hijack = None +def hijack_ddpm_edit(): + global ddpm_edit_hijack + if not ddpm_edit_hijack: + CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) + CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) + + unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)