From a183de04e3f965083e7f3462201327d30c36b958 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 20:03:33 +0800 Subject: [PATCH] Execute model_loaded_callback after moving to target device --- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 50bc209e4..2c0457715 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -842,13 +842,13 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False): sd_hijack.model_hijack.hijack(sd_model) timer.record("hijack") - script_callbacks.model_loaded_callback(sd_model) - timer.record("script callbacks") - if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") + print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 31306d8ba..43687e48d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -273,10 +273,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) if not sd_model.lowvram: sd_model.to(devices.device) + script_callbacks.model_loaded_callback(sd_model) + print("VAE weights loaded.") return sd_model