mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
this time for sure
This commit is contained in:
parent
a64fbe8928
commit
cc53db6652
@ -538,8 +538,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecodedSamples(list):
|
||||||
|
already_decoded = True
|
||||||
|
|
||||||
|
|
||||||
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||||
samples = []
|
samples = DecodedSamples()
|
||||||
|
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||||
@ -793,7 +797,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
|
||||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
if getattr(samples_ddim, 'already_decoded', False):
|
||||||
|
x_samples_ddim = samples_ddim
|
||||||
|
else:
|
||||||
|
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
@ -1161,9 +1169,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
|
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||||
|
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return decoded_samples
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super().close()
|
super().close()
|
||||||
|
Loading…
Reference in New Issue
Block a user