From ee470cc6a32ae0c89ca32d71adac02b2d434f59a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 2 Mar 2024 06:54:11 +0300 Subject: [PATCH] style changes for #14979 --- modules/sd_models.py | 60 ++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index db72e120f..747fc39ee 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -552,36 +552,48 @@ def repair_config(sd_config): karlo_path = os.path.join(paths.models_path, 'karlo') sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + +def rescale_zero_terminal_snr_abar(alphas_cumprod): + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return alphas_bar + + def apply_alpha_schedule_override(sd_model, p=None): - def rescale_zero_terminal_snr_abar(alphas_cumprod): - alphas_bar_sqrt = alphas_cumprod.sqrt() + """ + Applies an override to the alpha schedule of the model according to settings. + - downcasts the alpha schedule to half precision + - rescales the alpha schedule to have zero terminal SNR + """ - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'): + return - # Shift so the last timestep is zero. - alphas_bar_sqrt -= (alphas_bar_sqrt_T) + sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device) - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + if opts.use_downcasted_alpha_bar: + if p is not None: + p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar + sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas_bar[-1] = 4.8973451890853435e-08 - return alphas_bar + if opts.sd_noise_schedule == "Zero Terminal SNR": + if p is not None: + p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule + sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) - if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'): - sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device) - - if opts.use_downcasted_alpha_bar: - if p is not None: - p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar - sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device) - if opts.sd_noise_schedule == "Zero Terminal SNR": - if p is not None: - p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule - sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device) sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'