load_spandrel_model: make `half` `prefer_half`

As discussed with the Spandrel folks, it's good to heed Spandrel's
"supports half precision" flag to avoid e.g. black blotches and what-not.
This commit is contained in:
Aarni Koskela 2023-12-31 19:52:32 +02:00
parent 51f1cca852
commit 2cacbc124c
2 changed files with 15 additions and 7 deletions

View File

@ -139,23 +139,31 @@ def load_upscalers():
def load_spandrel_model(
path: str,
path: str | os.PathLike,
*,
device: str | torch.device | None,
half: bool = False,
prefer_half: bool = False,
dtype: str | torch.dtype | None = None,
expected_architecture: str | None = None,
) -> spandrel.ModelDescriptor:
import spandrel
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
if half:
model_descriptor.model.half()
half = False
if prefer_half:
if model_descriptor.supports_half:
model_descriptor.model.half()
half = True
else:
logger.info("Model %s does not support half precision, ignoring --half", path)
if dtype:
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
logger.debug(
"Loaded %s from %s (device=%s, half=%s, dtype=%s)",
model_descriptor, path, device, half, dtype,
)
return model_descriptor

View File

@ -39,7 +39,7 @@ class UpscalerRealESRGAN(Upscaler):
model_descriptor = modelloader.load_spandrel_model(
info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel
)
return upscale_with_model(