Add NPU Support

This commit is contained in:
wangshuai09 2024-01-27 17:21:32 +08:00
parent cf2772fab0
commit ec124607f4
7 changed files with 62 additions and 3 deletions

View File

@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
from modules import errors, shared
from modules import errors, shared, npu_specific
if sys.platform == "darwin":
from modules import mac_specific
@ -40,6 +40,9 @@ def get_optimal_device_name():
if has_xpu():
return xpu_specific.get_xpu_device_string()
if npu_specific.has_npu:
return npu_specific.get_npu_device_string()
return "cpu"
@ -67,6 +70,9 @@ def torch_gc():
if has_xpu():
xpu_specific.torch_xpu_gc()
if npu_specific.has_npu:
npu_specific.torch_npu_gc()
def enable_tf32():
if torch.cuda.is_available():
@ -164,4 +170,3 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)

View File

@ -143,13 +143,17 @@ def initialize_rest(*, reload_script_modules=False):
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
from modules import devices
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
if devices.npu_specific.has_npu:
import torch
torch.npu.set_device(0)
shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()
from modules import devices
devices.first_time_calculation()
if not shared.cmd_opts.skip_load_model_at_start:
Thread(target=load_model).start()

34
modules/npu_specific.py Normal file
View File

@ -0,0 +1,34 @@
import importlib
import torch
from modules import shared
def check_for_npu():
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu
torch_npu.npu.set_device(0)
try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
def get_npu_device_string():
if shared.cmd_opts.device_id is not None:
return f"npu:{shared.cmd_opts.device_id}"
return "npu:0"
def torch_npu_gc():
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
torch.npu.set_device(0)
with torch.npu.device(get_npu_device_string()):
torch.npu.empty_cache()
has_npu = check_for_npu()

View File

@ -151,6 +151,10 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
# workaround
if devices.npu_specific.has_npu:
import torch
torch.npu.set_device(0)
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]

View File

@ -5,6 +5,8 @@ accelerate
basicsr
blendmodes
clean-fid
cloudpickle
decorator
einops
fastapi>=0.90.1
gfpgan
@ -26,9 +28,11 @@ resize-right
safetensors
scikit-image>=0.19
synr==0.5.0
timm
tomesd
torch
torchdiffeq
torchsde
tornado
transformers==4.30.2

View File

@ -4,6 +4,8 @@ accelerate==0.21.0
basicsr==1.4.2
blendmodes==2022
clean-fid==0.1.35
cloudpickle==3.0.0
decorator==5.1.1
einops==0.4.1
fastapi==0.94.0
gfpgan==1.3.8
@ -23,10 +25,12 @@ realesrgan==0.3.0
resize-right==0.0.2
safetensors==0.3.1
scikit-image==0.21.0
synr==0.5.0
timm==0.9.2
tomesd==0.1.3
torch
torchdiffeq==0.2.3
torchsde==0.2.6
tornado==6.4
transformers==4.30.2
httpx==0.24.1

View File

@ -159,6 +159,10 @@ then
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]]
then
export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu"
fi
fi