diff --git a/modules/lowvram.py b/modules/lowvram.py index 019eb2788..f25fbefef 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager, nullcontext import torch from torch.utils.weak import WeakIdKeyDictionary from modules import devices, shared, patches @@ -9,11 +10,21 @@ stream_impl = devices.get_stream_impl() stream_wrapper = devices.get_stream_wrapper() -class SmartTensorMoverPatches: +def is_same_device(device1, device2): + tensor1_device_type = device1.type + tensor2_device_type = device2.type + tensor1_device_index = device1.index or 0 + tensor2_device_index = device2.index or 0 + return ( + tensor1_device_type == tensor2_device_type + and tensor1_device_index == tensor2_device_index + ) + +class RTTensorMoverPatches: def __init__(self): - self.memo = WeakIdKeyDictionary() - self.cleanup_memo = WeakIdKeyDictionary() - self.model_mover_stream = stream_impl(device=devices.device) + self.mover_stream = stream_impl(device=devices.device) + self.calc_stream = stream_impl(device=devices.device) + self.stash = {} self.linear_original = patches.patch( __name__, @@ -46,112 +57,58 @@ class SmartTensorMoverPatches: self.create_wrapper(torch.nn.functional.layer_norm, type=2), ) + @contextmanager + def wrap_weight_biases(self, input, weight, bias): + if not is_same_device(input.device, devices.device): + yield (weight, bias) + return + + moved = False + before_calc_event, after_calc_event = None, None + with stream_wrapper(stream=self.mover_stream): + if weight is not None and not is_same_device(weight.device, input.device): + weight = weight.to(device=input.device, copy=True, non_blocking=weight.is_pinned()) + moved = True + if bias is not None and not is_same_device(bias.device, input.device): + bias = bias.to(device=input.device, copy=True, non_blocking=bias.is_pinned()) + moved = True + before_calc_event = self.mover_stream.record_event() + + if not moved: + yield (weight, bias) + return + + with stream_wrapper(stream=self.calc_stream): + if before_calc_event is not None: + self.calc_stream.wait_event(before_calc_event) + yield (weight, bias) + after_calc_event = self.calc_stream.record_event() + self.stash[id(after_calc_event)] = (weight, bias, after_calc_event) + + to_remove = [] + for k, (w, b, e) in self.stash.items(): + if e.query(): + to_remove.append(k) + + for k in to_remove: + del self.stash[k] + def create_wrapper(self, original, type=1): if type == 2: def wrapper(input, arg1, weight, bias, *args, **kwargs): - record_cleanup_weight, record_cleanup_bias = False, False - current_stream = devices.get_current_stream() - - if weight is not None: - new_weight, weight_event = self.memo.get(weight, (None, None)) - if weight_event is not None: - weight = new_weight - record_cleanup_weight = True - current_stream.wait_event(weight_event) - - if bias is not None: - new_bias, bias_event = self.memo.get(bias, (None, None)) - if bias_event is not None: - bias = new_bias - record_cleanup_bias = True - current_stream.wait_event(bias_event) - - result = original(input, arg1, weight, bias, *args, **kwargs) - - if record_cleanup_weight: - self.cleanup_memo[weight] = current_stream.record_event() - - if record_cleanup_bias: - self.cleanup_memo[bias] = current_stream.record_event() - - return result + with self.wrap_weight_biases(input, weight, bias) as (w, b): + return original(input, arg1, w, b, *args, **kwargs) return wrapper else: def wrapper(input, weight, bias, *args, **kwargs): - record_cleanup_weight, record_cleanup_bias = False, False - current_stream = devices.get_current_stream() - - if weight is not None: - new_weight, weight_event = self.memo.get(weight, (None, None)) - if weight_event is not None: - weight = new_weight - record_cleanup_weight = True - current_stream.wait_event(weight_event) - - if bias is not None: - new_bias, bias_event = self.memo.get(bias, (None, None)) - if bias_event is not None: - bias = new_bias - record_cleanup_bias = True - current_stream.wait_event(bias_event) - - result = original(input, weight, bias, *args, **kwargs) - - if record_cleanup_weight: - self.cleanup_memo[weight] = current_stream.record_event() - - if record_cleanup_bias: - self.cleanup_memo[bias] = current_stream.record_event() - - return result + with self.wrap_weight_biases(input, weight, bias) as (w, b): + return original(input, w, b, *args, **kwargs) return wrapper - def __contains__(self, tensor): - return tensor in self.memo - - def move(self, tensor, device=None): - device = device or tensor.device - memo_tensor, memo_event = self.memo.get(tensor, (None, None)) - - if memo_tensor is not None: - return memo_tensor, memo_event - - with stream_wrapper(stream=self.model_mover_stream): - new_tensor = tensor.to(device=device, copy=True, non_blocking=True) - new_event = self.model_mover_stream.record_event() - self.memo[tensor] = (new_tensor, new_event) - - return self.memo[tensor] - - def _forget(self, tensor, tensor_on_device=None): - if tensor_on_device is not None: - tensor_used_event = self.cleanup_memo.get(tensor_on_device, None) - if tensor_used_event is not None: - self.model_mover_stream.wait_event(tensor_used_event) - self.cleanup_memo.pop(tensor_on_device, None) - del self.memo[tensor] - - def forget(self, tensor): - on_device_tensor = self.memo.get(tensor, None) - self._forget(tensor, on_device_tensor) - - def forget_batch(self, tensors): - for tensor in tensors: - tensor_on_device, _ = self.memo.get(tensor, (None, None)) - if tensor_on_device is not None: - self._forget(tensor, tensor_on_device) - - def forget_all(self): - for (tensor_on_device, _) in self.memo.values(): - if tensor_on_device in self.cleanup_memo: - self.model_mover_stream.wait_event(self.cleanup_memo[tensor_on_device]) - self.cleanup_memo.clear() - self.memo.clear() - def close(self): patches.undo(__name__, torch.nn.functional, "linear") patches.undo(__name__, torch.nn.functional, "conv2d") @@ -160,242 +117,21 @@ class SmartTensorMoverPatches: patches.undo(__name__, torch.nn.functional, "layer_norm") -mover = None +rtmover = None if stream_impl is not None and stream_wrapper is not None: - mover = SmartTensorMoverPatches() + rtmover = RTTensorMoverPatches() -class ModelMover: - @classmethod - def register(cls, model, max_prefetch=1): - instance = cls(model, max_prefetch) - model.lowvram_model_mover = instance - return instance - - def __init__(self, model, max_prefetch=1): - self.model = model - self.lookahead_distance = max_prefetch - self.hook_handles = [] - self.submodules_list = self.get_module_list() - - for c in self.submodules_list: - c._apply(lambda x: x.pin_memory()) - - self.submodules_indexer = {} - self.module_movement_events = {} - self.default_stream = devices.get_current_stream() - self.model_mover_stream = stream_impl(device=devices.device) - - for i, module in enumerate(self.submodules_list): - self.submodules_indexer[module] = i - - def get_module_list(self): - return [] - - def install(self): - for i in range(len(self.submodules_list)): - self.hook_handles.append( - self.submodules_list[i].register_forward_pre_hook( - self._pre_forward_hook - ) - ) - self.hook_handles.append( - self.submodules_list[i].register_forward_hook(self._post_forward_hook) - ) - - def uninstall(self): - for handle in self.hook_handles: - handle.remove() - - def _pre_forward_hook(self, module, _): - with stream_wrapper(stream=self.model_mover_stream): - idx = self.submodules_indexer[module] - for i in range(idx, idx + self.lookahead_distance): - submodule = self.submodules_list[i % len(self.submodules_list)] - if submodule in self.module_movement_events: - # already in GPU - continue - submodule.to(devices.device, non_blocking=True) - self.module_movement_events[submodule] = ( - self.model_mover_stream.record_event() - ) - - this_event = self.module_movement_events.get(module, None) - if this_event is not None: - self.default_stream.wait_event(this_event) - else: - print( - f"Module {module.__name__} was not moved to GPU. Taking slow path" - ) - submodule.to(devices.device, non_blocking=True) - - def _post_forward_hook(self, module, _1, _2): - with stream_wrapper(stream=self.model_mover_stream): - del self.module_movement_events[module] - module.to(cpu, non_blocking=True) +def calc_wrapper(): + if rtmover is not None: + return stream_wrapper(stream=rtmover.calc_stream) + return nullcontext() -class SmartModelMover: - @classmethod - def register(cls, model, vram_allowance=0, max_prefetch=10): - instance = cls(model, vram_allowance, max_prefetch) - model.lowvram_model_mover = instance - return instance - - def __init__(self, model, vram_allowance=0, max_prefetch=10): - self.model = model - self.vram_allowance = vram_allowance * 1024 * 1024 - self.vram_allowance_remaining = vram_allowance * 1024 * 1024 - self.max_prefetch = max_prefetch - self.hook_handles = [] - submodules_list = self.get_module_list() - - for c in submodules_list: - c._apply(lambda x: x.pin_memory()) - - self.submodules_list = [ - k for c in submodules_list for k in self.get_childrens(c) - ] - self.parameters_list = [ - list(x.parameters()) for x in self.submodules_list - ] - self.parameters_sizes = [ - sum([p.numel() * p.element_size() for p in x]) for x in self.parameters_list - ] - self.online_modules = set() - self.online_module_count = 0 - self.submodules_indexer = {} - - for i, module in enumerate(self.submodules_list): - self.submodules_indexer[module] = i - - def test_children(self, op): - return op.__class__.__name__ in [ - "Conv2d", - "Conv3d", - "Linear", - "GroupNorm", - "LayerNorm", - ] - - def get_childrens(self, container): - if isinstance(container, torch.nn.Sequential): - # return [c for c in container] - return [cc for c in container for cc in self.get_childrens(c)] - if "children" in dir(container): - childrens = [ - cc for c in container.children() for cc in self.get_childrens(c) - ] - if len(childrens) > 0: - return childrens - return [container] - - def drain_allowance(self, idx): - parameters_len = len(self.parameters_list) - - # no vram limitation is set - if self.vram_allowance <= 0: - # fetch up to max_prefetch parameters - while self.online_module_count < self.max_prefetch and self.online_module_count < parameters_len: - param = self.parameters_list[idx] - self.online_modules.add(idx) - self.online_module_count += 1 - yield param - idx = (idx + 1) % parameters_len - return - - # if there is still vram allowance, and it has not reached max_prefetch - while self.vram_allowance_remaining > 0 and ( - self.max_prefetch < 1 or self.online_module_count < self.max_prefetch - ) and self.online_module_count < parameters_len: - param = self.parameters_list[idx] - param_size = self.parameters_sizes[idx] - - # empty module or already online - if len(param) == 0 or idx in self.online_modules: - self.online_modules.add(idx) - self.online_module_count += 1 - idx = (idx + 1) % parameters_len - continue - - # if the parameter size is bigger than the remaining vram allowance, and there are already online modules - if ( - param_size > self.vram_allowance_remaining - and self.online_module_count > 0 - ): - return - self.vram_allowance_remaining -= param_size - self.online_modules.add(idx) - self.online_module_count += 1 - yield param - idx = (idx + 1) % parameters_len - - def fill_allowance(self, idx): - if self.vram_allowance > 0: - self.vram_allowance_remaining += self.parameters_sizes[idx] - self.online_modules.remove(idx) - self.online_module_count -= 1 - - def get_module_list(self): - return [] - - def install(self): - for submodule in self.submodules_list: - self.hook_handles.append( - submodule.register_forward_pre_hook(self._pre_forward_hook) - ) - self.hook_handles.append( - submodule.register_forward_hook(self._post_forward_hook) - ) - - def uninstall(self): - for handle in self.hook_handles: - handle.remove() - - for idx in self.online_modules: - mover.forget_batch(self.parameters_list[idx]) - - def preload(self): - idx = 0 - for parameters in self.drain_allowance(idx): - for param in parameters: - mover.move(param, device=devices.device) - - def _pre_forward_hook(self, module, *args, **kwargs): - idx = self.submodules_indexer[module] - for parameters in self.drain_allowance(idx): - for param in parameters: - mover.move(param, device=devices.device) - - def _post_forward_hook(self, module, *args, **kwargs): - idx = self.submodules_indexer[module] - - mover.forget_batch(self.parameters_list[idx]) - self.fill_allowance(idx) - - -class DiffModelMover(ModelMover): - def get_module_list(self): - modules = [] - modules.append(self.model.time_embed) - for block in self.model.input_blocks: - modules.append(block) - modules.append(self.model.middle_block) - for block in self.model.output_blocks: - modules.append(block) - return modules - - -class DiffSmartModelMover(SmartModelMover): - def get_module_list(self): - modules = [] - modules.append(self.model.time_embed) - for block in self.model.input_blocks: - modules.append(block) - modules.append(self.model.middle_block) - for block in self.model.output_blocks: - modules.append(block) - return modules +def calc_sync(): + if rtmover is not None: + return rtmover.calc_stream.synchronize() + return nullcontext() def send_everything_to_cpu(): @@ -530,9 +266,13 @@ def setup_for_low_vram(sd_model, use_medvram): # install hooks for bits of third model if stream_impl is not None and stream_wrapper is not None: - mp = DiffSmartModelMover.register(diff_model, vram_allowance=512, max_prefetch=70) - mp.install() - mp.preload() + # put it into pinned memory to achieve data transfer overlap + diff_model.time_embed._apply(lambda x: x.pin_memory()) + for block in diff_model.input_blocks: + block._apply(lambda x: x.pin_memory()) + diff_model.middle_block._apply(lambda x: x.pin_memory()) + for block in diff_model.output_blocks: + block._apply(lambda x: x.pin_memory()) else: diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) for block in diff_model.input_blocks: diff --git a/modules/processing.py b/modules/processing.py index 86194b057..e4a4ae907 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -782,7 +782,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio()) - res = process_images_inner(p) + with lowvram.calc_wrapper(): + res = process_images_inner(p) + lowvram.calc_sync() finally: sd_models.apply_token_merging(p.sd_model, 0) diff --git a/modules/sd_models.py b/modules/sd_models.py index b35aecbca..f6daa0b4c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -446,7 +446,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.fp16_weight = module.weight.data.clone().cpu().half() if module.bias is not None: module.fp16_bias = module.bias.data.clone().cpu().half() - module.to(torch.float8_e4m3fn) + module.to(torch.float8_e4m3fn)._apply(lambda x: x.pin_memory()) model.first_stage_model = first_stage timer.record("apply fp8") else: