handle non blocking better and case of single image

This commit is contained in:
drhead 2024-05-19 18:34:09 -04:00 committed by GitHub
parent 4eb7cb443d
commit 27e35f13fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 5 deletions

View File

@ -72,10 +72,6 @@ def decode_first_stage(model, x):
return samples_to_images_tensor(x, approx_index, model)
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
if torch.cuda.is_available():
lp_stream = torch.cuda.Stream()
live_preview_stream_context = torch.cuda.stream(lp_stream)
@ -83,9 +79,17 @@ else:
lp_stream = None
live_preview_stream_context = nullcontext()
def sample_to_image(samples, index=0, approximation=None):
with live_preview_stream_context:
sample = single_sample_to_image(samples[index], approximation, non_blocking=lp_stream is not None)
if lp_stream is not None:
lp_stream.synchronize()
return Image.fromarray(sample.numpy())
def samples_to_image_grid(samples, approximation=None):
with live_preview_stream_context:
sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=True) for sample in samples]
sample_tensors = [single_sample_to_image(sample, approximation, non_blocking=lp_stream is not None) for sample in samples]
if lp_stream is not None:
lp_stream.synchronize()
return images.image_grid([Image.fromarray(sample.numpy()) for sample in sample_tensors])