mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
async weight mover
This commit is contained in:
parent
2e93bdce0c
commit
c1702ea498
@ -269,3 +269,43 @@ def first_time_calculation():
|
||||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
|
||||
|
||||
def get_stream_impl():
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.Stream
|
||||
|
||||
if has_xpu():
|
||||
return torch.xpu.Stream
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_wrapper():
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.stream
|
||||
|
||||
if has_xpu():
|
||||
return torch.xpu.stream
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_event_impl():
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.Event
|
||||
|
||||
if has_xpu():
|
||||
return torch.xpu.Event
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_current_stream():
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.current_stream(device)
|
||||
|
||||
if has_xpu():
|
||||
return torch.xpu.current_stream(device)
|
||||
|
||||
return None
|
||||
|
@ -4,6 +4,77 @@ from modules import devices, shared
|
||||
module_in_gpu = None
|
||||
cpu = torch.device("cpu")
|
||||
|
||||
stream_impl = devices.get_stream_impl()
|
||||
stream_wrapper = devices.get_stream_wrapper()
|
||||
|
||||
class ModelMover:
|
||||
@classmethod
|
||||
def register(cls, model, lookahead_distance=1):
|
||||
instance = cls(model, lookahead_distance)
|
||||
setattr(model, 'lowvram_model_mover', instance)
|
||||
return instance
|
||||
|
||||
def __init__(self, model, lookahead_distance=1):
|
||||
self.model = model
|
||||
self.lookahead_distance = lookahead_distance
|
||||
self.hook_handles = []
|
||||
self.submodules_list = self.get_module_list()
|
||||
self.submodules_indexer = {}
|
||||
self.module_movement_events = {}
|
||||
self.default_stream = devices.get_current_stream()
|
||||
self.model_mover_stream = stream_impl(device=devices.device)
|
||||
|
||||
for i, module in enumerate(self.submodules_list):
|
||||
self.submodules_indexer[module] = i
|
||||
|
||||
def get_module_list(self):
|
||||
return []
|
||||
|
||||
def install(self):
|
||||
for i in range(len(self.submodules_list)):
|
||||
self.hook_handles.append(self.submodules_list[i].register_forward_pre_hook(self._pre_forward_hook))
|
||||
self.hook_handles.append(self.submodules_list[i].register_forward_hook(self._post_forward_hook))
|
||||
|
||||
def uninstall(self):
|
||||
for handle in self.hook_handles:
|
||||
handle.remove()
|
||||
|
||||
def _pre_forward_hook(self, module, _):
|
||||
|
||||
with stream_wrapper(stream=self.model_mover_stream):
|
||||
idx = self.submodules_indexer[module]
|
||||
for i in range(idx, idx + self.lookahead_distance):
|
||||
submodule = self.submodules_list[i % len(self.submodules_list)]
|
||||
if submodule in self.module_movement_events:
|
||||
# already in GPU
|
||||
continue
|
||||
submodule.to(devices.device, non_blocking=True)
|
||||
self.module_movement_events[submodule] = self.model_mover_stream.record_event()
|
||||
|
||||
this_event = self.module_movement_events.get(module, None)
|
||||
if this_event is not None:
|
||||
self.default_stream.wait_event(this_event)
|
||||
else:
|
||||
print(f"Module {module.__name__} was not moved to GPU. Taking slow path")
|
||||
submodule.to(devices.device, non_blocking=True)
|
||||
|
||||
def _post_forward_hook(self, module, _1, _2):
|
||||
with stream_wrapper(stream=self.model_mover_stream):
|
||||
del self.module_movement_events[module]
|
||||
module.to(cpu, non_blocking=True)
|
||||
|
||||
|
||||
class DiffModelMover(ModelMover):
|
||||
def get_module_list(self):
|
||||
modules = []
|
||||
modules.append(self.model.time_embed)
|
||||
for block in self.model.input_blocks:
|
||||
modules.append(block)
|
||||
modules.append(self.model.middle_block)
|
||||
for block in self.model.output_blocks:
|
||||
modules.append(block)
|
||||
return modules
|
||||
|
||||
|
||||
def send_everything_to_cpu():
|
||||
global module_in_gpu
|
||||
@ -43,6 +114,19 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
"""
|
||||
global module_in_gpu
|
||||
|
||||
try:
|
||||
name = module._get_name()
|
||||
except:
|
||||
try:
|
||||
name = module.__name__
|
||||
except:
|
||||
try:
|
||||
name = module.__class__.__name__
|
||||
except:
|
||||
name = str(module)
|
||||
|
||||
print(f"Moving {module.__module__}.{name} to GPU")
|
||||
|
||||
module = parents.get(module, module)
|
||||
|
||||
if module_in_gpu == module:
|
||||
@ -135,12 +219,8 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
|
||||
|
||||
# install hooks for bits of third model
|
||||
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.input_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
|
||||
for block in diff_model.output_blocks:
|
||||
block.register_forward_pre_hook(send_me_to_gpu)
|
||||
mover = DiffModelMover.register(diff_model, lookahead_distance=8)
|
||||
mover.install()
|
||||
|
||||
|
||||
def is_enabled(sd_model):
|
||||
|
Loading…
Reference in New Issue
Block a user