Merge pull request #13718 from avantcontra/bugfix_gfpgan_custom_path

fix bug when using --gfpgan-models-path
This commit is contained in:
AUTOMATIC1111 2023-11-03 20:19:58 +03:00 committed by GitHub
commit 452ab8fe72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
model_dir = "GFPGAN" model_dir = "GFPGAN"
user_path = None user_path = None
model_path = os.path.join(paths.models_path, model_dir) model_path = os.path.join(paths.models_path, model_dir)
model_file_path = None
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
@ -17,6 +18,7 @@ loaded_gfpgan_model = None
def gfpgann(): def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path global model_path
global model_file_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan) loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
return loaded_gfpgan_model return loaded_gfpgan_model
@ -24,17 +26,24 @@ def gfpgann():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
if len(models) == 1 and models[0].startswith("http"): if len(models) == 1 and models[0].startswith("http"):
model_file = models[0] model_file = models[0]
elif len(models) != 0: elif len(models) != 0:
latest_file = max(models, key=os.path.getctime) gfp_models = []
for item in models:
if 'GFPGAN' in os.path.basename(item):
gfp_models.append(item)
latest_file = max(gfp_models, key=os.path.getctime)
model_file = latest_file model_file = latest_file
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
if hasattr(facexlib.detection.retinaface, 'device'): if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan facexlib.detection.retinaface.device = devices.device_gfpgan
model_file_path = model_file
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan) model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model loaded_gfpgan_model = model
@ -77,19 +86,25 @@ def setup_model(dirname):
global user_path global user_path
global have_gfpgan global have_gfpgan
global gfpgan_constructor global gfpgan_constructor
global model_file_path
facexlib_path = model_path
if dirname is not None:
facexlib_path = dirname
load_file_from_url_orig = gfpgan.utils.load_file_from_url load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs): def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
def facex_load_file_from_url(**kwargs): def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
def facex_load_file_from_url2(**kwargs): def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url facexlib.detection.load_file_from_url = facex_load_file_from_url