diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index a87d1ef93..8b048ae00 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -11,25 +11,41 @@ from omegaconf import OmegaConf from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap +from modules import shared, sd_hijack warnings.filterwarnings("ignore", category=UserWarning) +cached_ldsr_model: torch.nn.Module = None + # Create LDSR Class class LDSR: def load_model_from_config(self, half_attention): - print(f"Loading model from {self.modelPath}") - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] - config = OmegaConf.load(self.yamlPath) - config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" - model = instantiate_from_config(config.model) - model.load_state_dict(sd, strict=False) - model.cuda() - if half_attention: - model = model.half() + global cached_ldsr_model + + if shared.opts.ldsr_cached and cached_ldsr_model is not None: + print(f"Loading model from cache") + model: torch.nn.Module = cached_ldsr_model + else: + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" + model: torch.nn.Module = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model = model.to(shared.device) + if half_attention: + model = model.half() + if shared.cmd_opts.opt_channelslast: + model = model.to(memory_format=torch.channels_last) + + sd_hijack.model_hijack.hijack(model) # apply optimization + model.eval() + + if shared.opts.ldsr_cached: + cached_ldsr_model = model - model.eval() return {"model": model} def __init__(self, model_path, yaml_path): @@ -94,7 +110,8 @@ class LDSR: down_sample_method = 'Lanczos' gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() im_og = image width_og, height_og = im_og.size @@ -131,7 +148,9 @@ class LDSR: del model gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available: + torch.cuda.empty_cache() + return a @@ -146,7 +165,7 @@ def get_cond(selected_path): c = rearrange(c, '1 c h w -> 1 h w c') c = 2. * c - 1. - c = c.to(torch.device("cuda")) + c = c.to(shared.device) example["LR_image"] = c example["image"] = c_up diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 5c96037de..29d5f94ed 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -59,6 +59,7 @@ def on_ui_settings(): import gradio as gr shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling"))) script_callbacks.on_ui_settings(on_ui_settings)