async weight mover

This commit is contained in:
wfjsw 2024-02-07 02:26:10 -06:00
parent 2e93bdce0c
commit c1702ea498
2 changed files with 126 additions and 6 deletions

View File

@ -269,3 +269,43 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
def get_stream_impl():
if torch.cuda.is_available():
return torch.cuda.Stream
if has_xpu():
return torch.xpu.Stream
return None
def get_stream_wrapper():
if torch.cuda.is_available():
return torch.cuda.stream
if has_xpu():
return torch.xpu.stream
return None
def get_event_impl():
if torch.cuda.is_available():
return torch.cuda.Event
if has_xpu():
return torch.xpu.Event
return None
def get_current_stream():
if torch.cuda.is_available():
return torch.cuda.current_stream(device)
if has_xpu():
return torch.xpu.current_stream(device)
return None

View File

@ -4,6 +4,77 @@ from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
stream_impl = devices.get_stream_impl()
stream_wrapper = devices.get_stream_wrapper()
class ModelMover:
@classmethod
def register(cls, model, lookahead_distance=1):
instance = cls(model, lookahead_distance)
setattr(model, 'lowvram_model_mover', instance)
return instance
def __init__(self, model, lookahead_distance=1):
self.model = model
self.lookahead_distance = lookahead_distance
self.hook_handles = []
self.submodules_list = self.get_module_list()
self.submodules_indexer = {}
self.module_movement_events = {}
self.default_stream = devices.get_current_stream()
self.model_mover_stream = stream_impl(device=devices.device)
for i, module in enumerate(self.submodules_list):
self.submodules_indexer[module] = i
def get_module_list(self):
return []
def install(self):
for i in range(len(self.submodules_list)):
self.hook_handles.append(self.submodules_list[i].register_forward_pre_hook(self._pre_forward_hook))
self.hook_handles.append(self.submodules_list[i].register_forward_hook(self._post_forward_hook))
def uninstall(self):
for handle in self.hook_handles:
handle.remove()
def _pre_forward_hook(self, module, _):
with stream_wrapper(stream=self.model_mover_stream):
idx = self.submodules_indexer[module]
for i in range(idx, idx + self.lookahead_distance):
submodule = self.submodules_list[i % len(self.submodules_list)]
if submodule in self.module_movement_events:
# already in GPU
continue
submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = self.model_mover_stream.record_event()
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(f"Module {module.__name__} was not moved to GPU. Taking slow path")
submodule.to(devices.device, non_blocking=True)
def _post_forward_hook(self, module, _1, _2):
with stream_wrapper(stream=self.model_mover_stream):
del self.module_movement_events[module]
module.to(cpu, non_blocking=True)
class DiffModelMover(ModelMover):
def get_module_list(self):
modules = []
modules.append(self.model.time_embed)
for block in self.model.input_blocks:
modules.append(block)
modules.append(self.model.middle_block)
for block in self.model.output_blocks:
modules.append(block)
return modules
def send_everything_to_cpu():
global module_in_gpu
@ -43,6 +114,19 @@ def setup_for_low_vram(sd_model, use_medvram):
"""
global module_in_gpu
try:
name = module._get_name()
except:
try:
name = module.__name__
except:
try:
name = module.__class__.__name__
except:
name = str(module)
print(f"Moving {module.__module__}.{name} to GPU")
module = parents.get(module, module)
if module_in_gpu == module:
@ -135,12 +219,8 @@ def setup_for_low_vram(sd_model, use_medvram):
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
mover = DiffModelMover.register(diff_model, lookahead_distance=8)
mover.install()
def is_enabled(sd_model):