From e8df8a9f520fc5a6672db205d251768f2bd205bb Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sun, 10 Mar 2024 00:25:56 -1000 Subject: [PATCH] avoid oom on slow cards --- modules/lowvram.py | 18 ++++++++++++++---- modules/sd_models.py | 2 +- modules/shared_options.py | 1 + 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/modules/lowvram.py b/modules/lowvram.py index ee14f2afa..a638d7c6e 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -27,6 +27,10 @@ class RTTensorMoverPatches: self.mover_stream = stream_impl(device=devices.device) self.calc_stream = stream_impl(device=devices.device) self.stash = {} + self.speed_limit_loop_len = shared.opts.lowvram_max_loaded_module or 3 + self.speed_limit_loop_head = 0 + self.speed_limit_loop = [None] * self.speed_limit_loop_len + self.linear_original = patches.patch( __name__, @@ -95,6 +99,12 @@ class RTTensorMoverPatches: for k in to_remove: del self.stash[k] + if self.speed_limit_loop[self.speed_limit_loop_head] is not None: + self.mover_stream.wait_event(self.speed_limit_loop[self.speed_limit_loop_head]) + + self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % self.speed_limit_loop_len + self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event + def _wrapper_default(self, original): def wrapper(input, weight, bias=None, *args, **kwargs): with self.wrap_weight_biases(input, weight, bias) as (w, b): @@ -271,12 +281,12 @@ def setup_for_low_vram(sd_model, use_medvram): if use_streamlined_lowvram: # put it into pinned memory to achieve data transfer overlap - diff_model.time_embed._apply(lambda x: x.pin_memory()) + diff_model.time_embed._apply(lambda x: x.pin_memory(device=devices.device)) for block in diff_model.input_blocks: - block._apply(lambda x: x.pin_memory()) - diff_model.middle_block._apply(lambda x: x.pin_memory()) + block._apply(lambda x: x.pin_memory(device=devices.device)) + diff_model.middle_block._apply(lambda x: x.pin_memory(device=devices.device)) for block in diff_model.output_blocks: - block._apply(lambda x: x.pin_memory()) + block._apply(lambda x: x.pin_memory(device=devices.device)) else: diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) for block in diff_model.input_blocks: diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f6f352f8..1ab76dc9c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -448,7 +448,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.fp16_bias = module.bias.data.clone().cpu().half() module.to(torch.float8_e4m3fn)._apply( lambda x: ( - x.pin_memory() + x.pin_memory(device=devices.device) if not x.is_sparse and x.device.type == "cpu" else x ) diff --git a/modules/shared_options.py b/modules/shared_options.py index d0838fd7b..40ab5e75c 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -217,6 +217,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), "use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."), + "lowvram_max_loaded_module": OptionInfo(3, "Maximum number of loaded modules in low VRAM mode", gr.Slider, {"minimum": 1, "maximum": 40, "step": 1}).info("Maximum number of loaded modules in low VRAM mode. Decrease this value if you encounter out of memory error."), })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {