import importlib import torch from modules import shared def check_for_npu(): if importlib.util.find_spec("torch_npu") is None: return False import torch_npu try: # Will raise a RuntimeError if no NPU is found _ = torch_npu.npu.device_count() return torch.npu.is_available() except RuntimeError: return False def get_npu_device_string(): if shared.cmd_opts.device_id is not None: return f"npu:{shared.cmd_opts.device_id}" return "npu:0" def torch_npu_gc(): with torch.npu.device(get_npu_device_string()): torch.npu.empty_cache() has_npu = check_for_npu()