avoid oom on slow cards

This commit is contained in:
wfjsw 2024-03-10 00:25:56 -10:00
parent 572f4cddb8
commit e8df8a9f52
3 changed files with 16 additions and 5 deletions

View File

@ -27,6 +27,10 @@ class RTTensorMoverPatches:
self.mover_stream = stream_impl(device=devices.device)
self.calc_stream = stream_impl(device=devices.device)
self.stash = {}
self.speed_limit_loop_len = shared.opts.lowvram_max_loaded_module or 3
self.speed_limit_loop_head = 0
self.speed_limit_loop = [None] * self.speed_limit_loop_len
self.linear_original = patches.patch(
__name__,
@ -95,6 +99,12 @@ class RTTensorMoverPatches:
for k in to_remove:
del self.stash[k]
if self.speed_limit_loop[self.speed_limit_loop_head] is not None:
self.mover_stream.wait_event(self.speed_limit_loop[self.speed_limit_loop_head])
self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % self.speed_limit_loop_len
self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event
def _wrapper_default(self, original):
def wrapper(input, weight, bias=None, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b):
@ -271,12 +281,12 @@ def setup_for_low_vram(sd_model, use_medvram):
if use_streamlined_lowvram:
# put it into pinned memory to achieve data transfer overlap
diff_model.time_embed._apply(lambda x: x.pin_memory())
diff_model.time_embed._apply(lambda x: x.pin_memory(device=devices.device))
for block in diff_model.input_blocks:
block._apply(lambda x: x.pin_memory())
diff_model.middle_block._apply(lambda x: x.pin_memory())
block._apply(lambda x: x.pin_memory(device=devices.device))
diff_model.middle_block._apply(lambda x: x.pin_memory(device=devices.device))
for block in diff_model.output_blocks:
block._apply(lambda x: x.pin_memory())
block._apply(lambda x: x.pin_memory(device=devices.device))
else:
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:

View File

@ -448,7 +448,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)._apply(
lambda x: (
x.pin_memory()
x.pin_memory(device=devices.device)
if not x.is_sparse and x.device.type == "cpu"
else x
)

View File

@ -217,6 +217,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
"lowvram_max_loaded_module": OptionInfo(3, "Maximum number of loaded modules in low VRAM mode", gr.Slider, {"minimum": 1, "maximum": 40, "step": 1}).info("Maximum number of loaded modules in low VRAM mode. Decrease this value if you encounter out of memory error."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {