mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Add transformer forward patch
This commit is contained in:
parent
53d67088ee
commit
cc9ca67664
@ -74,6 +74,30 @@ def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
|
||||||
|
# Prevents a lot of unnecessary aten::copy_ calls
|
||||||
|
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
|
||||||
|
# note: if no context is given, cross-attention defaults to self-attention
|
||||||
|
if not isinstance(context, list):
|
||||||
|
context = [context]
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
x_in = x
|
||||||
|
x = self.norm(x)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_in(x)
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
x = block(x, context=context[i])
|
||||||
|
if self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
x = x.view(b, h, w, c).permute(0, 3, 1, 2)
|
||||||
|
if not self.use_linear:
|
||||||
|
x = self.proj_out(x)
|
||||||
|
return x + x_in
|
||||||
|
|
||||||
|
|
||||||
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
torch.nn.GELU.__init__(self, *args, **kwargs)
|
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||||
@ -95,7 +119,8 @@ def hijack_ddpm_edit():
|
|||||||
|
|
||||||
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding, lambda *args, **kwargs: True)
|
||||||
|
CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward, lambda *args, **kwargs: True)
|
||||||
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
|
||||||
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
|
||||||
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||||
|
Loading…
Reference in New Issue
Block a user