mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Add KL Optimal scheduler
This commit is contained in:
parent
ddb28b33a3
commit
83266205d0
@ -31,6 +31,15 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
|||||||
return torch.FloatTensor(sigs).to(device)
|
return torch.FloatTensor(sigs).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def kl_optimal(n, sigma_min, sigma_max, device):
|
||||||
|
alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
|
||||||
|
alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
|
||||||
|
sigmas = torch.empty((n+1,), device=device)
|
||||||
|
for i in range(n+1):
|
||||||
|
sigmas[i] = torch.tan((i/n) * alpha_min + (1.0-i/n) * alpha_max)
|
||||||
|
return sigmas
|
||||||
|
|
||||||
|
|
||||||
schedulers = [
|
schedulers = [
|
||||||
Scheduler('automatic', 'Automatic', None),
|
Scheduler('automatic', 'Automatic', None),
|
||||||
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),
|
||||||
@ -38,6 +47,7 @@ schedulers = [
|
|||||||
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
|
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
|
||||||
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
|
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
|
||||||
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
|
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
|
||||||
|
Scheduler('kl_optimal', 'KL Optimal', kl_optimal),
|
||||||
]
|
]
|
||||||
|
|
||||||
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}
|
||||||
|
Loading…
Reference in New Issue
Block a user