diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index fd63e47f5..6b7979e20 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -87,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) - self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None + self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms self.mask = None self.nmask = None self.init_latent = None @@ -113,7 +113,9 @@ class VanillaStableDiffusionSampler: return samples def sample(self, p, x, conditioning, unconditional_conditioning): - self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) + for fieldname in ['p_sample_ddim', 'p_sample_plms']: + if hasattr(self.sampler, fieldname): + setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)) self.mask = None self.nmask = None self.init_latent = None