mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #5586 from wywywywy/ldsr-improvements
LDSR improvements - cache / optimization / opt_channelslast
This commit is contained in:
commit
685f9631b5
@ -11,25 +11,41 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.util import instantiate_from_config, ismap
|
from ldm.util import instantiate_from_config, ismap
|
||||||
|
from modules import shared, sd_hijack
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
cached_ldsr_model: torch.nn.Module = None
|
||||||
|
|
||||||
|
|
||||||
# Create LDSR Class
|
# Create LDSR Class
|
||||||
class LDSR:
|
class LDSR:
|
||||||
def load_model_from_config(self, half_attention):
|
def load_model_from_config(self, half_attention):
|
||||||
|
global cached_ldsr_model
|
||||||
|
|
||||||
|
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
||||||
|
print(f"Loading model from cache")
|
||||||
|
model: torch.nn.Module = cached_ldsr_model
|
||||||
|
else:
|
||||||
print(f"Loading model from {self.modelPath}")
|
print(f"Loading model from {self.modelPath}")
|
||||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
config = OmegaConf.load(self.yamlPath)
|
config = OmegaConf.load(self.yamlPath)
|
||||||
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
||||||
model = instantiate_from_config(config.model)
|
model: torch.nn.Module = instantiate_from_config(config.model)
|
||||||
model.load_state_dict(sd, strict=False)
|
model.load_state_dict(sd, strict=False)
|
||||||
model.cuda()
|
model = model.to(shared.device)
|
||||||
if half_attention:
|
if half_attention:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
|
if shared.cmd_opts.opt_channelslast:
|
||||||
|
model = model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.hijack(model) # apply optimization
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
if shared.opts.ldsr_cached:
|
||||||
|
cached_ldsr_model = model
|
||||||
|
|
||||||
return {"model": model}
|
return {"model": model}
|
||||||
|
|
||||||
def __init__(self, model_path, yaml_path):
|
def __init__(self, model_path, yaml_path):
|
||||||
@ -94,6 +110,7 @@ class LDSR:
|
|||||||
down_sample_method = 'Lanczos'
|
down_sample_method = 'Lanczos'
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
if torch.cuda.is_available:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
im_og = image
|
im_og = image
|
||||||
@ -131,7 +148,9 @@ class LDSR:
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
if torch.cuda.is_available:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +165,7 @@ def get_cond(selected_path):
|
|||||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||||
c = 2. * c - 1.
|
c = 2. * c - 1.
|
||||||
|
|
||||||
c = c.to(torch.device("cuda"))
|
c = c.to(shared.device)
|
||||||
example["LR_image"] = c
|
example["LR_image"] = c
|
||||||
example["image"] = c_up
|
example["image"] = c_up
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ def on_ui_settings():
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
||||||
|
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_settings(on_ui_settings)
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
Loading…
Reference in New Issue
Block a user