add option to revert to old behavior

This commit is contained in:
wfjsw 2024-02-21 00:22:45 -06:00
parent 8828c9ecc5
commit 1d950c7776
3 changed files with 28 additions and 21 deletions

View File

@ -10,7 +10,7 @@ import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils
from modules import devices, paths, shared, modelloader, errors, torch_utils
blip_image_eval_size = 384
clip_model_name = 'ViT-L/14'
@ -186,6 +186,7 @@ class InterrogateModels:
res = ""
shared.state.begin(job="interrogate")
try:
from modules import lowvram
lowvram.send_everything_to_cpu()
devices.torch_gc()

View File

@ -9,6 +9,9 @@ stream_impl = devices.get_stream_impl()
stream_wrapper = devices.get_stream_wrapper()
use_streamlined_lowvram = not shared.opts.use_non_streamlined_lowvram and stream_impl is not None and stream_wrapper is not None
def is_same_device(device1, device2):
tensor1_device_type = device1.type
tensor2_device_type = device2.type
@ -29,31 +32,31 @@ class RTTensorMoverPatches:
__name__,
torch.nn.functional,
"linear",
self.create_wrapper(torch.nn.functional.linear),
self._wrapper_default(torch.nn.functional.linear),
)
self.conv2d_original = patches.patch(
__name__,
torch.nn.functional,
"conv2d",
self.create_wrapper(torch.nn.functional.conv2d),
self._wrapper_default(torch.nn.functional.conv2d),
)
self.conv3d_original = patches.patch(
__name__,
torch.nn.functional,
"conv3d",
self.create_wrapper(torch.nn.functional.conv3d),
self._wrapper_default(torch.nn.functional.conv3d),
)
self.group_norm_original = patches.patch(
__name__,
torch.nn.functional,
"group_norm",
self.create_wrapper(torch.nn.functional.group_norm, type=2),
self._wrapper_group_norm(torch.nn.functional.group_norm),
)
self.layer_norm_original = patches.patch(
__name__,
torch.nn.functional,
"layer_norm",
self.create_wrapper(torch.nn.functional.layer_norm, type=2),
self._wrapper_layer_norm(torch.nn.functional.layer_norm),
)
@contextmanager
@ -92,21 +95,23 @@ class RTTensorMoverPatches:
for k in to_remove:
del self.stash[k]
def create_wrapper(self, original, type=1):
if type == 2:
def _wrapper_default(self, original):
def wrapper(input, weight, bias, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, w, b, *args, **kwargs)
return wrapper
def wrapper(input, arg1, weight, bias, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, arg1, w, b, *args, **kwargs)
def _wrapper_group_norm(self, original):
def wrapper(input, num_groups, weight, bias, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, num_groups, w, b, *args, **kwargs)
return wrapper
return wrapper
else:
def wrapper(input, weight, bias, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, w, b, *args, **kwargs)
return wrapper
def _wrapper_layer_norm(self, original):
def wrapper(input, normalized_shape, weight, bias, *args, **kwargs):
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, normalized_shape, w, b, *args, **kwargs)
return wrapper
def close(self):
patches.undo(__name__, torch.nn.functional, "linear")
@ -117,7 +122,7 @@ class RTTensorMoverPatches:
rtmover = None
if stream_impl is not None and stream_wrapper is not None:
if use_streamlined_lowvram:
rtmover = RTTensorMoverPatches()
@ -264,7 +269,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# install hooks for bits of third model
if stream_impl is not None and stream_wrapper is not None:
if use_streamlined_lowvram:
# put it into pinned memory to achieve data transfer overlap
diff_model.time_embed._apply(lambda x: x.pin_memory())
for block in diff_model.input_blocks:

View File

@ -216,6 +216,7 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),
"cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."),
"use_non_streamlined_lowvram": OptionInfo(False, "Use non-streamlined low VRAM mode").info("Do not use the streamlined mode for low VRAM cards. For devices that do not support concurrently copy memory between host and device while executing a kernel. Significantly decreases performance."),
}))
options_templates.update(options_section(('compatibility', "Compatibility", "sd"), {