From 0fddb4a1c06a6e2122add7eee3b001a6d473baee Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 30 Nov 2022 08:02:39 -0500 Subject: [PATCH] Rework MPS randn fix, add randn_like fix torch.manual_seed() already sets a CPU generator, so there is no reason to create a CPU generator manually. torch.randn_like also needs a MPS fix for k-diffusion, but a torch hijack with randn_like already exists so it can also be used for that. --- modules/devices.py | 15 +++------------ modules/sd_samplers.py | 8 +++++--- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index f00079c6b..046460fa0 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -66,24 +66,15 @@ dtype_vae = torch.float16 def randn(seed, shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. - if device.type == 'mps': - generator = torch.Generator(device=cpu) - generator.manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - torch.manual_seed(seed) + if device.type == 'mps': + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': - generator = torch.Generator(device=cpu) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8b11f569a..4c123d3b6 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -365,7 +365,10 @@ class TorchHijack: if noise.shape == x.shape: return noise - return torch.randn_like(x) + if x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + else: + return torch.randn_like(x) # MPS fix for randn in torchsde @@ -429,8 +432,7 @@ class KDiffusionSampler: self.model_wrap.step = 0 self.eta = p.eta or opts.eta_ancestral - if self.sampler_noises is not None: - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises) + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) extra_params_kwargs = {} for param_name in self.extra_params: