add postprocess_batch_list callback

This commit is contained in:
ljleb 2023-07-24 13:52:24 -04:00
parent f451994053
commit ca45ff1ae6
2 changed files with 55 additions and 1 deletions

View File

@ -717,7 +717,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt) all_prompts = p.all_prompts[:]
all_seeds = p.all_seeds[:]
all_subseeds = p.all_subseeds[:]
# apply changes to generation data
all_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.prompts
all_seeds[n * p.batch_size:(n + 1) * p.batch_size] = p.seeds
all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] = p.subseeds
# update p.all_negative_prompts in case extensions changed the size of the batch
# create_infotext below uses it
old_negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] = p.negative_prompts
try:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
finally:
# restore p.all_negative_prompts in case extensions changed the size of the batch
p.all_negative_prompts[n * p.batch_size:n * p.batch_size + len(p.negative_prompts)] = old_negative_prompts
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
@ -806,6 +824,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
x_samples_ddim = postprocess_batch_list_args.images
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i p.batch_index = i

View File

@ -16,6 +16,11 @@ class PostprocessImageArgs:
self.image = image self.image = image
class PostprocessBatchListArgs:
def __init__(self, images):
self.images = images
class Script: class Script:
name = None name = None
"""script's internal name derived from title""" """script's internal name derived from title"""
@ -156,6 +161,25 @@ class Script:
pass pass
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
"""
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
This is useful when you want to update the entire batch instead of individual images.
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
If the number of images is different from the batch size when returning,
then the script has the responsibility to also update the following attributes in the processing object (p):
- p.prompts
- p.negative_prompts
- p.seeds
- p.subseeds
**kwargs will have same items as process_batch, and also:
- batch_number - index of current batch, from 0 to number of batches-1
"""
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args): def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
""" """
Called for every image after it has been generated. Called for every image after it has been generated.
@ -536,6 +560,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try: