From 67efee33a6c65e58b3f6c788993d0e68a33e4fd0 Mon Sep 17 00:00:00 2001 From: klimaleksus Date: Mon, 28 Nov 2022 16:29:43 +0500 Subject: [PATCH] Make VAE step sequential to prevent VRAM spikes --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index edceb5323..fd995b8aa 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) - samples_ddim = samples_ddim.to(devices.dtype_vae) - x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) + x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim