mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
better impl
This commit is contained in:
parent
0caa75312e
commit
fffc902698
@ -1,3 +1,4 @@
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import torch
|
||||
from torch.utils.weak import WeakIdKeyDictionary
|
||||
from modules import devices, shared, patches
|
||||
@ -9,11 +10,21 @@ stream_impl = devices.get_stream_impl()
|
||||
stream_wrapper = devices.get_stream_wrapper()
|
||||
|
||||
|
||||
class SmartTensorMoverPatches:
|
||||
def is_same_device(device1, device2):
|
||||
tensor1_device_type = device1.type
|
||||
tensor2_device_type = device2.type
|
||||
tensor1_device_index = device1.index or 0
|
||||
tensor2_device_index = device2.index or 0
|
||||
return (
|
||||
tensor1_device_type == tensor2_device_type
|
||||
and tensor1_device_index == tensor2_device_index
|
||||
)
|
||||
|
||||
class RTTensorMoverPatches:
|
||||
def __init__(self):
|
||||
self.memo = WeakIdKeyDictionary()
|
||||
self.cleanup_memo = WeakIdKeyDictionary()
|
||||
self.model_mover_stream = stream_impl(device=devices.device)
|
||||
self.mover_stream = stream_impl(device=devices.device)
|
||||
self.calc_stream = stream_impl(device=devices.device)
|
||||
self.stash = {}
|
||||
|
||||
self.linear_original = patches.patch(
|
||||
__name__,
|
||||
@ -46,112 +57,58 @@ class SmartTensorMoverPatches:
|
||||
self.create_wrapper(torch.nn.functional.layer_norm, type=2),
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def wrap_weight_biases(self, input, weight, bias):
|
||||
if not is_same_device(input.device, devices.device):
|
||||
yield (weight, bias)
|
||||
return
|
||||
|
||||
moved = False
|
||||
before_calc_event, after_calc_event = None, None
|
||||
with stream_wrapper(stream=self.mover_stream):
|
||||
if weight is not None and not is_same_device(weight.device, input.device):
|
||||
weight = weight.to(device=input.device, copy=True, non_blocking=weight.is_pinned())
|
||||
moved = True
|
||||
if bias is not None and not is_same_device(bias.device, input.device):
|
||||
bias = bias.to(device=input.device, copy=True, non_blocking=bias.is_pinned())
|
||||
moved = True
|
||||
before_calc_event = self.mover_stream.record_event()
|
||||
|
||||
if not moved:
|
||||
yield (weight, bias)
|
||||
return
|
||||
|
||||
with stream_wrapper(stream=self.calc_stream):
|
||||
if before_calc_event is not None:
|
||||
self.calc_stream.wait_event(before_calc_event)
|
||||
yield (weight, bias)
|
||||
after_calc_event = self.calc_stream.record_event()
|
||||
self.stash[id(after_calc_event)] = (weight, bias, after_calc_event)
|
||||
|
||||
to_remove = []
|
||||
for k, (w, b, e) in self.stash.items():
|
||||
if e.query():
|
||||
to_remove.append(k)
|
||||
|
||||
for k in to_remove:
|
||||
del self.stash[k]
|
||||
|
||||
def create_wrapper(self, original, type=1):
|
||||
if type == 2:
|
||||
|
||||
def wrapper(input, arg1, weight, bias, *args, **kwargs):
|
||||
record_cleanup_weight, record_cleanup_bias = False, False
|
||||
current_stream = devices.get_current_stream()
|
||||
|
||||
if weight is not None:
|
||||
new_weight, weight_event = self.memo.get(weight, (None, None))
|
||||
if weight_event is not None:
|
||||
weight = new_weight
|
||||
record_cleanup_weight = True
|
||||
current_stream.wait_event(weight_event)
|
||||
|
||||
if bias is not None:
|
||||
new_bias, bias_event = self.memo.get(bias, (None, None))
|
||||
if bias_event is not None:
|
||||
bias = new_bias
|
||||
record_cleanup_bias = True
|
||||
current_stream.wait_event(bias_event)
|
||||
|
||||
result = original(input, arg1, weight, bias, *args, **kwargs)
|
||||
|
||||
if record_cleanup_weight:
|
||||
self.cleanup_memo[weight] = current_stream.record_event()
|
||||
|
||||
if record_cleanup_bias:
|
||||
self.cleanup_memo[bias] = current_stream.record_event()
|
||||
|
||||
return result
|
||||
with self.wrap_weight_biases(input, weight, bias) as (w, b):
|
||||
return original(input, arg1, w, b, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
|
||||
def wrapper(input, weight, bias, *args, **kwargs):
|
||||
record_cleanup_weight, record_cleanup_bias = False, False
|
||||
current_stream = devices.get_current_stream()
|
||||
|
||||
if weight is not None:
|
||||
new_weight, weight_event = self.memo.get(weight, (None, None))
|
||||
if weight_event is not None:
|
||||
weight = new_weight
|
||||
record_cleanup_weight = True
|
||||
current_stream.wait_event(weight_event)
|
||||
|
||||
if bias is not None:
|
||||
new_bias, bias_event = self.memo.get(bias, (None, None))
|
||||
if bias_event is not None:
|
||||
bias = new_bias
|
||||
record_cleanup_bias = True
|
||||
current_stream.wait_event(bias_event)
|
||||
|
||||
result = original(input, weight, bias, *args, **kwargs)
|
||||
|
||||
if record_cleanup_weight:
|
||||
self.cleanup_memo[weight] = current_stream.record_event()
|
||||
|
||||
if record_cleanup_bias:
|
||||
self.cleanup_memo[bias] = current_stream.record_event()
|
||||
|
||||
return result
|
||||
with self.wrap_weight_biases(input, weight, bias) as (w, b):
|
||||
return original(input, w, b, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
def __contains__(self, tensor):
|
||||
return tensor in self.memo
|
||||
|
||||
def move(self, tensor, device=None):
|
||||
device = device or tensor.device
|
||||
memo_tensor, memo_event = self.memo.get(tensor, (None, None))
|
||||
|
||||
if memo_tensor is not None:
|
||||
return memo_tensor, memo_event
|
||||
|
||||
with stream_wrapper(stream=self.model_mover_stream):
|
||||
new_tensor = tensor.to(device=device, copy=True, non_blocking=True)
|
||||
new_event = self.model_mover_stream.record_event()
|
||||
self.memo[tensor] = (new_tensor, new_event)
|
||||
|
||||
return self.memo[tensor]
|
||||
|
||||
def _forget(self, tensor, tensor_on_device=None):
|
||||
if tensor_on_device is not None:
|
||||
tensor_used_event = self.cleanup_memo.get(tensor_on_device, None)
|
||||
if tensor_used_event is not None:
|
||||
self.model_mover_stream.wait_event(tensor_used_event)
|
||||
self.cleanup_memo.pop(tensor_on_device, None)
|
||||
del self.memo[tensor]
|
||||
|
||||
def forget(self, tensor):
|
||||
on_device_tensor = self.memo.get(tensor, None)
|
||||
self._forget(tensor, on_device_tensor)
|
||||
|
||||
def forget_batch(self, tensors):
|
||||
for tensor in tensors:
|
||||
tensor_on_device, _ = self.memo.get(tensor, (None, None))
|
||||
if tensor_on_device is not None:
|
||||
self._forget(tensor, tensor_on_device)
|
||||
|
||||
def forget_all(self):
|
||||
for (tensor_on_device, _) in self.memo.values():
|
||||
if tensor_on_device in self.cleanup_memo:
|
||||
self.model_mover_stream.wait_event(self.cleanup_memo[tensor_on_device])
|
||||
self.cleanup_memo.clear()
|
||||
self.memo.clear()
|
||||
|
||||
def close(self):
|
||||
patches.undo(__name__, torch.nn.functional, "linear")
|
||||
patches.undo(__name__, torch.nn.functional, "conv2d")
|
||||
@ -160,242 +117,21 @@ class SmartTensorMoverPatches:
|
||||
patches.undo(__name__, torch.nn.functional, "layer_norm")
|
||||
|
||||
|
||||
mover = None
|
||||
rtmover = None
|
||||
if stream_impl is not None and stream_wrapper is not None:
|
||||
mover = SmartTensorMoverPatches()
|
||||
rtmover = RTTensorMoverPatches()
|
||||
|
||||
|
||||
class ModelMover:
|
||||
@classmethod
|
||||
def register(cls, model, max_prefetch=1):
|
||||
instance = cls(model, max_prefetch)
|
||||
model.lowvram_model_mover = instance
|
||||
return instance
|
||||
|
||||
def __init__(self, model, max_prefetch=1):
|
||||
self.model = model
|
||||
self.lookahead_distance = max_prefetch
|
||||
self.hook_handles = []
|
||||
self.submodules_list = self.get_module_list()
|
||||
|
||||
for c in self.submodules_list:
|
||||
c._apply(lambda x: x.pin_memory())
|
||||
|
||||
self.submodules_indexer = {}
|
||||
self.module_movement_events = {}
|
||||
self.default_stream = devices.get_current_stream()
|
||||
self.model_mover_stream = stream_impl(device=devices.device)
|
||||
|
||||
for i, module in enumerate(self.submodules_list):
|
||||
self.submodules_indexer[module] = i
|
||||
|
||||
def get_module_list(self):
|
||||
return []
|
||||
|
||||
def install(self):
|
||||
for i in range(len(self.submodules_list)):
|
||||
self.hook_handles.append(
|
||||
self.submodules_list[i].register_forward_pre_hook(
|
||||
self._pre_forward_hook
|
||||
)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
self.submodules_list[i].register_forward_hook(self._post_forward_hook)
|
||||
)
|
||||
|
||||
def uninstall(self):
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
|
||||
def _pre_forward_hook(self, module, _):
|
||||
with stream_wrapper(stream=self.model_mover_stream):
|
||||
idx = self.submodules_indexer[module]
|
||||
for i in range(idx, idx + self.lookahead_distance):
|
||||
submodule = self.submodules_list[i % len(self.submodules_list)]
|
||||
if submodule in self.module_movement_events:
|
||||
# already in GPU
|
||||
continue
|
||||
submodule.to(devices.device, non_blocking=True)
|
||||
self.module_movement_events[submodule] = (
|
||||
self.model_mover_stream.record_event()
|
||||
)
|
||||
|
||||
this_event = self.module_movement_events.get(module, None)
|
||||
if this_event is not None:
|
||||
self.default_stream.wait_event(this_event)
|
||||
else:
|
||||
print(
|
||||
f"Module {module.__name__} was not moved to GPU. Taking slow path"
|
||||
)
|
||||
submodule.to(devices.device, non_blocking=True)
|
||||
|
||||
def _post_forward_hook(self, module, _1, _2):
|
||||
with stream_wrapper(stream=self.model_mover_stream):
|
||||
del self.module_movement_events[module]
|
||||
module.to(cpu, non_blocking=True)
|
||||
def calc_wrapper():
|
||||
if rtmover is not None:
|
||||
return stream_wrapper(stream=rtmover.calc_stream)
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class SmartModelMover:
|
||||
@classmethod
|
||||
def register(cls, model, vram_allowance=0, max_prefetch=10):
|
||||
instance = cls(model, vram_allowance, max_prefetch)
|
||||
model.lowvram_model_mover = instance
|
||||
return instance
|
||||
|
||||
def __init__(self, model, vram_allowance=0, max_prefetch=10):
|
||||
self.model = model
|
||||
self.vram_allowance = vram_allowance * 1024 * 1024
|
||||
self.vram_allowance_remaining = vram_allowance * 1024 * 1024
|
||||
self.max_prefetch = max_prefetch
|
||||
self.hook_handles = []
|
||||
submodules_list = self.get_module_list()
|
||||
|
||||
for c in submodules_list:
|
||||
c._apply(lambda x: x.pin_memory())
|
||||
|
||||
self.submodules_list = [
|
||||
k for c in submodules_list for k in self.get_childrens(c)
|
||||
]
|
||||
self.parameters_list = [
|
||||
list(x.parameters()) for x in self.submodules_list
|
||||
]
|
||||
self.parameters_sizes = [
|
||||
sum([p.numel() * p.element_size() for p in x]) for x in self.parameters_list
|
||||
]
|
||||
self.online_modules = set()
|
||||
self.online_module_count = 0
|
||||
self.submodules_indexer = {}
|
||||
|
||||
for i, module in enumerate(self.submodules_list):
|
||||
self.submodules_indexer[module] = i
|
||||
|
||||
def test_children(self, op):
|
||||
return op.__class__.__name__ in [
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"Linear",
|
||||
"GroupNorm",
|
||||
"LayerNorm",
|
||||
]
|
||||
|
||||
def get_childrens(self, container):
|
||||
if isinstance(container, torch.nn.Sequential):
|
||||
# return [c for c in container]
|
||||
return [cc for c in container for cc in self.get_childrens(c)]
|
||||
if "children" in dir(container):
|
||||
childrens = [
|
||||
cc for c in container.children() for cc in self.get_childrens(c)
|
||||
]
|
||||
if len(childrens) > 0:
|
||||
return childrens
|
||||
return [container]
|
||||
|
||||
def drain_allowance(self, idx):
|
||||
parameters_len = len(self.parameters_list)
|
||||
|
||||
# no vram limitation is set
|
||||
if self.vram_allowance <= 0:
|
||||
# fetch up to max_prefetch parameters
|
||||
while self.online_module_count < self.max_prefetch and self.online_module_count < parameters_len:
|
||||
param = self.parameters_list[idx]
|
||||
self.online_modules.add(idx)
|
||||
self.online_module_count += 1
|
||||
yield param
|
||||
idx = (idx + 1) % parameters_len
|
||||
return
|
||||
|
||||
# if there is still vram allowance, and it has not reached max_prefetch
|
||||
while self.vram_allowance_remaining > 0 and (
|
||||
self.max_prefetch < 1 or self.online_module_count < self.max_prefetch
|
||||
) and self.online_module_count < parameters_len:
|
||||
param = self.parameters_list[idx]
|
||||
param_size = self.parameters_sizes[idx]
|
||||
|
||||
# empty module or already online
|
||||
if len(param) == 0 or idx in self.online_modules:
|
||||
self.online_modules.add(idx)
|
||||
self.online_module_count += 1
|
||||
idx = (idx + 1) % parameters_len
|
||||
continue
|
||||
|
||||
# if the parameter size is bigger than the remaining vram allowance, and there are already online modules
|
||||
if (
|
||||
param_size > self.vram_allowance_remaining
|
||||
and self.online_module_count > 0
|
||||
):
|
||||
return
|
||||
self.vram_allowance_remaining -= param_size
|
||||
self.online_modules.add(idx)
|
||||
self.online_module_count += 1
|
||||
yield param
|
||||
idx = (idx + 1) % parameters_len
|
||||
|
||||
def fill_allowance(self, idx):
|
||||
if self.vram_allowance > 0:
|
||||
self.vram_allowance_remaining += self.parameters_sizes[idx]
|
||||
self.online_modules.remove(idx)
|
||||
self.online_module_count -= 1
|
||||
|
||||
def get_module_list(self):
|
||||
return []
|
||||
|
||||
def install(self):
|
||||
for submodule in self.submodules_list:
|
||||
self.hook_handles.append(
|
||||
submodule.register_forward_pre_hook(self._pre_forward_hook)
|
||||
)
|
||||
self.hook_handles.append(
|
||||
submodule.register_forward_hook(self._post_forward_hook)
|
||||
)
|
||||
|
||||
def uninstall(self):
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
|
||||
for idx in self.online_modules:
|
||||
mover.forget_batch(self.parameters_list[idx])
|
||||
|
||||
def preload(self):
|
||||
idx = 0
|
||||
for parameters in self.drain_allowance(idx):
|
||||
for param in parameters:
|
||||
mover.move(param, device=devices.device)
|
||||
|
||||
def _pre_forward_hook(self, module, *args, **kwargs):
|
||||
idx = self.submodules_indexer[module]
|
||||
for parameters in self.drain_allowance(idx):
|
||||
for param in parameters:
|
||||
mover.move(param, device=devices.device)
|
||||
|
||||
def _post_forward_hook(self, module, *args, **kwargs):
|
||||
idx = self.submodules_indexer[module]
|
||||
|
||||
mover.forget_batch(self.parameters_list[idx])
|
||||
self.fill_allowance(idx)
|
||||
|
||||
|
||||
class DiffModelMover(ModelMover):
|
||||
def get_module_list(self):
|
||||
modules = []
|
||||
modules.append(self.model.time_embed)
|
||||
for block in self.model.input_blocks:
|
||||
modules.append(block)
|
||||
modules.append(self.model.middle_block)
|
||||
for block in self.model.output_blocks:
|
||||
modules.append(block)
|
||||
return modules
|
||||
|
||||
|
||||
class DiffSmartModelMover(SmartModelMover):
|
||||
def get_module_list(self):
|
||||
modules = []
|
||||
modules.append(self.model.time_embed)
|
||||
for block in self.model.input_blocks:
|
||||
modules.append(block)
|
||||
modules.append(self.model.middle_block)
|
||||
for block in self.model.output_blocks:
|
||||
modules.append(block)
|
||||
return modules
|
||||
def calc_sync():
|
||||
if rtmover is not None:
|
||||
return rtmover.calc_stream.synchronize()
|
||||
return nullcontext()
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
@ -530,9 +266,13 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
# install hooks for bits of third model
|
||||
|
||||
if stream_impl is not None and stream_wrapper is not None:
|
||||
mp = DiffSmartModelMover.register(diff_model, vram_allowance=512, max_prefetch=70)
|
||||
mp.install()
|
||||
mp.preload()
|
||||
# put it into pinned memory to achieve data transfer overlap
|
||||
diff_model.time_embed._apply(lambda x: x.pin_memory())
|
||||
for block in diff_model.input_blocks:
|
||||
block._apply(lambda x: x.pin_memory())
|
||||
diff_model.middle_block._apply(lambda x: x.pin_memory())
|
||||
for block in diff_model.output_blocks:
|
||||
block._apply(lambda x: x.pin_memory())
|
||||
else:
|
||||
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.input_blocks:
|
||||
|
@ -782,7 +782,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
||||
|
||||
res = process_images_inner(p)
|
||||
with lowvram.calc_wrapper():
|
||||
res = process_images_inner(p)
|
||||
lowvram.calc_sync()
|
||||
|
||||
finally:
|
||||
sd_models.apply_token_merging(p.sd_model, 0)
|
||||
|
@ -446,7 +446,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
module.to(torch.float8_e4m3fn)._apply(lambda x: x.pin_memory())
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user