remove profiler

This commit is contained in:
wfjsw 2024-02-12 00:25:40 -06:00
parent a58ee39e38
commit cf3cc4c762

View File

@ -208,33 +208,31 @@ class ModelMover:
handle.remove()
def _pre_forward_hook(self, module, _):
with torch.profiler.record_function("lowvram prehook"):
with stream_wrapper(stream=self.model_mover_stream):
idx = self.submodules_indexer[module]
for i in range(idx, idx + self.lookahead_distance):
submodule = self.submodules_list[i % len(self.submodules_list)]
if submodule in self.module_movement_events:
# already in GPU
continue
submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = (
self.model_mover_stream.record_event()
)
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(
f"Module {module.__name__} was not moved to GPU. Taking slow path"
)
with stream_wrapper(stream=self.model_mover_stream):
idx = self.submodules_indexer[module]
for i in range(idx, idx + self.lookahead_distance):
submodule = self.submodules_list[i % len(self.submodules_list)]
if submodule in self.module_movement_events:
# already in GPU
continue
submodule.to(devices.device, non_blocking=True)
self.module_movement_events[submodule] = (
self.model_mover_stream.record_event()
)
this_event = self.module_movement_events.get(module, None)
if this_event is not None:
self.default_stream.wait_event(this_event)
else:
print(
f"Module {module.__name__} was not moved to GPU. Taking slow path"
)
submodule.to(devices.device, non_blocking=True)
def _post_forward_hook(self, module, _1, _2):
with torch.profiler.record_function("lowvram posthook"):
with stream_wrapper(stream=self.model_mover_stream):
del self.module_movement_events[module]
module.to(cpu, non_blocking=True)
with stream_wrapper(stream=self.model_mover_stream):
del self.module_movement_events[module]
module.to(cpu, non_blocking=True)
class SmartModelMover: