stable-diffusion-webui/modules/lowvram.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

291 lines
11 KiB
Python
Raw Normal View History

2024-02-21 00:49:44 +00:00
from contextlib import contextmanager, nullcontext
import torch
2024-02-09 08:21:07 +00:00
from modules import devices, shared, patches
module_in_gpu = None
cpu = torch.device("cpu")
2024-02-07 08:26:10 +00:00
stream_impl = devices.get_stream_impl()
stream_wrapper = devices.get_stream_wrapper()
2024-02-12 03:54:32 +00:00
use_streamlined_lowvram = torch.cuda.is_available() and not shared.opts.use_non_streamlined_lowvram and stream_impl is not None and stream_wrapper is not None
2024-02-21 06:22:45 +00:00
2024-02-21 00:49:44 +00:00
def is_same_device(device1, device2):
tensor1_device_type = device1.type
tensor2_device_type = device2.type
tensor1_device_index = device1.index or 0
tensor2_device_index = device2.index or 0
return (
tensor1_device_type == tensor2_device_type
and tensor1_device_index == tensor2_device_index
)
class RTTensorMoverPatches:
2024-02-09 08:21:07 +00:00
def __init__(self):
2024-02-21 00:49:44 +00:00
self.mover_stream = stream_impl(device=devices.device)
self.calc_stream = stream_impl(device=devices.device)
self.stash = {}
2024-02-09 08:21:07 +00:00
2024-02-12 03:54:32 +00:00
self.linear_original = patches.patch(
__name__,
torch.nn.functional,
"linear",
2024-02-21 06:22:45 +00:00
self._wrapper_default(torch.nn.functional.linear),
2024-02-12 03:54:32 +00:00
)
self.conv2d_original = patches.patch(
__name__,
torch.nn.functional,
"conv2d",
2024-02-21 06:22:45 +00:00
self._wrapper_default(torch.nn.functional.conv2d),
2024-02-12 03:54:32 +00:00
)
self.conv3d_original = patches.patch(
__name__,
torch.nn.functional,
"conv3d",
2024-02-21 06:22:45 +00:00
self._wrapper_default(torch.nn.functional.conv3d),
2024-02-12 03:54:32 +00:00
)
self.group_norm_original = patches.patch(
__name__,
torch.nn.functional,
"group_norm",
2024-02-21 06:22:45 +00:00
self._wrapper_group_norm(torch.nn.functional.group_norm),
2024-02-12 03:54:32 +00:00
)
self.layer_norm_original = patches.patch(
__name__,
torch.nn.functional,
"layer_norm",
2024-02-21 06:22:45 +00:00
self._wrapper_layer_norm(torch.nn.functional.layer_norm),
2024-02-12 03:54:32 +00:00
)
2024-02-09 08:21:07 +00:00
2024-02-21 00:49:44 +00:00
@contextmanager
def wrap_weight_biases(self, input, weight, bias):
if not is_same_device(input.device, devices.device):
yield (weight, bias)
return
2024-02-09 08:21:07 +00:00
2024-02-21 00:49:44 +00:00
moved = False
before_calc_event, after_calc_event = None, None
with stream_wrapper(stream=self.mover_stream):
if weight is not None and not is_same_device(weight.device, input.device):
weight = weight.to(device=input.device, copy=True, non_blocking=weight.is_pinned())
moved = True
if bias is not None and not is_same_device(bias.device, input.device):
bias = bias.to(device=input.device, copy=True, non_blocking=bias.is_pinned())
moved = True
before_calc_event = self.mover_stream.record_event()
if not moved:
yield (weight, bias)
return
2024-02-12 03:44:17 +00:00
2024-02-21 00:49:44 +00:00
with stream_wrapper(stream=self.calc_stream):
if before_calc_event is not None:
self.calc_stream.wait_event(before_calc_event)
yield (weight, bias)
after_calc_event = self.calc_stream.record_event()
self.stash[id(after_calc_event)] = (weight, bias, after_calc_event)
2024-02-12 03:44:17 +00:00
2024-02-21 00:49:44 +00:00
to_remove = []
2024-02-21 03:08:07 +00:00
for k, (_, _, e) in self.stash.items():
2024-02-21 00:49:44 +00:00
if e.query():
to_remove.append(k)
2024-02-12 03:44:17 +00:00
2024-02-21 00:49:44 +00:00
for k in to_remove:
del self.stash[k]
2024-02-12 03:44:17 +00:00
2024-02-21 06:22:45 +00:00
def _wrapper_default(self, original):
2024-02-25 07:42:10 +00:00
def wrapper(input, weight, bias=None, *args, **kwargs):
2024-02-21 06:22:45 +00:00
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, w, b, *args, **kwargs)
return wrapper
2024-02-12 03:54:32 +00:00
2024-02-21 06:22:45 +00:00
def _wrapper_group_norm(self, original):
2024-02-25 07:42:10 +00:00
def wrapper(input, num_groups, weight=None, bias=None, *args, **kwargs):
2024-02-21 06:22:45 +00:00
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, num_groups, w, b, *args, **kwargs)
return wrapper
2024-02-12 03:54:32 +00:00
2024-02-21 06:22:45 +00:00
def _wrapper_layer_norm(self, original):
2024-02-25 07:42:10 +00:00
def wrapper(input, normalized_shape, weight=None, bias=None, *args, **kwargs):
2024-02-21 06:22:45 +00:00
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, normalized_shape, w, b, *args, **kwargs)
return wrapper
2024-02-09 08:21:07 +00:00
def close(self):
2024-02-12 03:54:32 +00:00
patches.undo(__name__, torch.nn.functional, "linear")
patches.undo(__name__, torch.nn.functional, "conv2d")
patches.undo(__name__, torch.nn.functional, "conv3d")
patches.undo(__name__, torch.nn.functional, "group_norm")
patches.undo(__name__, torch.nn.functional, "layer_norm")
2024-02-09 08:21:07 +00:00
2024-02-21 00:49:44 +00:00
rtmover = None
2024-02-21 06:22:45 +00:00
if use_streamlined_lowvram:
2024-02-21 00:49:44 +00:00
rtmover = RTTensorMoverPatches()
2024-02-09 08:21:07 +00:00
2024-02-21 00:49:44 +00:00
def calc_wrapper():
if rtmover is not None:
return stream_wrapper(stream=rtmover.calc_stream)
return nullcontext()
def calc_sync():
if rtmover is not None:
return rtmover.calc_stream.synchronize()
return nullcontext()
2024-02-07 08:26:10 +00:00
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
2023-08-22 15:49:08 +00:00
def is_needed(sd_model):
2024-02-17 07:40:39 +00:00
return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')
2023-08-22 15:49:08 +00:00
def apply(sd_model):
enable = is_needed(sd_model)
shared.parallel_processing_allowed = not enable
if enable:
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
else:
sd_model.lowvram = False
def setup_for_low_vram(sd_model, use_medvram):
2024-02-17 07:40:39 +00:00
if getattr(sd_model, 'lowvram', False):
return
sd_model.lowvram = True
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
first_stage_model = sd_model.first_stage_model
first_stage_model_encode = sd_model.first_stage_model.encode
first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
2023-07-12 20:52:43 +00:00
to_remain_in_cpu = [
2024-02-17 07:40:39 +00:00
(sd_model, 'first_stage_model'),
(sd_model, 'depth_model'),
(sd_model, 'embedder'),
(sd_model, 'model'),
(sd_model, 'embedder'),
2023-07-12 20:52:43 +00:00
]
2024-02-17 07:40:39 +00:00
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
2023-07-12 20:52:43 +00:00
if is_sdxl:
2024-02-17 07:40:39 +00:00
to_remain_in_cpu.append((sd_model, 'conditioner'))
2023-07-12 20:52:43 +00:00
elif is_sd2:
2024-02-17 07:40:39 +00:00
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
2023-07-12 20:52:43 +00:00
else:
2024-02-17 07:40:39 +00:00
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
2023-07-12 20:52:43 +00:00
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
stored = []
for obj, field in to_remain_in_cpu:
module = getattr(obj, field, None)
stored.append(module)
setattr(obj, field, None)
# send the model to GPU.
sd_model.to(devices.device)
2023-07-12 20:52:43 +00:00
# put modules back. the modules will be in CPU.
for (obj, field), module in zip(to_remain_in_cpu, stored):
setattr(obj, field, module)
# register hooks for those the first three models
2023-07-12 20:52:43 +00:00
if is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
2024-02-17 07:40:39 +00:00
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
2024-02-17 07:40:39 +00:00
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
2023-07-12 20:52:43 +00:00
else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
2023-07-12 20:52:43 +00:00
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
2023-07-14 06:56:01 +00:00
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
2024-02-17 07:40:39 +00:00
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(devices.device)
2024-02-17 07:40:39 +00:00
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
2024-02-12 03:54:32 +00:00
2024-02-21 06:22:45 +00:00
if use_streamlined_lowvram:
2024-02-21 00:49:44 +00:00
# put it into pinned memory to achieve data transfer overlap
diff_model.time_embed._apply(lambda x: x.pin_memory())
for block in diff_model.input_blocks:
block._apply(lambda x: x.pin_memory())
diff_model.middle_block._apply(lambda x: x.pin_memory())
for block in diff_model.output_blocks:
block._apply(lambda x: x.pin_memory())
2024-02-12 03:54:32 +00:00
else:
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
def is_enabled(sd_model):
2023-08-22 15:49:08 +00:00
return sd_model.lowvram