diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 92874a792..47dbc1b7c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import opts, device, cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None ldm.modules.diffusionmodules.model.print = lambda *args: None + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward + ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 8cd4c9549..85909eb94 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) - -def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - if h.shape[-2:] != hs[-1].shape[-2:]: - h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest") - h = torch.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py new file mode 100644 index 000000000..1b9d7757a --- /dev/null +++ b/modules/sd_hijack_unet.py @@ -0,0 +1,30 @@ +import torch + + +class TorchHijackForUnet: + """ + This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; + this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64 + """ + + def __getattr__(self, item): + if item == 'cat': + return self.cat + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + def cat(self, tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + if a.shape[-2:] != b.shape[-2:]: + a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") + + tensors = (a, b) + + return torch.cat(tensors, *args, **kwargs) + + +th = TorchHijackForUnet()