use override of sd_vae

This commit is contained in:
w-e-w 2024-05-19 05:14:32 +09:00
parent 1f392517f8
commit 1e696b028a
1 changed files with 9 additions and 15 deletions

View File

@ -118,21 +118,16 @@ def apply_size(p, x: str, xs) -> None:
def find_vae(name: str):
if name.lower() in ['auto', 'automatic']:
return modules.sd_vae.unspecified
if name.lower() == 'none':
return None
else:
choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
if len(choices) == 0:
print(f"No VAE found for {name}; using automatic")
return modules.sd_vae.unspecified
else:
return modules.sd_vae.vae_dict[choices[0]]
match name := name.lower().strip():
case 'auto', 'automatic':
return 'Automatic'
case 'none':
return 'None'
return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')
def apply_vae(p, x, xs):
modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))
p.override_settings['sd_vae'] = find_vae(x)
def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
@ -270,7 +265,7 @@ axis_options = [
AxisOption("Extra noise", float, apply_override("img2img_extra_noise")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
@ -399,10 +394,9 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
class SharedSettingsStackHelper(object):
def __enter__(self):
self.vae = opts.sd_vae
pass
def __exit__(self, exc_type, exc_value, tb):
opts.data["sd_vae"] = self.vae
modules.sd_models.reload_model_weights()
modules.sd_vae.reload_vae_weights()