vectorize kl-optimal sigma calculation

Co-authored-by: mamei16 <marcel.1710@live.de>
This commit is contained in:
drhead 2024-04-28 00:15:58 -04:00 committed by GitHub
parent 83266205d0
commit 3a215deff2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,9 +34,8 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
def kl_optimal(n, sigma_min, sigma_max, device): def kl_optimal(n, sigma_min, sigma_max, device):
alpha_min = torch.arctan(torch.tensor(sigma_min, device=device)) alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
alpha_max = torch.arctan(torch.tensor(sigma_max, device=device)) alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
sigmas = torch.empty((n+1,), device=device) step_indices = torch.arange(n + 1, device=device)
for i in range(n+1): sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
sigmas[i] = torch.tan((i/n) * alpha_min + (1.0-i/n) * alpha_max)
return sigmas return sigmas