diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 079118761..cede1cb08 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -93,6 +93,8 @@ callback_map = dict( callbacks_infotext_pasted=[], callbacks_script_unloaded=[], callbacks_before_ui=[], + callbacks_on_reload=[], + callbacks_on_polling=[], ) @@ -100,7 +102,6 @@ def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() - def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: @@ -108,6 +109,20 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): except Exception: report_exception(c, 'app_started_callback') +def app_polling_callback(demo: Optional[Blocks], app: FastAPI): + for c in callback_map['callbacks_on_polling']: + try: + c.callback() + except Exception: + report_exception(c, 'callbacks_on_polling') + +def app_reload_callback(demo: Optional[Blocks], app: FastAPI): + for c in callback_map['callbacks_on_reload']: + try: + c.callback() + except Exception: + report_exception(c, 'callbacks_on_reload') + def model_loaded_callback(sd_model): for c in callback_map['callbacks_model_loaded']: @@ -254,6 +269,14 @@ def on_app_started(callback): add_callback(callback_map['callbacks_app_started'], callback) +def on_polling(callback): + """register a function to be called on each polling of the server.""" + add_callback(callback_map['callbacks_on_polling'], callback) + +def on_before_reload(callback): + """register a function to be called just before the server reloads.""" + add_callback(callback_map['callbacks_on_reload'], callback) + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument; this function is also called when the script is reloaded. """ diff --git a/webui.py b/webui.py index ae3285c6d..98b6d394f 100644 --- a/webui.py +++ b/webui.py @@ -264,7 +264,9 @@ def create_api(app): def wait_on_server(demo=None): while 1: time.sleep(0.5) + modules.script_callbacks.app_polling_callback(None, demo) if shared.state.need_restart: + modules.script_callbacks.app_reload_callback(None, demo) shared.state.need_restart = False time.sleep(0.5) demo.close()