import gradio as gr import logging import os import re import lora_patches import network import network_lora import network_glora import network_hada import network_ia3 import network_lokr import network_full import network_norm import network_oft import torch from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack import modules.textual_inversion.textual_inversion as textual_inversion from lora_logger import logger module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), network_glora.ModuleTypeGLora(), network_oft.ModuleTypeOFT(), ] re_digits = re.compile(r"\d+") re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") re_compiled = {} suffix_conversion = { "attentions": {}, "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", "norm1": "in_layers_0", "norm2": "out_layers_0", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } } def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex_text): regex = re_compiled.get(regex_text) if regex is None: regex = re.compile(regex_text) re_compiled[regex_text] = regex r = re.match(regex, key) if not r: return False match_list.clear() match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) return True m = [] if match(m, r"lora_unet_conv_in(.*)"): return f'diffusion_model_input_blocks_0_0{m[0]}' if match(m, r"lora_unet_conv_out(.*)"): return f'diffusion_model_out_2{m[0]}' if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): if is_sd2: if 'mlp_fc1' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" elif 'mlp_fc2' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" else: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): if 'mlp_fc1' in m[1]: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" elif 'mlp_fc2' in m[1]: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" else: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" return key def assign_network_names_to_compvis_modules(sd_model): network_layer_mapping = {} if shared.sd_model.is_sdxl: for i, embedder in enumerate(shared.sd_model.conditioner.embedders): if not hasattr(embedder, 'wrapped'): continue for name, module in embedder.wrapped.named_modules(): network_name = f'{i}_{name.replace(".", "_")}' network_layer_mapping[network_name] = module module.network_layer_name = network_name else: for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name for name, module in shared.sd_model.model.named_modules(): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name sd_model.network_layer_mapping = network_layer_mapping def load_network(name, network_on_disk): net = network.Network(name, network_on_disk) net.mtime = os.path.getmtime(network_on_disk.filename) sd = sd_models.read_state_dict(network_on_disk.filename) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 if not hasattr(shared.sd_model, 'network_layer_mapping'): assign_network_names_to_compvis_modules(shared.sd_model) keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} bundle_embeddings = {} for key_network, weight in sd.items(): key_network_without_network_parts, _, network_part = key_network.partition(".") if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) if vec_name.split('.')[0] == 'string_to_param': _, k2 = vec_name.split('.', 1) emb_dict['string_to_param'] = {k2: weight} else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) if sd_module is None: m = re_x_proj.match(key) if m: sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" if sd_module is None and "lora_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) # kohya_ss OFT module elif sd_module is None and "oft_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) # KohakuBlueLeaf OFT module if sd_module is None and "oft_diag" in key: key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) if sd_module is None: keys_failed_to_match[key_network] = key continue if key not in matched_networks: matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) matched_networks[key].w[network_part] = weight for key, weights in matched_networks.items(): net_module = None for nettype in module_types: net_module = nettype.create_module(net, weights) if net_module is not None: break if net_module is None: raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module embeddings = {} for emb_name, data in bundle_embeddings.items(): embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name) embedding.loaded = None embeddings[emb_name] = embedding net.bundle_embeddings = embeddings if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") return net def purge_networks_from_memory(): while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: name = next(iter(networks_in_memory)) networks_in_memory.pop(name, None) devices.torch_gc() def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): emb_db = sd_hijack.model_hijack.embedding_db already_loaded = {} for net in loaded_networks: if net.name in names: already_loaded[net.name] = net for emb_name, embedding in net.bundle_embeddings.items(): if embedding.loaded: emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names] failed_to_load_networks = [] for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): net = already_loaded.get(name, None) if network_on_disk is not None: if net is None: net = networks_in_memory.get(name) if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: try: net = load_network(name, network_on_disk) networks_in_memory.pop(name, None) networks_in_memory[name] = net except Exception as e: errors.display(e, f"loading network {network_on_disk.filename}") continue net.mentioned_name = name network_on_disk.read_hash() if net is None: failed_to_load_networks.append(name) logging.info(f"Couldn't find network with name {name}") continue net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) for emb_name, embedding in net.bundle_embeddings.items(): if embedding.loaded is None and emb_name in emb_db.word_embeddings: logger.warning( f'Skip bundle embedding: "{emb_name}"' ' as it was already loaded from embeddings folder' ) continue embedding.loaded = False if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: embedding.loaded = True emb_db.register_embedding(embedding, shared.sd_model) else: emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' sd_hijack.model_hijack.comments.append(lora_not_found_message) if shared.opts.lora_not_found_warning_console: print(f'\n{lora_not_found_message}\n') if shared.opts.lora_not_found_gradio_warning: gr.Warning(lora_not_found_message) purge_networks_from_memory() def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): weights_backup = getattr(self, "network_weights_backup", None) bias_backup = getattr(self, "network_bias_backup", None) if weights_backup is None and bias_backup is None: return if weights_backup is not None: if isinstance(self, torch.nn.MultiheadAttention): self.in_proj_weight.copy_(weights_backup[0]) self.out_proj.weight.copy_(weights_backup[1]) else: self.weight.copy_(weights_backup) if bias_backup is not None: if isinstance(self, torch.nn.MultiheadAttention): self.out_proj.bias.copy_(bias_backup) else: self.bias.copy_(bias_backup) else: if isinstance(self, torch.nn.MultiheadAttention): self.out_proj.bias = None else: self.bias = None def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing. If not, restores original weights from backup and alters weights according to networks. """ network_layer_name = getattr(self, 'network_layer_name', None) if network_layer_name is None: return current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None and wanted_names != (): if current_names != (): raise RuntimeError("no backup weights found and current weights are not unchanged") if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) else: weights_backup = self.weight.to(devices.cpu, copy=True) self.network_weights_backup = weights_backup bias_backup = getattr(self, "network_bias_backup", None) if bias_backup is None: if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) elif getattr(self, 'bias', None) is not None: bias_backup = self.bias.to(devices.cpu, copy=True) else: bias_backup = None self.network_bias_backup = bias_backup if current_names != wanted_names: network_restore_weights_from_backup(self) for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): if getattr(self, 'fp16_weight', None) is None: weight = self.weight bias = self.bias else: weight = self.fp16_weight.clone().to(self.weight.device) bias = getattr(self, 'fp16_bias', None) if bias is not None: bias = bias.clone().to(self.bias.device) updown, ex_bias = module.calc_updown(weight) if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 continue module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) module_v = net.modules.get(network_layer_name + "_v_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None) if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: try: with torch.no_grad(): updown_q, _ = module_q.calc_updown(self.in_proj_weight) updown_k, _ = module_k.calc_updown(self.in_proj_weight) updown_v, _ = module_v.calc_updown(self.in_proj_weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) self.in_proj_weight += updown_qkv self.out_proj.weight += updown_out if ex_bias is not None: if self.out_proj.bias is None: self.out_proj.bias = torch.nn.Parameter(ex_bias) else: self.out_proj.bias += ex_bias except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 continue if module is None: continue logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 self.network_current_names = wanted_names def network_forward(org_module, input, original_forward): """ Old way of applying Lora by executing operations during layer's forward. Stacking many loras this way results in big performance degradation. """ if len(loaded_networks) == 0: return original_forward(org_module, input) input = devices.cond_cast_unet(input) network_restore_weights_from_backup(org_module) network_reset_cached_weight(org_module) y = original_forward(org_module, input) network_layer_name = getattr(org_module, 'network_layer_name', None) for lora in loaded_networks: module = lora.modules.get(network_layer_name, None) if module is None: continue y = module.forward(input, y) return y def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None self.network_bias_backup = None def network_Linear_forward(self, input): if shared.opts.lora_functional: return network_forward(self, input, originals.Linear_forward) network_apply_weights(self) return originals.Linear_forward(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return originals.Linear_load_state_dict(self, *args, **kwargs) def network_Conv2d_forward(self, input): if shared.opts.lora_functional: return network_forward(self, input, originals.Conv2d_forward) network_apply_weights(self) return originals.Conv2d_forward(self, input) def network_Conv2d_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return originals.Conv2d_load_state_dict(self, *args, **kwargs) def network_GroupNorm_forward(self, input): if shared.opts.lora_functional: return network_forward(self, input, originals.GroupNorm_forward) network_apply_weights(self) return originals.GroupNorm_forward(self, input) def network_GroupNorm_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return originals.GroupNorm_load_state_dict(self, *args, **kwargs) def network_LayerNorm_forward(self, input): if shared.opts.lora_functional: return network_forward(self, input, originals.LayerNorm_forward) network_apply_weights(self) return originals.LayerNorm_forward(self, input) def network_LayerNorm_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return originals.LayerNorm_load_state_dict(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) return originals.MultiheadAttention_forward(self, *args, **kwargs) def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) def list_available_networks(): available_networks.clear() available_network_aliases.clear() forbidden_network_aliases.clear() available_network_hash_lookup.clear() forbidden_network_aliases.update({"none": 1, "Addams": 1}) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue name = os.path.splitext(os.path.basename(filename))[0] try: entry = network.NetworkOnDisk(name, filename) except OSError: # should catch FileNotFoundError and PermissionError etc. errors.report(f"Failed to load network {name} from {filename}", exc_info=True) continue available_networks[name] = entry if entry.alias in available_network_aliases: forbidden_network_aliases[entry.alias.lower()] = 1 available_network_aliases[name] = entry available_network_aliases[entry.alias] = entry re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") def infotext_pasted(infotext, params): if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: return # if the other extension is active, it will handle those fields, no need to do anything added = [] for k in params: if not k.startswith("AddNet Model "): continue num = k[13:] if params.get("AddNet Module " + num) != "LoRA": continue name = params.get("AddNet Model " + num) if name is None: continue m = re_network_name.match(name) if m: name = m.group(1) multiplier = params.get("AddNet Weight A " + num, "1.0") added.append(f"") if added: params["Prompt"] += "\n" + "".join(added) originals: lora_patches.LoraPatches = None extra_network_lora = None available_networks = {} available_network_aliases = {} loaded_networks = [] loaded_bundle_embeddings = {} networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} list_available_networks()