Disable nan check by default

This commit is contained in:
huchenlei 2024-05-15 16:09:10 -04:00
parent 1c0a0c4c26
commit 06a7475f84
3 changed files with 9 additions and 6 deletions

View File

@ -69,7 +69,8 @@ parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefe
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*") parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="[Deprecated] do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--enable-nan-check", action='store_true', help="Check if produced images/latent spaces have nans at extra performance cost. (~20ms/it)")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")

View File

@ -230,7 +230,7 @@ class NansException(Exception):
def test_for_nans(x, where): def test_for_nans(x, where):
if shared.cmd_opts.disable_nan_check: if not shared.cmd_opts.enable_nan_check:
return return
if not torch.all(torch.isnan(x)).item(): if not torch.all(torch.isnan(x)).item():
@ -250,8 +250,6 @@ def test_for_nans(x, where):
else: else:
message = "A tensor with all NaNs was produced." message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message) raise NansException(message)

View File

@ -440,6 +440,10 @@ def prepare_environment():
git_pull_recursive(extensions_dir) git_pull_recursive(extensions_dir)
startup_timer.record("update extensions") startup_timer.record("update extensions")
if args.disable_nan_check:
print("Nan check disabled by default. --disable-nan-check argument is now ignored. "
"Use --enable-nan-check to re-enable nan check.")
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
@ -454,8 +458,8 @@ def configure_for_tests():
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt")) sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv: if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test") sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv: if "--enable-nan-check" in sys.argv:
sys.argv.append("--disable-nan-check") sys.argv.remove("--enable-nan-check")
os.environ['COMMANDLINE_ARGS'] = "" os.environ['COMMANDLINE_ARGS'] = ""