better impl

This commit is contained in:
wfjsw 2024-02-20 18:49:44 -06:00
parent 0caa75312e
commit fffc902698
3 changed files with 76 additions and 334 deletions

View File

@ -1,3 +1,4 @@
from contextlib import contextmanager, nullcontext
import torch
from torch.utils.weak import WeakIdKeyDictionary
from modules import devices, shared, patches
@ -9,11 +10,21 @@ stream_impl = devices.get_stream_impl()
stream_wrapper = devices.get_stream_wrapper()
class SmartTensorMoverPatches:
def is_same_device(device1, device2):
tensor1_device_type = device1.type
tensor2_device_type = device2.type
tensor1_device_index = device1.index or 0
tensor2_device_index = device2.index or 0
return (
tensor1_device_type == tensor2_device_type
and tensor1_device_index == tensor2_device_index
)
class RTTensorMoverPatches:
def __init__(self):
self.memo = WeakIdKeyDictionary()
self.cleanup_memo = WeakIdKeyDictionary()
self.model_mover_stream = stream_impl(device=devices.device)
self.mover_stream = stream_impl(device=devices.device)
self.calc_stream = stream_impl(device=devices.device)
self.stash = {}
self.linear_original = patches.patch(
__name__,
@ -46,112 +57,58 @@ class SmartTensorMoverPatches:
self.create_wrapper(torch.nn.functional.layer_norm, type=2),
)
@contextmanager
def wrap_weight_biases(self, input, weight, bias):
if not is_same_device(input.device, devices.device):
yield (weight, bias)
return
moved = False
before_calc_event, after_calc_event = None, None
with stream_wrapper(stream=self.mover_stream):
if weight is not None and not is_same_device(weight.device, input.device):
weight = weight.to(device=input.device, copy=True, non_blocking=weight.is_pinned())
moved = True
if bias is not None and not is_same_device(bias.device, input.device):
bias = bias.to(device=input.device, copy=True, non_blocking=bias.is_pinned())
moved = True
before_calc_event = self.mover_stream.record_event()
if not moved:
yield (weight, bias)
return
with stream_wrapper(stream=self.calc_stream):
if before_calc_event is not None:
self.calc_stream.wait_event(before_calc_event)
yield (weight, bias)
after_calc_event = self.calc_stream.record_event()
self.stash[id(after_calc_event)] = (weight, bias, after_calc_event)
to_remove = []
for k, (w, b, e) in self.stash.items():
if e.query():
to_remove.append(k)
for k in to_remove:
del self.stash[k]
def create_wrapper(self, original, type=1):
if type == 2:
def wrapper(input, arg1, weight, bias, *args, **kwargs):
record_cleanup_weight, record_cleanup_bias = False, False
current_stream = devices.get_current_stream()
if weight is not None:
new_weight, weight_event = self.memo.get(weight, (None, None))
if weight_event is not None:
weight = new_weight
record_cleanup_weight = True
current_stream.wait_event(weight_event)
if bias is not None:
new_bias, bias_event = self.memo.get(bias, (None, None))
if bias_event is not None:
bias = new_bias
record_cleanup_bias = True
current_stream.wait_event(bias_event)
result = original(input, arg1, weight, bias, *args, **kwargs)
if record_cleanup_weight:
self.cleanup_memo[weight] = current_stream.record_event()
if record_cleanup_bias:
self.cleanup_memo[bias] = current_stream.record_event()
return result
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, arg1, w, b, *args, **kwargs)
return wrapper
else:
def wrapper(input, weight, bias, *args, **kwargs):
record_cleanup_weight, record_cleanup_bias = False, False
current_stream = devices.get_current_stream()
if weight is not None:
new_weight, weight_event = self.memo.get(weight, (None, None))
if weight_event is not None:
weight = new_weight
record_cleanup_weight = True
current_stream.wait_event(weight_event)
if bias is not None:
new_bias, bias_event = self.memo.get(bias, (None, None))
if bias_event is not None:
bias = new_bias
record_cleanup_bias = True
current_stream.wait_event(bias_event)
result = original(input, weight, bias, *args, **kwargs)
if record_cleanup_weight:
self.cleanup_memo[weight] = current_stream.record_event()
if record_cleanup_bias:
self.cleanup_memo[bias] = current_stream.record_event()
return result
with self.wrap_weight_biases(input, weight, bias) as (w, b):
return original(input, w, b, *args, **kwargs)
return wrapper
def __contains__(self, tensor):
return tensor in self.memo
def move(self, tensor, device=None):
device = device or tensor.device
memo_tensor, memo_event = self.memo.get(tensor, (None, None))
if memo_tensor is not None:
return memo_tensor, memo_event
with stream_wrapper(stream=self.model_mover_stream):
new_tensor = tensor.to(device=device, copy=True, non_blocking=True)
new_event = self.model_mover_stream.record_event()
self.memo[tensor] = (new_tensor, new_event)
return self.memo[tensor]
def _forget(self, tensor, tensor_on_device=None):
if tensor_on_device is not None:
tensor_used_event = self.cleanup_memo.get(tensor_on_device, None)
if tensor_used_event is not None:
self.model_mover_stream.wait_event(tensor_used_event)
self.cleanup_memo.pop(tensor_on_device, None)
del self.memo[tensor]
def forget(self, tensor):
on_device_tensor = self.memo.get(tensor, None)
self._forget(tensor, on_device_tensor)
def forget_batch(self, tensors):
for tensor in tensors:
tensor_on_device, _ = self.memo.get(tensor, (None, None))
if tensor_on_device is not None:
self._forget(tensor, tensor_on_device)
def forget_all(self):
for (tensor_on_device, _) in self.memo.values():
if tensor_on_device in self.cleanup_memo:
self.model_mover_stream.wait_event(self.cleanup_memo[tensor_on_device])
self.cleanup_memo.clear()
self.memo.clear()
def close(self):
patches.undo(__name__, torch.nn.functional, "linear")
patches.undo(__name__, torch.nn.functional, "conv2d")
@ -160,242 +117,21 @@ class SmartTensorMoverPatches:
patches.undo(__name__, torch.nn.functional, "layer_norm")
mover = None
rtmover = None
if stream_impl is not None and stream_wrapper is not None:
mover = SmartTensorMoverPatches()
rtmover = RTTensorMoverPatches()
class ModelMover:
@classmethod
def register(cls, model, max_prefetch=1):
instance = cls(model, max_prefetch)
model.lowvram_model_mover = instance
return instance
def __init__(self, model, max_prefetch=1):
self.model = model
self.lookahead_distance = max_prefetch
self.hook_handles = []
self.submodules_list = self.get_module_list()
for c in self.submodules_list:
c._apply(lambda x: x.pin_memory())
self.submodules_indexer = {}
self.module_movement_events = {}
self.default_stream = devices.get_current_stream()
self.model_mover_stream = stream_impl(device=devices.device)
for i, module in enumerate(self.submodules_list):
self.submodules_indexer[module] = i
def get_module_list(self):
return []
def install(self):
for i in range(len(self.submodules_list)):
self.hook_handles.append(
self.submodules_list[i].register_forward_pre_hook(
self._pre_forward_hook
)
)
self.hook_handles.append(
self.submodules_list[i].register_forward_hook(self._post_forward_hook)
)
def uninstall(self):
for handle in self.hook_handles:
handle.remove()
def _pre_forward_hook(self, module, _):
with stream_wrapper(stream=self.model_mover_stream):
idx = self.submodules_indexer[module]
for i in range(idx, idx + self.lookahead_distance):
submodule = self.submodules_list[i % len(self.submodules_list)]
if submodule in self.module_movement_events:
# already in GPU
continue
submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = (
self.model_mover_stream.record_event()
)
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(
f"Module {module.__name__} was not moved to GPU. Taking slow path"
)
submodule.to(devices.device, non_blocking=True)
def _post_forward_hook(self, module, _1, _2):
with stream_wrapper(stream=self.model_mover_stream):
del self.module_movement_events[module]
module.to(cpu, non_blocking=True)
def calc_wrapper():
if rtmover is not None:
return stream_wrapper(stream=rtmover.calc_stream)
return nullcontext()
class SmartModelMover:
@classmethod
def register(cls, model, vram_allowance=0, max_prefetch=10):
instance = cls(model, vram_allowance, max_prefetch)
model.lowvram_model_mover = instance
return instance
def __init__(self, model, vram_allowance=0, max_prefetch=10):
self.model = model
self.vram_allowance = vram_allowance * 1024 * 1024
self.vram_allowance_remaining = vram_allowance * 1024 * 1024
self.max_prefetch = max_prefetch
self.hook_handles = []
submodules_list = self.get_module_list()
for c in submodules_list:
c._apply(lambda x: x.pin_memory())
self.submodules_list = [
k for c in submodules_list for k in self.get_childrens(c)
]
self.parameters_list = [
list(x.parameters()) for x in self.submodules_list
]
self.parameters_sizes = [
sum([p.numel() * p.element_size() for p in x]) for x in self.parameters_list
]
self.online_modules = set()
self.online_module_count = 0
self.submodules_indexer = {}
for i, module in enumerate(self.submodules_list):
self.submodules_indexer[module] = i
def test_children(self, op):
return op.__class__.__name__ in [
"Conv2d",
"Conv3d",
"Linear",
"GroupNorm",
"LayerNorm",
]
def get_childrens(self, container):
if isinstance(container, torch.nn.Sequential):
# return [c for c in container]
return [cc for c in container for cc in self.get_childrens(c)]
if "children" in dir(container):
childrens = [
cc for c in container.children() for cc in self.get_childrens(c)
]
if len(childrens) > 0:
return childrens
return [container]
def drain_allowance(self, idx):
parameters_len = len(self.parameters_list)
# no vram limitation is set
if self.vram_allowance <= 0:
# fetch up to max_prefetch parameters
while self.online_module_count < self.max_prefetch and self.online_module_count < parameters_len:
param = self.parameters_list[idx]
self.online_modules.add(idx)
self.online_module_count += 1
yield param
idx = (idx + 1) % parameters_len
return
# if there is still vram allowance, and it has not reached max_prefetch
while self.vram_allowance_remaining > 0 and (
self.max_prefetch < 1 or self.online_module_count < self.max_prefetch
) and self.online_module_count < parameters_len:
param = self.parameters_list[idx]
param_size = self.parameters_sizes[idx]
# empty module or already online
if len(param) == 0 or idx in self.online_modules:
self.online_modules.add(idx)
self.online_module_count += 1
idx = (idx + 1) % parameters_len
continue
# if the parameter size is bigger than the remaining vram allowance, and there are already online modules
if (
param_size > self.vram_allowance_remaining
and self.online_module_count > 0
):
return
self.vram_allowance_remaining -= param_size
self.online_modules.add(idx)
self.online_module_count += 1
yield param
idx = (idx + 1) % parameters_len
def fill_allowance(self, idx):
if self.vram_allowance > 0:
self.vram_allowance_remaining += self.parameters_sizes[idx]
self.online_modules.remove(idx)
self.online_module_count -= 1
def get_module_list(self):
return []
def install(self):
for submodule in self.submodules_list:
self.hook_handles.append(
submodule.register_forward_pre_hook(self._pre_forward_hook)
)
self.hook_handles.append(
submodule.register_forward_hook(self._post_forward_hook)
)
def uninstall(self):
for handle in self.hook_handles:
handle.remove()
for idx in self.online_modules:
mover.forget_batch(self.parameters_list[idx])
def preload(self):
idx = 0
for parameters in self.drain_allowance(idx):
for param in parameters:
mover.move(param, device=devices.device)
def _pre_forward_hook(self, module, *args, **kwargs):
idx = self.submodules_indexer[module]
for parameters in self.drain_allowance(idx):
for param in parameters:
mover.move(param, device=devices.device)
def _post_forward_hook(self, module, *args, **kwargs):
idx = self.submodules_indexer[module]
mover.forget_batch(self.parameters_list[idx])
self.fill_allowance(idx)
class DiffModelMover(ModelMover):
def get_module_list(self):
modules = []
modules.append(self.model.time_embed)
for block in self.model.input_blocks:
modules.append(block)
modules.append(self.model.middle_block)
for block in self.model.output_blocks:
modules.append(block)
return modules
class DiffSmartModelMover(SmartModelMover):
def get_module_list(self):
modules = []
modules.append(self.model.time_embed)
for block in self.model.input_blocks:
modules.append(block)
modules.append(self.model.middle_block)
for block in self.model.output_blocks:
modules.append(block)
return modules
def calc_sync():
if rtmover is not None:
return rtmover.calc_stream.synchronize()
return nullcontext()
def send_everything_to_cpu():
@ -530,9 +266,13 @@ def setup_for_low_vram(sd_model, use_medvram):
# install hooks for bits of third model
if stream_impl is not None and stream_wrapper is not None:
mp = DiffSmartModelMover.register(diff_model, vram_allowance=512, max_prefetch=70)
mp.install()
mp.preload()
# put it into pinned memory to achieve data transfer overlap
diff_model.time_embed._apply(lambda x: x.pin_memory())
for block in diff_model.input_blocks:
block._apply(lambda x: x.pin_memory())
diff_model.middle_block._apply(lambda x: x.pin_memory())
for block in diff_model.output_blocks:
block._apply(lambda x: x.pin_memory())
else:
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:

View File

@ -782,7 +782,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
res = process_images_inner(p)
with lowvram.calc_wrapper():
res = process_images_inner(p)
lowvram.calc_sync()
finally:
sd_models.apply_token_merging(p.sd_model, 0)

View File

@ -446,7 +446,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
module.to(torch.float8_e4m3fn)._apply(lambda x: x.pin_memory())
model.first_stage_model = first_stage
timer.record("apply fp8")
else: