alternate implementation for unet forward replacement that does not depend on hijack being applied

This commit is contained in:
AUTOMATIC1111 2023-12-02 19:35:47 +03:00
parent af5f0734c9
commit ac02216e54
2 changed files with 13 additions and 8 deletions

View File

@ -38,8 +38,11 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
def list_optimizers():

View File

@ -5,8 +5,7 @@ from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
original_forward = None
original_forward = None # not used, only left temporarily for compatibility
def list_unets():
new_unets = script_callbacks.list_unets_callback()
@ -84,9 +83,12 @@ class SdUnet(torch.nn.Module):
pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
def create_unet_forward(original_forward):
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)
return UNetModel_forward