import os import re import network import network_lora import network_hada import torch from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), ] re_digits = re.compile(r"\d+") re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") re_compiled = {} suffix_conversion = { "attentions": {}, "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } } def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex_text): regex = re_compiled.get(regex_text) if regex is None: regex = re.compile(regex_text) re_compiled[regex_text] = regex r = re.match(regex, key) if not r: return False match_list.clear() match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) return True m = [] if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): if is_sd2: if 'mlp_fc1' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" elif 'mlp_fc2' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" else: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): if 'mlp_fc1' in m[1]: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" elif 'mlp_fc2' in m[1]: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" else: return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" return key def assign_network_names_to_compvis_modules(sd_model): network_layer_mapping = {} if shared.sd_model.is_sdxl: for i, embedder in enumerate(shared.sd_model.conditioner.embedders): if not hasattr(embedder, 'wrapped'): continue for name, module in embedder.wrapped.named_modules(): network_name = f'{i}_{name.replace(".", "_")}' network_layer_mapping[network_name] = module module.network_layer_name = network_name else: for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name for name, module in shared.sd_model.model.named_modules(): network_name = name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name sd_model.network_layer_mapping = network_layer_mapping def load_network(name, network_on_disk): net = network.Network(name, network_on_disk) net.mtime = os.path.getmtime(network_on_disk.filename) sd = sd_models.read_state_dict(network_on_disk.filename) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 if not hasattr(shared.sd_model, 'network_layer_mapping'): assign_network_names_to_compvis_modules(shared.sd_model) keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} for key_network, weight in sd.items(): key_network_without_network_parts, network_part = key_network.split(".", 1) key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) if sd_module is None: m = re_x_proj.match(key) if m: sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" if sd_module is None and "lora_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) if sd_module is None: keys_failed_to_match[key_network] = key continue if key not in matched_networks: matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) matched_networks[key].w[network_part] = weight for key, weights in matched_networks.items(): net_module = None for nettype in module_types: net_module = nettype.create_module(net, weights) if net_module is not None: break if net_module is None: raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module if keys_failed_to_match: print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") return net def load_networks(names, multipliers=None): already_loaded = {} for net in loaded_networks: if net.name in names: already_loaded[net.name] = net loaded_networks.clear() networks_on_disk = [available_network_aliases.get(name, None) for name in names] if any(x is None for x in networks_on_disk): list_available_networks() networks_on_disk = [available_network_aliases.get(name, None) for name in names] failed_to_load_networks = [] for i, name in enumerate(names): net = already_loaded.get(name, None) network_on_disk = networks_on_disk[i] if network_on_disk is not None: if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: try: net = load_network(name, network_on_disk) except Exception as e: errors.display(e, f"loading network {network_on_disk.filename}") continue net.mentioned_name = name network_on_disk.read_hash() if net is None: failed_to_load_networks.append(name) print(f"Couldn't find network with name {name}") continue net.multiplier = multipliers[i] if multipliers else 1.0 loaded_networks.append(net) if failed_to_load_networks: sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None: return if isinstance(self, torch.nn.MultiheadAttention): self.in_proj_weight.copy_(weights_backup[0]) self.out_proj.weight.copy_(weights_backup[1]) else: self.weight.copy_(weights_backup) def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing. If not, restores orginal weights from backup and alters weights according to networks. """ network_layer_name = getattr(self, 'network_layer_name', None) if network_layer_name is None: return current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None: if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) else: weights_backup = self.weight.to(devices.cpu, copy=True) self.network_weights_backup = weights_backup if current_names != wanted_names: network_restore_weights_from_backup(self) for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): with torch.no_grad(): updown = module.calc_updown(self.weight) if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) self.weight += updown continue module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) module_v = net.modules.get(network_layer_name + "_v_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None) if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: with torch.no_grad(): updown_q = module_q.calc_updown(self.in_proj_weight) updown_k = module_k.calc_updown(self.in_proj_weight) updown_v = module_v.calc_updown(self.in_proj_weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) self.in_proj_weight += updown_qkv self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) continue if module is None: continue print(f'failed to calculate network weights for layer {network_layer_name}') self.network_current_names = wanted_names def network_forward(module, input, original_forward): """ Old way of applying Lora by executing operations during layer's forward. Stacking many loras this way results in big performance degradation. """ if len(loaded_networks) == 0: return original_forward(module, input) input = devices.cond_cast_unet(input) network_restore_weights_from_backup(module) network_reset_cached_weight(module) y = original_forward(module, input) network_layer_name = getattr(module, 'network_layer_name', None) for lora in loaded_networks: module = lora.modules.get(network_layer_name, None) if module is None: continue y = module.forward(y, input) return y def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None def network_Linear_forward(self, input): if shared.opts.lora_functional: return network_forward(self, input, torch.nn.Linear_forward_before_network) network_apply_weights(self) return torch.nn.Linear_forward_before_network(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) def network_Conv2d_forward(self, input): if shared.opts.lora_functional: return network_forward(self, input, torch.nn.Conv2d_forward_before_network) network_apply_weights(self) return torch.nn.Conv2d_forward_before_network(self, input) def network_Conv2d_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) def list_available_networks(): available_networks.clear() available_network_aliases.clear() forbidden_network_aliases.clear() available_network_hash_lookup.clear() forbidden_network_aliases.update({"none": 1, "Addams": 1}) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue name = os.path.splitext(os.path.basename(filename))[0] try: entry = network.NetworkOnDisk(name, filename) except OSError: # should catch FileNotFoundError and PermissionError etc. errors.report(f"Failed to load network {name} from {filename}", exc_info=True) continue available_networks[name] = entry if entry.alias in available_network_aliases: forbidden_network_aliases[entry.alias.lower()] = 1 available_network_aliases[name] = entry available_network_aliases[entry.alias] = entry re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") def infotext_pasted(infotext, params): if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: return # if the other extension is active, it will handle those fields, no need to do anything added = [] for k in params: if not k.startswith("AddNet Model "): continue num = k[13:] if params.get("AddNet Module " + num) != "LoRA": continue name = params.get("AddNet Model " + num) if name is None: continue m = re_network_name.match(name) if m: name = m.group(1) multiplier = params.get("AddNet Weight A " + num, "1.0") added.append(f"") if added: params["Prompt"] += "\n" + "".join(added) available_networks = {} available_network_aliases = {} loaded_networks = [] available_network_hash_lookup = {} forbidden_network_aliases = {} list_available_networks()