Better naming

This commit is contained in:
Kohaku-Blueleaf 2023-11-19 15:56:31 +08:00
parent f383af2729
commit 043d2edcf6

View File

@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs):
@contextlib.contextmanager
def manual_autocast():
def manual_cast():
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
@ -148,10 +148,10 @@ def autocast(disable=False):
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_autocast()
return manual_cast()
if has_mps() and shared.cmd_opts.precision != "full":
return manual_autocast()
return manual_cast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()