Verify architecture for loaded Spandrel models

This commit is contained in:
Aarni Koskela 2023-12-30 16:37:03 +02:00
parent c756133541
commit 4ad0c0c0a8
8 changed files with 22 additions and 5 deletions

View File

@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else: else:
filename = path filename = path
return modelloader.load_spandrel_model(filename, device=device) return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
def on_ui_settings(): def on_ui_settings():

View File

@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler):
filename, filename,
device=self._get_device(), device=self._get_device(),
dtype=devices.dtype, dtype=devices.dtype,
expected_architecture="SwinIR",
) )
if getattr(opts, 'SWIN_torch_compile', False): if getattr(opts, 'SWIN_torch_compile', False):
try: try:

View File

@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
return modelloader.load_spandrel_model( return modelloader.load_spandrel_model(
model_path, model_path,
device=devices.device_codeformer, device=devices.device_codeformer,
expected_architecture='CodeFormer',
).model ).model
raise ValueError("No codeformer model found") raise ValueError("No codeformer model found")

View File

@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
return modelloader.load_spandrel_model( return modelloader.load_spandrel_model(
filename, filename,
device=('cpu' if devices.device_esrgan.type == 'mps' else None), device=('cpu' if devices.device_esrgan.type == 'mps' else None),
expected_architecture='ESRGAN',
) )

View File

@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
net = modelloader.load_spandrel_model( net = modelloader.load_spandrel_model(
model_path, model_path,
device=self.get_device(), device=self.get_device(),
expected_architecture='GFPGAN',
).model ).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net return net

View File

@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
return modelloader.load_spandrel_model( return modelloader.load_spandrel_model(
path, path,
device=devices.device_esrgan, # TODO: should probably be device_hat device=devices.device_esrgan, # TODO: should probably be device_hat
expected_architecture='HAT',
) )

View File

@ -6,6 +6,8 @@ import shutil
import importlib import importlib
from urllib.parse import urlparse from urllib.parse import urlparse
import torch
from modules import shared from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
from modules.paths import script_path, models_path from modules.paths import script_path, models_path
@ -183,9 +185,18 @@ def load_upscalers():
) )
def load_spandrel_model(path, *, device, half: bool = False, dtype=None): def load_spandrel_model(
path: str,
*,
device: str | torch.device | None,
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
):
import spandrel import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path) model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
raise TypeError(f"Model {path} is not a {expected_architecture} model")
if half: if half:
model = model.model.half() model = model.model.half()
if dtype: if dtype:

View File

@ -1,9 +1,9 @@
import os import os
from modules.upscaler_utils import upscale_with_model
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import cmd_opts, opts
from modules import modelloader, errors from modules import modelloader, errors
from modules.shared import cmd_opts, opts
from modules.upscaler import Upscaler, UpscalerData
from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
info.local_data_path, info.local_data_path,
device=self.device, device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
expected_architecture="RealESRGAN",
) )
return upscale_with_model( return upscale_with_model(
mod, mod,