diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py new file mode 100644 index 000000000..d51de2970 --- /dev/null +++ b/extensions-builtin/Lora/lora_logger.py @@ -0,0 +1,33 @@ +import sys +import copy +import logging + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("lora") +logger.propagate = False + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") + ) + logger.addHandler(handler) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d8e8dfb7f..6021fd8de 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -93,6 +93,7 @@ class Network: # LoraModule self.unet_multiplier = 1.0 self.dyn_dim = None self.modules = {} + self.bundle_embeddings = {} self.mtime = None self.mentioned_name = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index ddab3c55e..60d8dec4c 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -16,6 +16,9 @@ import torch from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack +import modules.textual_inversion.textual_inversion as textual_inversion + +from lora_logger import logger module_types = [ network_lora.ModuleTypeLora(), @@ -151,9 +154,19 @@ def load_network(name, network_on_disk): is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} + bundle_embeddings = {} for key_network, weight in sd.items(): key_network_without_network_parts, network_part = key_network.split(".", 1) + if key_network_without_network_parts == "bundle_emb": + emb_name, vec_name = network_part.split(".", 1) + emb_dict = bundle_embeddings.get(emb_name, {}) + if vec_name.split('.')[0] == 'string_to_param': + _, k2 = vec_name.split('.', 1) + emb_dict['string_to_param'] = {k2: weight} + else: + emb_dict[vec_name] = weight + bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -197,6 +210,14 @@ def load_network(name, network_on_disk): net.modules[key] = net_module + embeddings = {} + for emb_name, data in bundle_embeddings.items(): + embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name) + embedding.loaded = None + embeddings[emb_name] = embedding + + net.bundle_embeddings = embeddings + if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") @@ -212,11 +233,15 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + emb_db = sd_hijack.model_hijack.embedding_db already_loaded = {} for net in loaded_networks: if net.name in names: already_loaded[net.name] = net + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded: + emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() @@ -259,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded is None and emb_name in emb_db.word_embeddings: + logger.warning( + f'Skip bundle embedding: "{emb_name}"' + ' as it was already loaded from embeddings folder' + ) + continue + + embedding.loaded = False + if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: + embedding.loaded = True + emb_db.register_embedding(embedding, shared.sd_model) + else: + emb_db.skipped_embeddings[name] = embedding + if failed_to_load_networks: sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) @@ -567,6 +607,7 @@ extra_network_lora = None available_networks = {} available_network_aliases = {} loaded_networks = [] +loaded_bundle_embeddings = {} networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 401a0a2ab..04dda585c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,40 +181,7 @@ class EmbeddingDatabase: else: return - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vectors - embedding.shape = shape - embedding.filename = path - embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') + embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): return fn +def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): + if 'string_to_param' in data: # textual inversion embeddings + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vectors + embedding.shape = shape + + if filepath: + embedding.filename = filepath + embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') + + return embedding + + def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return