diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 990533fed..16a5500e3 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -270,6 +270,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.encode_with_transformers(tokens) + pooled = getattr(z, 'pooled', None) + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() @@ -277,6 +279,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): new_mean = z.mean() z = z * (original_mean / new_mean) + if pooled is not None: + z.pooled = pooled + return z