progress bar description for k-diffsuion for 88393097

This commit is contained in:
AUTOMATIC 2022-09-01 15:22:42 +03:00
parent 49fcdbefa3
commit 2d5689a051

View File

@ -35,6 +35,7 @@ import traceback
from collections import namedtuple
from contextlib import nullcontext
import signal
import tqdm
import k_diffusion.sampling
from ldm.util import instantiate_from_config
@ -842,6 +843,7 @@ class StableDiffusionProcessing:
self.extra_generation_params: dict = extra_generation_params
self.overlay_images = overlay_images
self.paste_to = None
self.progress_info = ""
def init(self):
pass
@ -917,7 +919,6 @@ class CFGDenoiser(nn.Module):
return denoised
class KDiffusionSampler:
def __init__(self, funcname):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
@ -938,12 +939,18 @@ class KDiffusionSampler:
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: tqdm.tqdm(range(*args), desc=p.progress_info, **kwargs)
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
return samples_ddim
@ -1030,6 +1037,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
p.progress_info = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
x_samples_ddim = model.decode_first_stage(samples_ddim)