diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6f8ccf1d2..2ca17d8bb 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from collections import namedtuple +from collections import namedtuple, deque import numpy as np from math import floor import torch @@ -344,18 +344,28 @@ class CFGDenoiser(torch.nn.Module): class TorchHijack: - def __init__(self, kdiff_sampler): - self.kdiff_sampler = kdiff_sampler + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) def __getattr__(self, item): if item == 'randn_like': - return self.kdiff_sampler.randn_like + return self.randn_like if hasattr(torch, item): return getattr(torch, item) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + return torch.randn_like(x) + class KDiffusionSampler: def __init__(self, funcname, sd_model): @@ -367,7 +377,6 @@ class KDiffusionSampler: self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None - self.sampler_noise_index = 0 self.stop_at = None self.eta = None self.default_eta = 1.0 @@ -400,26 +409,14 @@ class KDiffusionSampler: def number_of_needed_noises(self, p): return p.steps - def randn_like(self, x): - noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None - - if noise is not None and x.shape == noise.shape: - res = noise - else: - res = torch.randn_like(x) - - self.sampler_noise_index += 1 - return res - def initialize(self, p): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap.step = 0 - self.sampler_noise_index = 0 self.eta = p.eta or opts.eta_ancestral if self.sampler_noises is not None: - k_diffusion.sampling.torch = TorchHijack(self) + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises) extra_params_kwargs = {} for param_name in self.extra_params: