diff --git a/modules/processing.py b/modules/processing.py index 2009d3bf8..187e98fde 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + if p.scripts is not None: + p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) + if len(prompts) == 0: break diff --git a/modules/scripts.py b/modules/scripts.py index 24056a12f..e6a505b39 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -80,6 +80,20 @@ class Script: pass + def before_process_batch(self, p, *args, **kwargs): + """ + Called before extra networks are parsed from the prompt, so you can add + new extra network keywords to the prompt with this callback. + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch + """ + + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -388,6 +402,15 @@ class ScriptRunner: print(f"Error running process: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def before_process_batch(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process_batch(p, *script_args, **kwargs) + except Exception: + print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: