2022-10-21 22:11:07 +00:00
|
|
|
import sys, os, shlex
|
2022-10-04 09:32:22 +00:00
|
|
|
import contextlib
|
2022-09-11 05:11:27 +00:00
|
|
|
import torch
|
2022-09-12 13:34:13 +00:00
|
|
|
from modules import errors
|
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2022-11-12 03:02:40 +00:00
|
|
|
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
|
|
|
|
# check `getattr` and try it for compatibility
|
|
|
|
def has_mps() -> bool:
|
2022-11-12 07:00:49 +00:00
|
|
|
if not getattr(torch, 'has_mps', False):
|
|
|
|
return False
|
2022-11-12 03:02:40 +00:00
|
|
|
try:
|
|
|
|
torch.zeros(1).to(torch.device("mps"))
|
|
|
|
return True
|
|
|
|
except Exception:
|
|
|
|
return False
|
2022-09-11 05:11:27 +00:00
|
|
|
|
2022-09-11 15:48:36 +00:00
|
|
|
|
2022-10-21 22:11:07 +00:00
|
|
|
def extract_device_id(args, name):
|
|
|
|
for x in range(len(args)):
|
2022-11-12 07:00:49 +00:00
|
|
|
if name in args[x]:
|
|
|
|
return args[x + 1]
|
|
|
|
|
2022-10-21 22:11:07 +00:00
|
|
|
return None
|
2022-09-11 15:48:36 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2022-09-11 05:11:27 +00:00
|
|
|
def get_optimal_device():
|
2022-09-11 15:48:36 +00:00
|
|
|
if torch.cuda.is_available():
|
2022-10-22 11:04:14 +00:00
|
|
|
from modules import shared
|
|
|
|
|
|
|
|
device_id = shared.cmd_opts.device_id
|
|
|
|
|
2022-10-21 22:11:07 +00:00
|
|
|
if device_id is not None:
|
|
|
|
cuda_device = f"cuda:{device_id}"
|
|
|
|
return torch.device(cuda_device)
|
|
|
|
else:
|
|
|
|
return torch.device("cuda")
|
2022-09-11 15:48:36 +00:00
|
|
|
|
2022-11-29 02:28:41 +00:00
|
|
|
# if has_mps():
|
|
|
|
# return torch.device("mps")
|
2022-09-11 15:48:36 +00:00
|
|
|
|
|
|
|
return cpu
|
2022-09-11 20:24:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
def torch_gc():
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.ipc_collect()
|
2022-09-12 13:34:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
def enable_tf32():
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
|
|
errors.run(enable_tf32, "Enabling TF32")
|
2022-09-12 17:09:32 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
cpu = torch.device("cpu")
|
2022-10-25 03:04:50 +00:00
|
|
|
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
2022-10-02 12:03:39 +00:00
|
|
|
dtype = torch.float16
|
2022-10-10 13:11:14 +00:00
|
|
|
dtype_vae = torch.float16
|
2022-09-12 17:09:32 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2022-09-12 17:09:32 +00:00
|
|
|
def randn(seed, shape):
|
|
|
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
|
|
|
if device.type == 'mps':
|
|
|
|
generator = torch.Generator(device=cpu)
|
|
|
|
generator.manual_seed(seed)
|
|
|
|
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
|
|
|
return noise
|
|
|
|
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
return torch.randn(shape, device=device)
|
|
|
|
|
2022-09-13 18:49:58 +00:00
|
|
|
|
|
|
|
def randn_without_seed(shape):
|
|
|
|
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
|
|
|
|
if device.type == 'mps':
|
|
|
|
generator = torch.Generator(device=cpu)
|
|
|
|
noise = torch.randn(shape, generator=generator, device=cpu).to(device)
|
|
|
|
return noise
|
|
|
|
|
|
|
|
return torch.randn(shape, device=device)
|
|
|
|
|
2022-10-04 09:32:22 +00:00
|
|
|
|
2022-10-10 13:11:14 +00:00
|
|
|
def autocast(disable=False):
|
2022-10-04 09:32:22 +00:00
|
|
|
from modules import shared
|
|
|
|
|
2022-10-10 13:11:14 +00:00
|
|
|
if disable:
|
|
|
|
return contextlib.nullcontext()
|
|
|
|
|
2022-10-04 09:32:22 +00:00
|
|
|
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
|
|
|
|
return contextlib.nullcontext()
|
|
|
|
|
|
|
|
return torch.autocast("cuda")
|
2022-10-25 06:01:57 +00:00
|
|
|
|
2022-11-12 07:00:49 +00:00
|
|
|
|
2022-10-25 06:01:57 +00:00
|
|
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
2022-11-12 07:00:49 +00:00
|
|
|
def mps_contiguous(input_tensor, device):
|
|
|
|
return input_tensor.contiguous() if device.type == 'mps' else input_tensor
|
|
|
|
|
|
|
|
|
|
|
|
def mps_contiguous_to(input_tensor, device):
|
|
|
|
return mps_contiguous(input_tensor, device).to(device)
|