mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
avoid oom on slow cards
This commit is contained in:
parent
572f4cddb8
commit
e8df8a9f52
@ -27,6 +27,10 @@ class RTTensorMoverPatches:
|
||||
self.mover_stream = stream_impl(device=devices.device)
|
||||
self.calc_stream = stream_impl(device=devices.device)
|
||||
self.stash = {}
|
||||
self.speed_limit_loop_len = shared.opts.lowvram_max_loaded_module or 3
|
||||
self.speed_limit_loop_head = 0
|
||||
self.speed_limit_loop = [None] * self.speed_limit_loop_len
|
||||
|
||||
|
||||
self.linear_original = patches.patch(
|
||||
__name__,
|
||||
@ -95,6 +99,12 @@ class RTTensorMoverPatches:
|
||||
for k in to_remove:
|
||||
del self.stash[k]
|
||||
|
||||
if self.speed_limit_loop[self.speed_limit_loop_head] is not None:
|
||||
self.mover_stream.wait_event(self.speed_limit_loop[self.speed_limit_loop_head])
|
||||
|
||||
self.speed_limit_loop_head = (self.speed_limit_loop_head + 1) % self.speed_limit_loop_len
|
||||
self.speed_limit_loop[self.speed_limit_loop_head] = after_calc_event
|
||||
|
||||
def _wrapper_default(self, original):
|
||||
def wrapper(input, weight, bias=None, *args, **kwargs):
|
||||
with self.wrap_weight_biases(input, weight, bias) as (w, b):
|
||||
@ -271,12 +281,12 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
|
||||
if use_streamlined_lowvram:
|
||||
# put it into pinned memory to achieve data transfer overlap
|
||||
diff_model.time_embed._apply(lambda x: x.pin_memory())
|
||||
diff_model.time_embed._apply(lambda x: x.pin_memory(device=devices.device))
|
||||
for block in diff_model.input_blocks:
|
||||
block._apply(lambda x: x.pin_memory())
|
||||
diff_model.middle_block._apply(lambda x: x.pin_memory())
|
||||
block._apply(lambda x: x.pin_memory(device=devices.device))
|
||||
diff_model.middle_block._apply(lambda x: x.pin_memory(device=devices.device))
|
||||
for block in diff_model.output_blocks:
|
||||
block._apply(lambda x: x.pin_memory())
|
||||
block._apply(lambda x: x.pin_memory(device=devices.device))
|
||||
else:
|
||||
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.input_blocks:
|
||||
|
@ -448,7 +448,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)._apply(
|
||||
lambda x: (
|
||||
x.pin_memory()
|
||||
x.pin_memory(device=devices.device)
|
||||
if not x.is_sparse and x.device.type == "cpu"
|
||||
else x
|
||||
)
|
||||
|
@ -217,6 +217,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
|
||||
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
|
||||
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
|
||||
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
|
||||
"lowvram_max_loaded_module": OptionInfo(3, "Maximum number of loaded modules in low VRAM mode", gr.Slider, {"minimum": 1, "maximum": 40, "step": 1}).info("Maximum number of loaded modules in low VRAM mode. Decrease this value if you encounter out of memory error."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {
|
||||
|
Loading…
Reference in New Issue
Block a user