diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 0dabbe0e4..c680367eb 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -74,6 +74,30 @@ def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): return embedding +# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. +# Prevents a lot of unnecessary aten::copy_ calls +def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = x.view(b, h, w, c).permute(0, 3, 1, 2) + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + class GELUHijack(torch.nn.GELU, torch.nn.Module): def __init__(self, *args, **kwargs): torch.nn.GELU.__init__(self, *args, **kwargs) @@ -95,7 +119,8 @@ def hijack_ddpm_edit(): unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding, lambda *args, **kwargs: True) +CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward, lambda *args, **kwargs: True) CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)