From c1702ea4984d0788089b88b3f1760b54045375c0 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Wed, 7 Feb 2024 02:26:10 -0600 Subject: [PATCH] async weight mover --- modules/devices.py | 40 ++++++++++++++++++++ modules/lowvram.py | 92 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 126 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index e4f671ac6..2b92f97a1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -269,3 +269,43 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + + +def get_stream_impl(): + if torch.cuda.is_available(): + return torch.cuda.Stream + + if has_xpu(): + return torch.xpu.Stream + + return None + + +def get_stream_wrapper(): + if torch.cuda.is_available(): + return torch.cuda.stream + + if has_xpu(): + return torch.xpu.stream + + return None + + +def get_event_impl(): + if torch.cuda.is_available(): + return torch.cuda.Event + + if has_xpu(): + return torch.xpu.Event + + return None + + +def get_current_stream(): + if torch.cuda.is_available(): + return torch.cuda.current_stream(device) + + if has_xpu(): + return torch.xpu.current_stream(device) + + return None diff --git a/modules/lowvram.py b/modules/lowvram.py index 45701046b..7ecadd7aa 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -4,6 +4,77 @@ from modules import devices, shared module_in_gpu = None cpu = torch.device("cpu") +stream_impl = devices.get_stream_impl() +stream_wrapper = devices.get_stream_wrapper() + +class ModelMover: + @classmethod + def register(cls, model, lookahead_distance=1): + instance = cls(model, lookahead_distance) + setattr(model, 'lowvram_model_mover', instance) + return instance + + def __init__(self, model, lookahead_distance=1): + self.model = model + self.lookahead_distance = lookahead_distance + self.hook_handles = [] + self.submodules_list = self.get_module_list() + self.submodules_indexer = {} + self.module_movement_events = {} + self.default_stream = devices.get_current_stream() + self.model_mover_stream = stream_impl(device=devices.device) + + for i, module in enumerate(self.submodules_list): + self.submodules_indexer[module] = i + + def get_module_list(self): + return [] + + def install(self): + for i in range(len(self.submodules_list)): + self.hook_handles.append(self.submodules_list[i].register_forward_pre_hook(self._pre_forward_hook)) + self.hook_handles.append(self.submodules_list[i].register_forward_hook(self._post_forward_hook)) + + def uninstall(self): + for handle in self.hook_handles: + handle.remove() + + def _pre_forward_hook(self, module, _): + + with stream_wrapper(stream=self.model_mover_stream): + idx = self.submodules_indexer[module] + for i in range(idx, idx + self.lookahead_distance): + submodule = self.submodules_list[i % len(self.submodules_list)] + if submodule in self.module_movement_events: + # already in GPU + continue + submodule.to(devices.device, non_blocking=True) + self.module_movement_events[submodule] = self.model_mover_stream.record_event() + + this_event = self.module_movement_events.get(module, None) + if this_event is not None: + self.default_stream.wait_event(this_event) + else: + print(f"Module {module.__name__} was not moved to GPU. Taking slow path") + submodule.to(devices.device, non_blocking=True) + + def _post_forward_hook(self, module, _1, _2): + with stream_wrapper(stream=self.model_mover_stream): + del self.module_movement_events[module] + module.to(cpu, non_blocking=True) + + +class DiffModelMover(ModelMover): + def get_module_list(self): + modules = [] + modules.append(self.model.time_embed) + for block in self.model.input_blocks: + modules.append(block) + modules.append(self.model.middle_block) + for block in self.model.output_blocks: + modules.append(block) + return modules + def send_everything_to_cpu(): global module_in_gpu @@ -43,6 +114,19 @@ def setup_for_low_vram(sd_model, use_medvram): """ global module_in_gpu + try: + name = module._get_name() + except: + try: + name = module.__name__ + except: + try: + name = module.__class__.__name__ + except: + name = str(module) + + print(f"Moving {module.__module__}.{name} to GPU") + module = parents.get(module, module) if module_in_gpu == module: @@ -135,12 +219,8 @@ def setup_for_low_vram(sd_model, use_medvram): diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored # install hooks for bits of third model - diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) - for block in diff_model.input_blocks: - block.register_forward_pre_hook(send_me_to_gpu) - diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) - for block in diff_model.output_blocks: - block.register_forward_pre_hook(send_me_to_gpu) + mover = DiffModelMover.register(diff_model, lookahead_distance=8) + mover.install() def is_enabled(sd_model):