mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Implementation for sgm_uniform branch
This commit is contained in:
parent
c4a00affc5
commit
a6b5a513f9
12
modules/sd_samplers_custom_schedulers.py
Normal file
12
modules/sd_samplers_custom_schedulers.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
||||
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
|
||||
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
|
||||
sigs = [
|
||||
inner_model.t_to_sigma(ts)
|
||||
for ts in torch.linspace(start, end, n)[:-1]
|
||||
]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs).to(device)
|
@ -3,6 +3,7 @@ import inspect
|
||||
import k_diffusion.sampling
|
||||
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
|
||||
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
|
||||
from modules.sd_samplers_custom_schedulers import sgm_uniform
|
||||
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
@ -62,7 +63,8 @@ k_diffusion_scheduler = {
|
||||
'Automatic': None,
|
||||
'karras': k_diffusion.sampling.get_sigmas_karras,
|
||||
'exponential': k_diffusion.sampling.get_sigmas_exponential,
|
||||
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
|
||||
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential,
|
||||
'sgm_uniform' : sgm_uniform,
|
||||
}
|
||||
|
||||
|
||||
@ -121,6 +123,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
|
||||
sigmas_kwargs['rho'] = opts.rho
|
||||
p.extra_generation_params["Schedule rho"] = opts.rho
|
||||
if opts.k_sched_type == 'sgm_uniform':
|
||||
# Ensure the "step" will be target step + 1
|
||||
steps += 1 if not discard_next_to_last_sigma else 0
|
||||
sigmas_kwargs['inner_model'] = self.model_wrap
|
||||
sigmas_kwargs.pop('rho', None)
|
||||
|
||||
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
|
@ -368,7 +368,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential", "sgm_uniform"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||
|
Loading…
Reference in New Issue
Block a user