This commit is contained in:
wfjsw 2024-03-10 00:56:47 -10:00
parent e8df8a9f52
commit ed69979d9d

View File

@ -27,9 +27,8 @@ class RTTensorMoverPatches:
self.mover_stream = stream_impl(device=devices.device) self.mover_stream = stream_impl(device=devices.device)
self.calc_stream = stream_impl(device=devices.device) self.calc_stream = stream_impl(device=devices.device)
self.stash = {} self.stash = {}
self.speed_limit_loop_len = shared.opts.lowvram_max_loaded_module or 3
self.speed_limit_loop_head = 0 self.speed_limit_loop_head = 0
self.speed_limit_loop = [None] * self.speed_limit_loop_len self.speed_limit_loop = []
self.linear_original = patches.patch( self.linear_original = patches.patch(
@ -99,12 +98,14 @@ class RTTensorMoverPatches:
for k in to_remove: for k in to_remove:
del self.stash[k] del self.stash[k]
if len(self.speed_limit_loop) < shared.opts.lowvram_max_loaded_module:
self.speed_limit_loop.extend([None] * (shared.opts.lowvram_max_loaded_module - len(self.speed_limit_loop)))
self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event
self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % shared.opts.lowvram_max_loaded_module
if self.speed_limit_loop[self.speed_limit_loop_head] is not None: if self.speed_limit_loop[self.speed_limit_loop_head] is not None:
self.mover_stream.wait_event(self.speed_limit_loop[self.speed_limit_loop_head]) self.mover_stream.wait_event(self.speed_limit_loop[self.speed_limit_loop_head])
self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % self.speed_limit_loop_len
self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event
def _wrapper_default(self, original): def _wrapper_default(self, original):
def wrapper(input, weight, bias=None, *args, **kwargs): def wrapper(input, weight, bias=None, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b): with self.wrap_weight_biases(input, weight, bias) as (w, b):