diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9b00f76e9..4d2026e1c 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -88,10 +88,13 @@ def refresh_vae_list(): def find_vae_near_checkpoint(checkpoint_file): - checkpoint_path = os.path.splitext(checkpoint_file)[0] - for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: - if os.path.isfile(vae_location): - return vae_location + checkpoint_path = os.path.basename(checkpoint_file).split('.', 1)[0] + print(f"checkpoint: {checkpoint_path}") + for vae_file in vae_dict.values(): + vae_path = os.path.basename(vae_file).split('.', 1)[0] + print(f"vae: {vae_path}") + if vae_path == checkpoint_path: + return vae_file return None