From 6f0abbb71a3f29d6df63fed82d5d5e196ca0d4de Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 29 Jul 2023 15:15:06 +0300 Subject: [PATCH] textual inversion support for SDXL --- modules/sd_hijack.py | 8 +++++--- modules/sd_hijack_clip.py | 2 +- modules/sd_models_xl.py | 9 +++++++++ .../textual_inversion/textual_inversion.py | 19 ++++++++++++++----- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c8fdd4f16..cfa5f0ebb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -197,7 +197,7 @@ class StableDiffusionModelHijack: conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': - embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g') conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) text_cond_models.append(conditioner.embedders[i]) @@ -292,10 +292,11 @@ class StableDiffusionModelHijack: class EmbeddingsWithFixes(torch.nn.Module): - def __init__(self, wrapped, embeddings): + def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): super().__init__() self.wrapped = wrapped self.embeddings = embeddings + self.textual_inversion_key = textual_inversion_key def forward(self, input_ids): batch_fixes = self.embeddings.fixes @@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = devices.cond_cast_unet(embedding.vec) + vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec + emb = devices.cond_cast_unet(vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 16a5500e3..2f9d569b1 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): position += 1 continue - emb_len = int(embedding.vec.shape[0]) + emb_len = int(embedding.vectors) if len(chunk.tokens) + emb_len > self.chunk_length: next_chunk() diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 40559208b..bc2195087 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, return torch.cat(res, dim=1) +def tokenize(self: sgm.modules.GeneralConditioner, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: + return embedder.tokenize(texts) + + raise AssertionError('no tokenizer available') + + + def process_texts(self, texts): for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: return embedder.process_texts(texts) @@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count): # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.tokenize = tokenize sgm.modules.GeneralConditioner.process_texts = process_texts sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6166c76f6..4713bc2d9 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,29 +181,38 @@ class EmbeddingDatabase: else: return + # textual inversion embeddings if 'string_to_param' in data: param_dict = data['string_to_param'] param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] else: raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vec.shape[0] - embedding.shape = vec.shape[-1] + embedding.vectors = vectors + embedding.shape = shape embedding.filename = path embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')