remove profiler

This commit is contained in:
wfjsw 2024-02-12 00:25:40 -06:00
parent a58ee39e38
commit cf3cc4c762

View File

@ -208,33 +208,31 @@ class ModelMover:
handle.remove() handle.remove()
def _pre_forward_hook(self, module, _): def _pre_forward_hook(self, module, _):
with torch.profiler.record_function("lowvram prehook"): with stream_wrapper(stream=self.model_mover_stream):
with stream_wrapper(stream=self.model_mover_stream): idx = self.submodules_indexer[module]
idx = self.submodules_indexer[module] for i in range(idx, idx + self.lookahead_distance):
for i in range(idx, idx + self.lookahead_distance): submodule = self.submodules_list[i % len(self.submodules_list)]
submodule = self.submodules_list[i % len(self.submodules_list)] if submodule in self.module_movement_events:
if submodule in self.module_movement_events: # already in GPU
# already in GPU continue
continue
submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = (
self.model_mover_stream.record_event()
)
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(
f"Module {module.__name__} was not moved to GPU. Taking slow path"
)
submodule.to(devices.device, non_blocking=True) submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = (
self.model_mover_stream.record_event()
)
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(
f"Module {module.__name__} was not moved to GPU. Taking slow path"
)
submodule.to(devices.device, non_blocking=True)
def _post_forward_hook(self, module, _1, _2): def _post_forward_hook(self, module, _1, _2):
with torch.profiler.record_function("lowvram posthook"): with stream_wrapper(stream=self.model_mover_stream):
with stream_wrapper(stream=self.model_mover_stream): del self.module_movement_events[module]
del self.module_movement_events[module] module.to(cpu, non_blocking=True)
module.to(cpu, non_blocking=True)
class SmartModelMover: class SmartModelMover: