This commit is contained in:
wfjsw 2024-03-10 00:56:47 -10:00
parent e8df8a9f52
commit ed69979d9d

View File

@ -27,9 +27,8 @@ class RTTensorMoverPatches:
self.mover_stream = stream_impl(device=devices.device)
self.calc_stream = stream_impl(device=devices.device)
self.stash = {}
self.speed_limit_loop_len = shared.opts.lowvram_max_loaded_module or 3
self.speed_limit_loop_head = 0
self.speed_limit_loop = [None] * self.speed_limit_loop_len
self.speed_limit_loop = []
self.linear_original = patches.patch(
@ -99,11 +98,13 @@ class RTTensorMoverPatches:
for k in to_remove:
del self.stash[k]
if len(self.speed_limit_loop) < shared.opts.lowvram_max_loaded_module:
self.speed_limit_loop.extend([None] * (shared.opts.lowvram_max_loaded_module - len(self.speed_limit_loop)))
self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event
self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % shared.opts.lowvram_max_loaded_module
if self.speed_limit_loop[self.speed_limit_loop_head] is not None:
self.mover_stream.wait_event(self.speed_limit_loop[self.speed_limit_loop_head])
self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % self.speed_limit_loop_len
self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event
def _wrapper_default(self, original):
def wrapper(input, weight, bias=None, *args, **kwargs):