diff --git a/modules/sd_models.py b/modules/sd_models.py index a29c8c1a9..b2dd005a6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash + if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) - - if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: - checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) model.load_state_dict(checkpoints_loaded[checkpoint_info]) + if shared.opts.sd_checkpoint_cache > 0: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info