diff --git a/webui.py b/webui.py index c70a11c7c..39f9ae9a8 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,7 @@ import os import threading +from modules import devices from modules.paths import script_path import signal @@ -47,6 +48,8 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func): def f(*args, **kwargs): + devices.torch_gc() + shared.state.sampling_step = 0 shared.state.job_count = -1 shared.state.job_no = 0 @@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func): shared.state.job = "" shared.state.job_count = 0 + devices.torch_gc() + return res return modules.ui.wrap_gradio_call(f)