mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui
This commit is contained in:
commit
d6fdfde9d7
@ -11,25 +11,41 @@ from omegaconf import OmegaConf
|
|||||||
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.util import instantiate_from_config, ismap
|
from ldm.util import instantiate_from_config, ismap
|
||||||
|
from modules import shared, sd_hijack
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
|
||||||
|
cached_ldsr_model: torch.nn.Module = None
|
||||||
|
|
||||||
|
|
||||||
# Create LDSR Class
|
# Create LDSR Class
|
||||||
class LDSR:
|
class LDSR:
|
||||||
def load_model_from_config(self, half_attention):
|
def load_model_from_config(self, half_attention):
|
||||||
print(f"Loading model from {self.modelPath}")
|
global cached_ldsr_model
|
||||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
|
||||||
sd = pl_sd["state_dict"]
|
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
|
||||||
config = OmegaConf.load(self.yamlPath)
|
print(f"Loading model from cache")
|
||||||
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
model: torch.nn.Module = cached_ldsr_model
|
||||||
model = instantiate_from_config(config.model)
|
else:
|
||||||
model.load_state_dict(sd, strict=False)
|
print(f"Loading model from {self.modelPath}")
|
||||||
model.cuda()
|
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||||
if half_attention:
|
sd = pl_sd["state_dict"]
|
||||||
model = model.half()
|
config = OmegaConf.load(self.yamlPath)
|
||||||
|
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
||||||
|
model: torch.nn.Module = instantiate_from_config(config.model)
|
||||||
|
model.load_state_dict(sd, strict=False)
|
||||||
|
model = model.to(shared.device)
|
||||||
|
if half_attention:
|
||||||
|
model = model.half()
|
||||||
|
if shared.cmd_opts.opt_channelslast:
|
||||||
|
model = model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
sd_hijack.model_hijack.hijack(model) # apply optimization
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if shared.opts.ldsr_cached:
|
||||||
|
cached_ldsr_model = model
|
||||||
|
|
||||||
model.eval()
|
|
||||||
return {"model": model}
|
return {"model": model}
|
||||||
|
|
||||||
def __init__(self, model_path, yaml_path):
|
def __init__(self, model_path, yaml_path):
|
||||||
@ -94,7 +110,8 @@ class LDSR:
|
|||||||
down_sample_method = 'Lanczos'
|
down_sample_method = 'Lanczos'
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
if torch.cuda.is_available:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
im_og = image
|
im_og = image
|
||||||
width_og, height_og = im_og.size
|
width_og, height_og = im_og.size
|
||||||
@ -131,7 +148,9 @@ class LDSR:
|
|||||||
|
|
||||||
del model
|
del model
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
if torch.cuda.is_available:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +165,7 @@ def get_cond(selected_path):
|
|||||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||||
c = 2. * c - 1.
|
c = 2. * c - 1.
|
||||||
|
|
||||||
c = c.to(torch.device("cuda"))
|
c = c.to(shared.device)
|
||||||
example["LR_image"] = c
|
example["LR_image"] = c
|
||||||
example["image"] = c_up
|
example["image"] = c_up
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ def on_ui_settings():
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
||||||
|
shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))
|
||||||
|
|
||||||
|
|
||||||
script_callbacks.on_ui_settings(on_ui_settings)
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
@ -88,7 +88,7 @@ function checkBrackets(evt) {
|
|||||||
if(counterElt.title != '') {
|
if(counterElt.title != '') {
|
||||||
counterElt.style = 'color: #FF5555;';
|
counterElt.style = 'color: #FF5555;';
|
||||||
} else {
|
} else {
|
||||||
counterElt.style = 'color: #000;';
|
counterElt.style = '';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,13 +13,15 @@ from skimage import exposure
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.styles
|
import modules.styles
|
||||||
|
import modules.sd_models as sd_models
|
||||||
|
import modules.sd_vae as sd_vae
|
||||||
import logging
|
import logging
|
||||||
from ldm.data.util import AddMiDaS
|
from ldm.data.util import AddMiDaS
|
||||||
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
||||||
@ -454,8 +456,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for k, v in p.override_settings.items():
|
for k, v in p.override_settings.items():
|
||||||
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible
|
setattr(opts, k, v)
|
||||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
|
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
|
||||||
|
if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
|
||||||
|
if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
@ -463,6 +467,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
if k == 'sd_hypernetwork': shared.reload_hypernetworks()
|
||||||
|
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
|
||||||
|
if k == 'sd_vae': sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -571,9 +577,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
if opts.filter_nsfw:
|
if p.scripts is not None:
|
||||||
import modules.safety as safety
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
import torch
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
safety_feature_extractor = None
|
|
||||||
safety_checker = None
|
|
||||||
|
|
||||||
def numpy_to_pil(images):
|
|
||||||
"""
|
|
||||||
Convert a numpy image or a batch of images to a PIL image.
|
|
||||||
"""
|
|
||||||
if images.ndim == 3:
|
|
||||||
images = images[None, ...]
|
|
||||||
images = (images * 255).round().astype("uint8")
|
|
||||||
pil_images = [Image.fromarray(image) for image in images]
|
|
||||||
|
|
||||||
return pil_images
|
|
||||||
|
|
||||||
# check and replace nsfw content
|
|
||||||
def check_safety(x_image):
|
|
||||||
global safety_feature_extractor, safety_checker
|
|
||||||
|
|
||||||
if safety_feature_extractor is None:
|
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
||||||
|
|
||||||
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
|
||||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
|
||||||
|
|
||||||
return x_checked_image, has_nsfw_concept
|
|
||||||
|
|
||||||
|
|
||||||
def censor_batch(x):
|
|
||||||
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
|
||||||
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
|
||||||
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
return x
|
|
@ -88,6 +88,17 @@ class Script:
|
|||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def postprocess_batch(self, p, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Same as process_batch(), but called for every batch after it has been generated.
|
||||||
|
|
||||||
|
**kwargs will have same items as process_batch, and also:
|
||||||
|
- batch_number - index of current batch, from 0 to number of batches-1
|
||||||
|
- images - torch tensor with all generated images, with values ranging from 0 to 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
def postprocess(self, p, processed, *args):
|
def postprocess(self, p, processed, *args):
|
||||||
"""
|
"""
|
||||||
This function is called after processing ends for AlwaysVisible scripts.
|
This function is called after processing ends for AlwaysVisible scripts.
|
||||||
@ -347,6 +358,15 @@ class ScriptRunner:
|
|||||||
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
def postprocess_batch(self, p, images, **kwargs):
|
||||||
|
for script in self.alwayson_scripts:
|
||||||
|
try:
|
||||||
|
script_args = p.script_args[script.args_from:script.args_to]
|
||||||
|
script.postprocess_batch(p, *script_args, images=images, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
def before_component(self, component, **kwargs):
|
def before_component(self, component, **kwargs):
|
||||||
for script in self.scripts:
|
for script in self.scripts:
|
||||||
try:
|
try:
|
||||||
|
@ -367,7 +367,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
|
||||||
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
|
||||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||||
"filter_nsfw": OptionInfo(False, "Filter NSFW content"),
|
|
||||||
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
|
||||||
}))
|
}))
|
||||||
|
@ -206,12 +206,13 @@ def refresh_available_extensions_from_data(hide_tags):
|
|||||||
if url is None:
|
if url is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
||||||
|
extension_tags = extension_tags + ["installed"] if existing else extension_tags
|
||||||
|
|
||||||
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
|
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
|
||||||
hidden += 1
|
hidden += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
existing = installed_extension_urls.get(normalize_git_url(url), None)
|
|
||||||
|
|
||||||
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
|
||||||
|
|
||||||
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
|
||||||
@ -222,7 +223,11 @@ def refresh_available_extensions_from_data(hide_tags):
|
|||||||
<td>{html.escape(description)}</td>
|
<td>{html.escape(description)}</td>
|
||||||
<td>{install_code}</td>
|
<td>{install_code}</td>
|
||||||
</tr>
|
</tr>
|
||||||
"""
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
for tag in [x for x in extension_tags if x not in tags]:
|
||||||
|
tags[tag] = tag
|
||||||
|
|
||||||
code += """
|
code += """
|
||||||
</tbody>
|
</tbody>
|
||||||
@ -272,7 +277,7 @@ def create_ui():
|
|||||||
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
|
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||||
|
|
||||||
install_result = gr.HTML()
|
install_result = gr.HTML()
|
||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
accelerate
|
accelerate
|
||||||
basicsr
|
basicsr
|
||||||
diffusers
|
|
||||||
fairscale==0.4.4
|
fairscale==0.4.4
|
||||||
fonts
|
fonts
|
||||||
font-roboto
|
font-roboto
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
transformers==4.19.2
|
transformers==4.19.2
|
||||||
diffusers==0.3.0
|
|
||||||
accelerate==0.12.0
|
accelerate==0.12.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
function gradioApp(){
|
function gradioApp() {
|
||||||
return document.getElementsByTagName('gradio-app')[0].shadowRoot;
|
const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot
|
||||||
|
return !!gradioShadowRoot ? gradioShadowRoot : document;
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_uiCurrentTab() {
|
function get_uiCurrentTab() {
|
||||||
@ -82,4 +83,4 @@ function uiElementIsVisible(el) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return isVisible;
|
return isVisible;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user