stable-diffusion-webui/modules/lowvram.py
2024-03-09 19:12:42 -10:00

573 lines
20 KiB
Python

import torch
from torch.utils.weak import WeakIdKeyDictionary
from modules import devices, shared, patches
module_in_gpu = None
cpu = torch.device("cpu")
stream_impl = devices.get_stream_impl()
stream_wrapper = devices.get_stream_wrapper()
class SmartTensorMoverPatches:
def __init__(self):
self.memo = WeakIdKeyDictionary()
self.cleanup_memo = WeakIdKeyDictionary()
self.model_mover_stream = stream_impl(device=devices.device)
self.linear_original = patches.patch(
__name__,
torch.nn.functional,
"linear",
self.create_wrapper(torch.nn.functional.linear),
)
self.conv2d_original = patches.patch(
__name__,
torch.nn.functional,
"conv2d",
self.create_wrapper(torch.nn.functional.conv2d),
)
self.conv3d_original = patches.patch(
__name__,
torch.nn.functional,
"conv3d",
self.create_wrapper(torch.nn.functional.conv3d),
)
self.group_norm_original = patches.patch(
__name__,
torch.nn.functional,
"group_norm",
self.create_wrapper(torch.nn.functional.group_norm, type=2),
)
self.layer_norm_original = patches.patch(
__name__,
torch.nn.functional,
"layer_norm",
self.create_wrapper(torch.nn.functional.layer_norm, type=2),
)
def create_wrapper(self, original, type=1):
if type == 2:
def wrapper(input, arg1, weight, bias, *args, **kwargs):
record_cleanup_weight, record_cleanup_bias = False, False
current_stream = devices.get_current_stream()
if weight is not None:
new_weight, weight_event = self.memo.get(weight, (None, None))
if weight_event is not None:
weight = new_weight
record_cleanup_weight = True
current_stream.wait_event(weight_event)
if bias is not None:
new_bias, bias_event = self.memo.get(bias, (None, None))
if bias_event is not None:
bias = new_bias
record_cleanup_bias = True
current_stream.wait_event(bias_event)
result = original(input, arg1, weight, bias, *args, **kwargs)
if record_cleanup_weight:
self.cleanup_memo[weight] = current_stream.record_event()
if record_cleanup_bias:
self.cleanup_memo[bias] = current_stream.record_event()
return result
return wrapper
else:
def wrapper(input, weight, bias, *args, **kwargs):
record_cleanup_weight, record_cleanup_bias = False, False
current_stream = devices.get_current_stream()
if weight is not None:
new_weight, weight_event = self.memo.get(weight, (None, None))
if weight_event is not None:
weight = new_weight
record_cleanup_weight = True
current_stream.wait_event(weight_event)
if bias is not None:
new_bias, bias_event = self.memo.get(bias, (None, None))
if bias_event is not None:
bias = new_bias
record_cleanup_bias = True
current_stream.wait_event(bias_event)
result = original(input, weight, bias, *args, **kwargs)
if record_cleanup_weight:
self.cleanup_memo[weight] = current_stream.record_event()
if record_cleanup_bias:
self.cleanup_memo[bias] = current_stream.record_event()
return result
return wrapper
def __contains__(self, tensor):
return tensor in self.memo
def move(self, tensor, device=None):
device = device or tensor.device
memo_tensor, memo_event = self.memo.get(tensor, (None, None))
if memo_tensor is not None:
return memo_tensor, memo_event
with stream_wrapper(stream=self.model_mover_stream):
new_tensor = tensor.to(device=device, copy=True, non_blocking=True)
new_event = self.model_mover_stream.record_event()
self.memo[tensor] = (new_tensor, new_event)
return self.memo[tensor]
def _forget(self, tensor, tensor_on_device=None):
if tensor_on_device is not None:
tensor_used_event = self.cleanup_memo.get(tensor_on_device, None)
if tensor_used_event is not None:
self.model_mover_stream.wait_event(tensor_used_event)
self.cleanup_memo.pop(tensor_on_device, None)
del self.memo[tensor]
def forget(self, tensor):
on_device_tensor = self.memo.get(tensor, None)
self._forget(tensor, on_device_tensor)
def forget_batch(self, tensors):
for tensor in tensors:
tensor_on_device, _ = self.memo.get(tensor, (None, None))
if tensor_on_device is not None:
self._forget(tensor, tensor_on_device)
def forget_all(self):
for (tensor_on_device, _) in self.memo.values():
if tensor_on_device in self.cleanup_memo:
self.model_mover_stream.wait_event(self.cleanup_memo[tensor_on_device])
self.cleanup_memo.clear()
self.memo.clear()
def close(self):
patches.undo(__name__, torch.nn.functional, "linear")
patches.undo(__name__, torch.nn.functional, "conv2d")
patches.undo(__name__, torch.nn.functional, "conv3d")
patches.undo(__name__, torch.nn.functional, "group_norm")
patches.undo(__name__, torch.nn.functional, "layer_norm")
mover = None
if stream_impl is not None and stream_wrapper is not None:
mover = SmartTensorMoverPatches()
class ModelMover:
@classmethod
def register(cls, model, max_prefetch=1):
instance = cls(model, max_prefetch)
model.lowvram_model_mover = instance
return instance
def __init__(self, model, max_prefetch=1):
self.model = model
self.lookahead_distance = max_prefetch
self.hook_handles = []
self.submodules_list = self.get_module_list()
for c in self.submodules_list:
c._apply(lambda x: x.pin_memory())
self.submodules_indexer = {}
self.module_movement_events = {}
self.default_stream = devices.get_current_stream()
self.model_mover_stream = stream_impl(device=devices.device)
for i, module in enumerate(self.submodules_list):
self.submodules_indexer[module] = i
def get_module_list(self):
return []
def install(self):
for i in range(len(self.submodules_list)):
self.hook_handles.append(
self.submodules_list[i].register_forward_pre_hook(
self._pre_forward_hook
)
)
self.hook_handles.append(
self.submodules_list[i].register_forward_hook(self._post_forward_hook)
)
def uninstall(self):
for handle in self.hook_handles:
handle.remove()
def _pre_forward_hook(self, module, _):
with stream_wrapper(stream=self.model_mover_stream):
idx = self.submodules_indexer[module]
for i in range(idx, idx + self.lookahead_distance):
submodule = self.submodules_list[i % len(self.submodules_list)]
if submodule in self.module_movement_events:
# already in GPU
continue
submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = (
self.model_mover_stream.record_event()
)
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(
f"Module {module.__name__} was not moved to GPU. Taking slow path"
)
submodule.to(devices.device, non_blocking=True)
def _post_forward_hook(self, module, _1, _2):
with stream_wrapper(stream=self.model_mover_stream):
del self.module_movement_events[module]
module.to(cpu, non_blocking=True)
class SmartModelMover:
@classmethod
def register(cls, model, vram_allowance=0, max_prefetch=10):
instance = cls(model, vram_allowance, max_prefetch)
model.lowvram_model_mover = instance
return instance
def __init__(self, model, vram_allowance=0, max_prefetch=10):
self.model = model
self.vram_allowance = vram_allowance * 1024 * 1024
self.vram_allowance_remaining = vram_allowance * 1024 * 1024
self.max_prefetch = max_prefetch
self.hook_handles = []
submodules_list = self.get_module_list()
for c in submodules_list:
c._apply(lambda x: x.pin_memory())
self.submodules_list = [
k for c in submodules_list for k in self.get_childrens(c)
]
self.parameters_list = [
list(x.parameters()) for x in self.submodules_list
]
self.parameters_sizes = [
sum([p.numel() * p.element_size() for p in x]) for x in self.parameters_list
]
self.online_modules = set()
self.online_module_count = 0
self.submodules_indexer = {}
for i, module in enumerate(self.submodules_list):
self.submodules_indexer[module] = i
def test_children(self, op):
return op.__class__.__name__ in [
"Conv2d",
"Conv3d",
"Linear",
"GroupNorm",
"LayerNorm",
]
def get_childrens(self, container):
if isinstance(container, torch.nn.Sequential):
# return [c for c in container]
return [cc for c in container for cc in self.get_childrens(c)]
if "children" in dir(container):
childrens = [
cc for c in container.children() for cc in self.get_childrens(c)
]
if len(childrens) > 0:
return childrens
return [container]
def drain_allowance(self, idx):
parameters_len = len(self.parameters_list)
# no vram limitation is set
if self.vram_allowance <= 0:
# fetch up to max_prefetch parameters
while self.online_module_count < self.max_prefetch and self.online_module_count < parameters_len:
param = self.parameters_list[idx]
self.online_modules.add(idx)
self.online_module_count += 1
yield param
idx = (idx + 1) % parameters_len
return
# if there is still vram allowance, and it has not reached max_prefetch
while self.vram_allowance_remaining > 0 and (
self.max_prefetch < 1 or self.online_module_count < self.max_prefetch
) and self.online_module_count < parameters_len:
param = self.parameters_list[idx]
param_size = self.parameters_sizes[idx]
# empty module or already online
if len(param) == 0 or idx in self.online_modules:
self.online_modules.add(idx)
self.online_module_count += 1
idx = (idx + 1) % parameters_len
continue
# if the parameter size is bigger than the remaining vram allowance, and there are already online modules
if (
param_size > self.vram_allowance_remaining
and self.online_module_count > 0
):
return
self.vram_allowance_remaining -= param_size
self.online_modules.add(idx)
self.online_module_count += 1
yield param
idx = (idx + 1) % parameters_len
def fill_allowance(self, idx):
if self.vram_allowance > 0:
self.vram_allowance_remaining += self.parameters_sizes[idx]
self.online_modules.remove(idx)
self.online_module_count -= 1
def get_module_list(self):
return []
def install(self):
for submodule in self.submodules_list:
self.hook_handles.append(
submodule.register_forward_pre_hook(self._pre_forward_hook)
)
self.hook_handles.append(
submodule.register_forward_hook(self._post_forward_hook)
)
def uninstall(self):
for handle in self.hook_handles:
handle.remove()
for idx in self.online_modules:
mover.forget_batch(self.parameters_list[idx])
def preload(self):
idx = 0
for parameters in self.drain_allowance(idx):
for param in parameters:
mover.move(param, device=devices.device)
def _pre_forward_hook(self, module, *args, **kwargs):
idx = self.submodules_indexer[module]
for parameters in self.drain_allowance(idx):
for param in parameters:
mover.move(param, device=devices.device)
def _post_forward_hook(self, module, *args, **kwargs):
idx = self.submodules_indexer[module]
mover.forget_batch(self.parameters_list[idx])
self.fill_allowance(idx)
class DiffModelMover(ModelMover):
def get_module_list(self):
modules = []
modules.append(self.model.time_embed)
for block in self.model.input_blocks:
modules.append(block)
modules.append(self.model.middle_block)
for block in self.model.output_blocks:
modules.append(block)
return modules
class DiffSmartModelMover(SmartModelMover):
def get_module_list(self):
modules = []
modules.append(self.model.time_embed)
for block in self.model.input_blocks:
modules.append(block)
modules.append(self.model.middle_block)
for block in self.model.output_blocks:
modules.append(block)
return modules
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
mover.forget_all()
def is_needed(sd_model):
return (
shared.cmd_opts.lowvram
or shared.cmd_opts.medvram
or shared.cmd_opts.medvram_sdxl
and hasattr(sd_model, "conditioner")
)
def apply(sd_model):
enable = is_needed(sd_model)
shared.parallel_processing_allowed = not enable
if enable:
setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
else:
sd_model.lowvram = False
def setup_for_low_vram(sd_model, use_medvram):
if getattr(sd_model, "lowvram", False):
return
sd_model.lowvram = True
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
first_stage_model = sd_model.first_stage_model
first_stage_model_encode = sd_model.first_stage_model.encode
first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
to_remain_in_cpu = [
(sd_model, "first_stage_model"),
(sd_model, "depth_model"),
(sd_model, "embedder"),
(sd_model, "model"),
(sd_model, "embedder"),
]
is_sdxl = hasattr(sd_model, "conditioner")
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, "model")
if is_sdxl:
to_remain_in_cpu.append((sd_model, "conditioner"))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, "model"))
else:
to_remain_in_cpu.append((sd_model.cond_stage_model, "transformer"))
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
stored = []
for obj, field in to_remain_in_cpu:
module = getattr(obj, field, None)
stored.append(module)
setattr(obj, field, None)
# send the model to GPU.
sd_model.to(devices.device)
# put modules back. the modules will be in CPU.
for (obj, field), module in zip(to_remain_in_cpu, stored):
setattr(obj, field, module)
# register hooks for those the first three models
if is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(
send_me_to_gpu
)
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
parents[sd_model.cond_stage_model.model.token_embedding] = (
sd_model.cond_stage_model
)
else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = (
diff_model.input_blocks,
diff_model.middle_block,
diff_model.output_blocks,
diff_model.time_embed,
)
(
diff_model.input_blocks,
diff_model.middle_block,
diff_model.output_blocks,
diff_model.time_embed,
) = (None, None, None, None)
sd_model.model.to(devices.device)
(
diff_model.input_blocks,
diff_model.middle_block,
diff_model.output_blocks,
diff_model.time_embed,
) = stored
# install hooks for bits of third model
if stream_impl is not None and stream_wrapper is not None:
mp = DiffSmartModelMover.register(diff_model, vram_allowance=512, max_prefetch=70)
mp.install()
mp.preload()
else:
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
def is_enabled(sd_model):
return sd_model.lowvram