From 1d950c777615bcd2e236d1423bfc8b5f6af02039 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Wed, 21 Feb 2024 00:22:45 -0600 Subject: [PATCH] add option to revert to old behavior --- modules/interrogate.py | 3 ++- modules/lowvram.py | 45 ++++++++++++++++++++++----------------- modules/shared_options.py | 1 + 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index c93e7aa86..0edd2fc10 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -10,7 +10,7 @@ import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode -from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils +from modules import devices, paths, shared, modelloader, errors, torch_utils blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -186,6 +186,7 @@ class InterrogateModels: res = "" shared.state.begin(job="interrogate") try: + from modules import lowvram lowvram.send_everything_to_cpu() devices.torch_gc() diff --git a/modules/lowvram.py b/modules/lowvram.py index 3c171a058..6a2012b3a 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -9,6 +9,9 @@ stream_impl = devices.get_stream_impl() stream_wrapper = devices.get_stream_wrapper() +use_streamlined_lowvram = not shared.opts.use_non_streamlined_lowvram and stream_impl is not None and stream_wrapper is not None + + def is_same_device(device1, device2): tensor1_device_type = device1.type tensor2_device_type = device2.type @@ -29,31 +32,31 @@ class RTTensorMoverPatches: __name__, torch.nn.functional, "linear", - self.create_wrapper(torch.nn.functional.linear), + self._wrapper_default(torch.nn.functional.linear), ) self.conv2d_original = patches.patch( __name__, torch.nn.functional, "conv2d", - self.create_wrapper(torch.nn.functional.conv2d), + self._wrapper_default(torch.nn.functional.conv2d), ) self.conv3d_original = patches.patch( __name__, torch.nn.functional, "conv3d", - self.create_wrapper(torch.nn.functional.conv3d), + self._wrapper_default(torch.nn.functional.conv3d), ) self.group_norm_original = patches.patch( __name__, torch.nn.functional, "group_norm", - self.create_wrapper(torch.nn.functional.group_norm, type=2), + self._wrapper_group_norm(torch.nn.functional.group_norm), ) self.layer_norm_original = patches.patch( __name__, torch.nn.functional, "layer_norm", - self.create_wrapper(torch.nn.functional.layer_norm, type=2), + self._wrapper_layer_norm(torch.nn.functional.layer_norm), ) @contextmanager @@ -92,21 +95,23 @@ class RTTensorMoverPatches: for k in to_remove: del self.stash[k] - def create_wrapper(self, original, type=1): - if type == 2: + def _wrapper_default(self, original): + def wrapper(input, weight, bias, *args, **kwargs): + with self.wrap_weight_biases(input, weight, bias) as (w, b): + return original(input, w, b, *args, **kwargs) + return wrapper - def wrapper(input, arg1, weight, bias, *args, **kwargs): - with self.wrap_weight_biases(input, weight, bias) as (w, b): - return original(input, arg1, w, b, *args, **kwargs) + def _wrapper_group_norm(self, original): + def wrapper(input, num_groups, weight, bias, *args, **kwargs): + with self.wrap_weight_biases(input, weight, bias) as (w, b): + return original(input, num_groups, w, b, *args, **kwargs) + return wrapper - return wrapper - else: - - def wrapper(input, weight, bias, *args, **kwargs): - with self.wrap_weight_biases(input, weight, bias) as (w, b): - return original(input, w, b, *args, **kwargs) - - return wrapper + def _wrapper_layer_norm(self, original): + def wrapper(input, normalized_shape, weight, bias, *args, **kwargs): + with self.wrap_weight_biases(input, weight, bias) as (w, b): + return original(input, normalized_shape, w, b, *args, **kwargs) + return wrapper def close(self): patches.undo(__name__, torch.nn.functional, "linear") @@ -117,7 +122,7 @@ class RTTensorMoverPatches: rtmover = None -if stream_impl is not None and stream_wrapper is not None: +if use_streamlined_lowvram: rtmover = RTTensorMoverPatches() @@ -264,7 +269,7 @@ def setup_for_low_vram(sd_model, use_medvram): # install hooks for bits of third model - if stream_impl is not None and stream_wrapper is not None: + if use_streamlined_lowvram: # put it into pinned memory to achieve data transfer overlap diff_model.time_embed._apply(lambda x: x.pin_memory()) for block in diff_model.input_blocks: diff --git a/modules/shared_options.py b/modules/shared_options.py index 21643afe0..d0838fd7b 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -216,6 +216,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd" "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), + "use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."), })) options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {