From 06a7475f84f224c2675778fd7d883ff04e906475 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 15 May 2024 16:09:10 -0400 Subject: [PATCH] Disable nan check by default --- modules/cmd_args.py | 3 ++- modules/devices.py | 4 +--- modules/launch_utils.py | 8 ++++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 016a33d10..26335903d 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -69,7 +69,8 @@ parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefe parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*") parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") -parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--disable-nan-check", action='store_true', help="[Deprecated] do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") +parser.add_argument("--enable-nan-check", action='store_true', help="Check if produced images/latent spaces have nans at extra performance cost. (~20ms/it)") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device") parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") diff --git a/modules/devices.py b/modules/devices.py index e4f671ac6..096918ca4 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -230,7 +230,7 @@ class NansException(Exception): def test_for_nans(x, where): - if shared.cmd_opts.disable_nan_check: + if not shared.cmd_opts.enable_nan_check: return if not torch.all(torch.isnan(x)).item(): @@ -250,8 +250,6 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." - message += " Use --disable-nan-check commandline argument to disable this check." - raise NansException(message) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 5812b0e58..ddc411076 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -440,6 +440,10 @@ def prepare_environment(): git_pull_recursive(extensions_dir) startup_timer.record("update extensions") + if args.disable_nan_check: + print("Nan check disabled by default. --disable-nan-check argument is now ignored. " + "Use --enable-nan-check to re-enable nan check.") + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) @@ -454,8 +458,8 @@ def configure_for_tests(): sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt")) if "--skip-torch-cuda-test" not in sys.argv: sys.argv.append("--skip-torch-cuda-test") - if "--disable-nan-check" not in sys.argv: - sys.argv.append("--disable-nan-check") + if "--enable-nan-check" in sys.argv: + sys.argv.remove("--enable-nan-check") os.environ['COMMANDLINE_ARGS'] = ""