fix for PLMS live previews in txt2img

This commit is contained in:
AUTOMATIC 2022-09-08 19:34:20 +03:00
parent ca3861e05f
commit fe4e3c2673

View File

@ -87,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
@ -113,7 +113,9 @@ class VanillaStableDiffusionSampler:
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning): def sample(self, p, x, conditioning, unconditional_conditioning):
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None