Allow refiner to be triggered by model timestep instead of sampling

This commit is contained in:
drhead 2024-02-20 16:37:29 -05:00 committed by GitHub
parent 09d2e58811
commit 25eeeaa65f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -156,7 +156,16 @@ replace_torchsde_browinan()
def apply_refiner(cfg_denoiser):
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
if opts.refiner_switch_by_sample_steps:
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try:
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info