This commit is contained in:
drhead 2024-05-19 18:34:12 -04:00 committed by GitHub
commit 83fc442a19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 10 deletions

View File

@ -1,6 +1,6 @@
import inspect
from collections import namedtuple
import numpy as np
from contextlib import nullcontext
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
@ -59,15 +59,11 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
return x_sample
def single_sample_to_image(sample, approximation=None):
def single_sample_to_image(sample, approximation=None, non_blocking=False):
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
x_sample = 255. * x_sample.permute(1, 2, 0)
return x_sample.to(device='cpu', dtype=torch.uint8, non_blocking=non_blocking)
def decode_first_stage(model, x):
@ -76,12 +72,27 @@ def decode_first_stage(model, x):
return samples_to_images_tensor(x, approx_index, model)
if torch.cuda.is_available():
lp_stream = torch.cuda.Stream()
live_preview_stream_context = torch.cuda.stream(lp_stream)
else:
lp_stream = None
live_preview_stream_context = nullcontext()
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
with live_preview_stream_context:
sample = single_sample_to_image(samples[index], approximation, non_blocking=lp_stream is not None)
if lp_stream is not None:
lp_stream.synchronize()
return Image.fromarray(sample.numpy())
def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
with live_preview_stream_context:
sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=lp_stream is not None) for sample in samples]
if lp_stream is not None:
lp_stream.synchronize()
return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors])
def images_tensor_to_samples(image, approximation=None, model=None):