2023-07-16 20:13:55 +00:00
import os
import re
import network
import network_lora
import network_hada
2023-07-16 21:12:18 +00:00
import network_ia3
2023-07-16 21:29:07 +00:00
import network_lokr
2023-07-17 06:00:47 +00:00
import network_full
2023-08-12 18:27:39 +00:00
import network_norm
2023-07-16 20:13:55 +00:00
import torch
from typing import Union
2023-07-18 17:11:30 +00:00
from modules import shared , devices , sd_models , errors , scripts , sd_hijack
2023-07-16 20:13:55 +00:00
module_types = [
network_lora . ModuleTypeLora ( ) ,
network_hada . ModuleTypeHada ( ) ,
2023-07-16 21:12:18 +00:00
network_ia3 . ModuleTypeIa3 ( ) ,
2023-07-16 21:29:07 +00:00
network_lokr . ModuleTypeLokr ( ) ,
2023-07-17 06:00:47 +00:00
network_full . ModuleTypeFull ( ) ,
2023-08-12 18:27:39 +00:00
network_norm . ModuleTypeNorm ( ) ,
2023-07-16 20:13:55 +00:00
]
re_digits = re . compile ( r " \ d+ " )
re_x_proj = re . compile ( r " (.*)_([qkv]_proj)$ " )
re_compiled = { }
suffix_conversion = {
" attentions " : { } ,
" resnets " : {
" conv1 " : " in_layers_2 " ,
" conv2 " : " out_layers_3 " ,
2023-08-12 18:27:39 +00:00
" norm1 " : " in_layers_0 " ,
" norm2 " : " out_layers_0 " ,
2023-07-16 20:13:55 +00:00
" time_emb_proj " : " emb_layers_1 " ,
" conv_shortcut " : " skip_connection " ,
}
}
def convert_diffusers_name_to_compvis ( key , is_sd2 ) :
def match ( match_list , regex_text ) :
regex = re_compiled . get ( regex_text )
if regex is None :
regex = re . compile ( regex_text )
re_compiled [ regex_text ] = regex
r = re . match ( regex , key )
if not r :
return False
match_list . clear ( )
match_list . extend ( [ int ( x ) if re . match ( re_digits , x ) else x for x in r . groups ( ) ] )
return True
m = [ ]
2023-07-17 06:00:47 +00:00
if match ( m , r " lora_unet_conv_in(.*) " ) :
return f ' diffusion_model_input_blocks_0_0 { m [ 0 ] } '
if match ( m , r " lora_unet_conv_out(.*) " ) :
return f ' diffusion_model_out_2 { m [ 0 ] } '
if match ( m , r " lora_unet_time_embedding_linear_( \ d+)(.*) " ) :
return f " diffusion_model_time_embed_ { m [ 0 ] * 2 - 2 } { m [ 1 ] } "
2023-07-16 20:13:55 +00:00
if match ( m , r " lora_unet_down_blocks_( \ d+)_(attentions|resnets)_( \ d+)_(.+) " ) :
suffix = suffix_conversion . get ( m [ 1 ] , { } ) . get ( m [ 3 ] , m [ 3 ] )
return f " diffusion_model_input_blocks_ { 1 + m [ 0 ] * 3 + m [ 2 ] } _ { 1 if m [ 1 ] == ' attentions ' else 0 } _ { suffix } "
if match ( m , r " lora_unet_mid_block_(attentions|resnets)_( \ d+)_(.+) " ) :
suffix = suffix_conversion . get ( m [ 0 ] , { } ) . get ( m [ 2 ] , m [ 2 ] )
return f " diffusion_model_middle_block_ { 1 if m [ 0 ] == ' attentions ' else m [ 1 ] * 2 } _ { suffix } "
if match ( m , r " lora_unet_up_blocks_( \ d+)_(attentions|resnets)_( \ d+)_(.+) " ) :
suffix = suffix_conversion . get ( m [ 1 ] , { } ) . get ( m [ 3 ] , m [ 3 ] )
return f " diffusion_model_output_blocks_ { m [ 0 ] * 3 + m [ 2 ] } _ { 1 if m [ 1 ] == ' attentions ' else 0 } _ { suffix } "
if match ( m , r " lora_unet_down_blocks_( \ d+)_downsamplers_0_conv " ) :
return f " diffusion_model_input_blocks_ { 3 + m [ 0 ] * 3 } _0_op "
if match ( m , r " lora_unet_up_blocks_( \ d+)_upsamplers_0_conv " ) :
return f " diffusion_model_output_blocks_ { 2 + m [ 0 ] * 3 } _ { 2 if m [ 0 ] > 0 else 1 } _conv "
if match ( m , r " lora_te_text_model_encoder_layers_( \ d+)_(.+) " ) :
if is_sd2 :
if ' mlp_fc1 ' in m [ 1 ] :
return f " model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc1 ' , ' mlp_c_fc ' ) } "
elif ' mlp_fc2 ' in m [ 1 ] :
return f " model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc2 ' , ' mlp_c_proj ' ) } "
else :
return f " model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' self_attn ' , ' attn ' ) } "
return f " transformer_text_model_encoder_layers_ { m [ 0 ] } _ { m [ 1 ] } "
if match ( m , r " lora_te2_text_model_encoder_layers_( \ d+)_(.+) " ) :
if ' mlp_fc1 ' in m [ 1 ] :
return f " 1_model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc1 ' , ' mlp_c_fc ' ) } "
elif ' mlp_fc2 ' in m [ 1 ] :
return f " 1_model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' mlp_fc2 ' , ' mlp_c_proj ' ) } "
else :
return f " 1_model_transformer_resblocks_ { m [ 0 ] } _ { m [ 1 ] . replace ( ' self_attn ' , ' attn ' ) } "
return key
def assign_network_names_to_compvis_modules ( sd_model ) :
network_layer_mapping = { }
if shared . sd_model . is_sdxl :
for i , embedder in enumerate ( shared . sd_model . conditioner . embedders ) :
if not hasattr ( embedder , ' wrapped ' ) :
continue
for name , module in embedder . wrapped . named_modules ( ) :
network_name = f ' { i } _ { name . replace ( " . " , " _ " ) } '
network_layer_mapping [ network_name ] = module
module . network_layer_name = network_name
else :
for name , module in shared . sd_model . cond_stage_model . wrapped . named_modules ( ) :
network_name = name . replace ( " . " , " _ " )
network_layer_mapping [ network_name ] = module
module . network_layer_name = network_name
for name , module in shared . sd_model . model . named_modules ( ) :
network_name = name . replace ( " . " , " _ " )
network_layer_mapping [ network_name ] = module
module . network_layer_name = network_name
sd_model . network_layer_mapping = network_layer_mapping
def load_network ( name , network_on_disk ) :
net = network . Network ( name , network_on_disk )
net . mtime = os . path . getmtime ( network_on_disk . filename )
sd = sd_models . read_state_dict ( network_on_disk . filename )
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr ( shared . sd_model , ' network_layer_mapping ' ) :
assign_network_names_to_compvis_modules ( shared . sd_model )
keys_failed_to_match = { }
is_sd2 = ' model_transformer_resblocks ' in shared . sd_model . network_layer_mapping
matched_networks = { }
for key_network , weight in sd . items ( ) :
key_network_without_network_parts , network_part = key_network . split ( " . " , 1 )
key = convert_diffusers_name_to_compvis ( key_network_without_network_parts , is_sd2 )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
if sd_module is None :
m = re_x_proj . match ( key )
if m :
sd_module = shared . sd_model . network_layer_mapping . get ( m . group ( 1 ) , None )
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and " lora_unet " in key_network_without_network_parts :
key = key_network_without_network_parts . replace ( " lora_unet " , " diffusion_model " )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
elif sd_module is None and " lora_te1_text_model " in key_network_without_network_parts :
key = key_network_without_network_parts . replace ( " lora_te1_text_model " , " 0_transformer_text_model " )
2023-07-25 13:18:10 +00:00
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
# some SD1 Loras also have correct compvis keys
if sd_module is None :
key = key_network_without_network_parts . replace ( " lora_te1_text_model " , " transformer_text_model " )
sd_module = shared . sd_model . network_layer_mapping . get ( key , None )
2023-07-16 20:13:55 +00:00
if sd_module is None :
keys_failed_to_match [ key_network ] = key
continue
if key not in matched_networks :
matched_networks [ key ] = network . NetworkWeights ( network_key = key_network , sd_key = key , w = { } , sd_module = sd_module )
matched_networks [ key ] . w [ network_part ] = weight
for key , weights in matched_networks . items ( ) :
net_module = None
for nettype in module_types :
net_module = nettype . create_module ( net , weights )
if net_module is not None :
break
if net_module is None :
raise AssertionError ( f " Could not find a module type (out of { ' , ' . join ( [ x . __class__ . __name__ for x in module_types ] ) } ) that would accept those keys: { ' , ' . join ( weights . w ) } " )
net . modules [ key ] = net_module
if keys_failed_to_match :
print ( f " Failed to match keys when loading network { network_on_disk . filename } : { keys_failed_to_match } " )
return net
2023-08-09 13:54:49 +00:00
def purge_networks_from_memory ( ) :
while len ( networks_in_memory ) > shared . opts . lora_in_memory_limit and len ( networks_in_memory ) > 0 :
name = next ( iter ( networks_in_memory ) )
networks_in_memory . pop ( name , None )
devices . torch_gc ( )
2023-07-17 06:00:47 +00:00
def load_networks ( names , te_multipliers = None , unet_multipliers = None , dyn_dims = None ) :
2023-07-16 20:13:55 +00:00
already_loaded = { }
for net in loaded_networks :
if net . name in names :
already_loaded [ net . name ] = net
loaded_networks . clear ( )
networks_on_disk = [ available_network_aliases . get ( name , None ) for name in names ]
if any ( x is None for x in networks_on_disk ) :
list_available_networks ( )
networks_on_disk = [ available_network_aliases . get ( name , None ) for name in names ]
failed_to_load_networks = [ ]
2023-08-09 13:54:49 +00:00
for i , ( network_on_disk , name ) in enumerate ( zip ( networks_on_disk , names ) ) :
2023-07-16 20:13:55 +00:00
net = already_loaded . get ( name , None )
if network_on_disk is not None :
2023-08-09 13:54:49 +00:00
if net is None :
net = networks_in_memory . get ( name )
2023-07-16 20:13:55 +00:00
if net is None or os . path . getmtime ( network_on_disk . filename ) > net . mtime :
try :
net = load_network ( name , network_on_disk )
2023-08-09 13:54:49 +00:00
networks_in_memory . pop ( name , None )
networks_in_memory [ name ] = net
2023-07-16 20:13:55 +00:00
except Exception as e :
errors . display ( e , f " loading network { network_on_disk . filename } " )
continue
net . mentioned_name = name
network_on_disk . read_hash ( )
if net is None :
failed_to_load_networks . append ( name )
print ( f " Couldn ' t find network with name { name } " )
continue
2023-07-17 06:00:47 +00:00
net . te_multiplier = te_multipliers [ i ] if te_multipliers else 1.0
net . unet_multiplier = unet_multipliers [ i ] if unet_multipliers else 1.0
net . dyn_dim = dyn_dims [ i ] if dyn_dims else 1.0
2023-07-16 20:13:55 +00:00
loaded_networks . append ( net )
if failed_to_load_networks :
sd_hijack . model_hijack . comments . append ( " Failed to find networks: " + " , " . join ( failed_to_load_networks ) )
2023-08-09 13:54:49 +00:00
purge_networks_from_memory ( )
2023-07-16 20:13:55 +00:00
2023-08-12 18:27:39 +00:00
def network_restore_weights_from_backup ( self : Union [ torch . nn . Conv2d , torch . nn . Linear , torch . nn . GroupNorm , torch . nn . LayerNorm , torch . nn . MultiheadAttention ] ) :
2023-07-16 20:13:55 +00:00
weights_backup = getattr ( self , " network_weights_backup " , None )
2023-08-12 18:27:39 +00:00
bias_backup = getattr ( self , " network_bias_backup " , None )
2023-07-16 20:13:55 +00:00
2023-08-12 18:27:39 +00:00
if weights_backup is None and bias_backup is None :
2023-07-16 20:13:55 +00:00
return
2023-08-12 18:27:39 +00:00
if weights_backup is not None :
if isinstance ( self , torch . nn . MultiheadAttention ) :
self . in_proj_weight . copy_ ( weights_backup [ 0 ] )
self . out_proj . weight . copy_ ( weights_backup [ 1 ] )
else :
self . weight . copy_ ( weights_backup )
2023-07-16 20:13:55 +00:00
2023-08-12 18:27:39 +00:00
if bias_backup is not None :
self . bias . copy_ ( bias_backup )
2023-07-16 20:13:55 +00:00
2023-08-12 18:27:39 +00:00
def network_apply_weights ( self : Union [ torch . nn . Conv2d , torch . nn . Linear , torch . nn . GroupNorm , torch . nn . LayerNorm , torch . nn . MultiheadAttention ] ) :
2023-07-16 20:13:55 +00:00
"""
Applies the currently selected set of networks to the weights of torch layer self .
If weights already have this particular set of networks applied , does nothing .
If not , restores orginal weights from backup and alters weights according to networks .
"""
network_layer_name = getattr ( self , ' network_layer_name ' , None )
if network_layer_name is None :
return
current_names = getattr ( self , " network_current_names " , ( ) )
2023-07-17 06:00:47 +00:00
wanted_names = tuple ( ( x . name , x . te_multiplier , x . unet_multiplier , x . dyn_dim ) for x in loaded_networks )
2023-07-16 20:13:55 +00:00
weights_backup = getattr ( self , " network_weights_backup " , None )
if weights_backup is None :
if isinstance ( self , torch . nn . MultiheadAttention ) :
weights_backup = ( self . in_proj_weight . to ( devices . cpu , copy = True ) , self . out_proj . weight . to ( devices . cpu , copy = True ) )
else :
weights_backup = self . weight . to ( devices . cpu , copy = True )
self . network_weights_backup = weights_backup
2023-08-12 18:27:39 +00:00
bias_backup = getattr ( self , " network_bias_backup " , None )
if bias_backup is None and getattr ( self , ' bias ' , None ) is not None :
bias_backup = self . bias . to ( devices . cpu , copy = True )
self . network_bias_backup = bias_backup
2023-07-16 20:13:55 +00:00
if current_names != wanted_names :
network_restore_weights_from_backup ( self )
for net in loaded_networks :
module = net . modules . get ( network_layer_name , None )
if module is not None and hasattr ( self , ' weight ' ) :
with torch . no_grad ( ) :
2023-08-12 18:27:39 +00:00
updown , ex_bias = module . calc_updown ( self . weight )
2023-07-16 20:13:55 +00:00
if len ( self . weight . shape ) == 4 and self . weight . shape [ 1 ] == 9 :
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch . nn . functional . pad ( updown , ( 0 , 0 , 0 , 0 , 0 , 5 ) )
self . weight + = updown
2023-08-12 18:35:04 +00:00
if ex_bias is not None and getattr ( self , ' bias ' , None ) is not None :
2023-08-12 18:27:39 +00:00
self . bias + = ex_bias
2023-07-16 21:01:17 +00:00
continue
2023-07-16 20:13:55 +00:00
module_q = net . modules . get ( network_layer_name + " _q_proj " , None )
module_k = net . modules . get ( network_layer_name + " _k_proj " , None )
module_v = net . modules . get ( network_layer_name + " _v_proj " , None )
module_out = net . modules . get ( network_layer_name + " _out_proj " , None )
if isinstance ( self , torch . nn . MultiheadAttention ) and module_q and module_k and module_v and module_out :
with torch . no_grad ( ) :
updown_q = module_q . calc_updown ( self . in_proj_weight )
updown_k = module_k . calc_updown ( self . in_proj_weight )
updown_v = module_v . calc_updown ( self . in_proj_weight )
updown_qkv = torch . vstack ( [ updown_q , updown_k , updown_v ] )
2023-07-17 06:00:47 +00:00
updown_out = module_out . calc_updown ( self . out_proj . weight )
2023-07-16 20:13:55 +00:00
self . in_proj_weight + = updown_qkv
2023-07-17 06:00:47 +00:00
self . out_proj . weight + = updown_out
2023-07-16 20:13:55 +00:00
continue
if module is None :
continue
print ( f ' failed to calculate network weights for layer { network_layer_name } ' )
self . network_current_names = wanted_names
def network_forward ( module , input , original_forward ) :
"""
Old way of applying Lora by executing operations during layer ' s forward.
Stacking many loras this way results in big performance degradation .
"""
if len ( loaded_networks ) == 0 :
return original_forward ( module , input )
input = devices . cond_cast_unet ( input )
network_restore_weights_from_backup ( module )
network_reset_cached_weight ( module )
y = original_forward ( module , input )
network_layer_name = getattr ( module , ' network_layer_name ' , None )
for lora in loaded_networks :
module = lora . modules . get ( network_layer_name , None )
if module is None :
continue
2023-08-11 03:42:58 +00:00
y = module . forward ( input , y )
2023-07-16 20:13:55 +00:00
return y
def network_reset_cached_weight ( self : Union [ torch . nn . Conv2d , torch . nn . Linear ] ) :
self . network_current_names = ( )
self . network_weights_backup = None
def network_Linear_forward ( self , input ) :
if shared . opts . lora_functional :
return network_forward ( self , input , torch . nn . Linear_forward_before_network )
network_apply_weights ( self )
return torch . nn . Linear_forward_before_network ( self , input )
def network_Linear_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
return torch . nn . Linear_load_state_dict_before_network ( self , * args , * * kwargs )
def network_Conv2d_forward ( self , input ) :
if shared . opts . lora_functional :
return network_forward ( self , input , torch . nn . Conv2d_forward_before_network )
network_apply_weights ( self )
return torch . nn . Conv2d_forward_before_network ( self , input )
def network_Conv2d_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
return torch . nn . Conv2d_load_state_dict_before_network ( self , * args , * * kwargs )
2023-08-12 18:27:39 +00:00
def network_GroupNorm_forward ( self , input ) :
if shared . opts . lora_functional :
return network_forward ( self , input , torch . nn . GroupNorm_forward_before_network )
network_apply_weights ( self )
return torch . nn . GroupNorm_forward_before_network ( self , input )
def network_GroupNorm_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
return torch . nn . GroupNorm_load_state_dict_before_network ( self , * args , * * kwargs )
def network_LayerNorm_forward ( self , input ) :
if shared . opts . lora_functional :
return network_forward ( self , input , torch . nn . LayerNorm_forward_before_network )
network_apply_weights ( self )
return torch . nn . LayerNorm_forward_before_network ( self , input )
def network_LayerNorm_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
return torch . nn . LayerNorm_load_state_dict_before_network ( self , * args , * * kwargs )
2023-07-16 20:13:55 +00:00
def network_MultiheadAttention_forward ( self , * args , * * kwargs ) :
network_apply_weights ( self )
return torch . nn . MultiheadAttention_forward_before_network ( self , * args , * * kwargs )
def network_MultiheadAttention_load_state_dict ( self , * args , * * kwargs ) :
network_reset_cached_weight ( self )
return torch . nn . MultiheadAttention_load_state_dict_before_network ( self , * args , * * kwargs )
def list_available_networks ( ) :
available_networks . clear ( )
available_network_aliases . clear ( )
forbidden_network_aliases . clear ( )
available_network_hash_lookup . clear ( )
forbidden_network_aliases . update ( { " none " : 1 , " Addams " : 1 } )
os . makedirs ( shared . cmd_opts . lora_dir , exist_ok = True )
candidates = list ( shared . walk_files ( shared . cmd_opts . lora_dir , allowed_extensions = [ " .pt " , " .ckpt " , " .safetensors " ] ) )
2023-07-18 17:11:30 +00:00
candidates + = list ( shared . walk_files ( shared . cmd_opts . lyco_dir_backcompat , allowed_extensions = [ " .pt " , " .ckpt " , " .safetensors " ] ) )
2023-07-16 20:13:55 +00:00
for filename in candidates :
if os . path . isdir ( filename ) :
continue
name = os . path . splitext ( os . path . basename ( filename ) ) [ 0 ]
try :
entry = network . NetworkOnDisk ( name , filename )
except OSError : # should catch FileNotFoundError and PermissionError etc.
errors . report ( f " Failed to load network { name } from { filename } " , exc_info = True )
continue
available_networks [ name ] = entry
if entry . alias in available_network_aliases :
forbidden_network_aliases [ entry . alias . lower ( ) ] = 1
available_network_aliases [ name ] = entry
available_network_aliases [ entry . alias ] = entry
re_network_name = re . compile ( r " (.*) \ s* \ ([0-9a-fA-F]+ \ ) " )
def infotext_pasted ( infotext , params ) :
if " AddNet Module 1 " in [ x [ 1 ] for x in scripts . scripts_txt2img . infotext_fields ] :
return # if the other extension is active, it will handle those fields, no need to do anything
added = [ ]
for k in params :
if not k . startswith ( " AddNet Model " ) :
continue
num = k [ 13 : ]
if params . get ( " AddNet Module " + num ) != " LoRA " :
continue
name = params . get ( " AddNet Model " + num )
if name is None :
continue
m = re_network_name . match ( name )
if m :
name = m . group ( 1 )
multiplier = params . get ( " AddNet Weight A " + num , " 1.0 " )
added . append ( f " <lora: { name } : { multiplier } > " )
if added :
params [ " Prompt " ] + = " \n " + " " . join ( added )
available_networks = { }
available_network_aliases = { }
loaded_networks = [ ]
2023-08-09 13:54:49 +00:00
networks_in_memory = { }
2023-07-16 20:13:55 +00:00
available_network_hash_lookup = { }
forbidden_network_aliases = { }
list_available_networks ( )