mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge branch 'master' into master
This commit is contained in:
commit
e05e46aa3f
@ -14,3 +14,9 @@ def get_optimal_device():
|
|||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
|
|
||||||
return cpu
|
return cpu
|
||||||
|
|
||||||
|
|
||||||
|
def torch_gc():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from modules import processing, shared, images
|
from modules import processing, shared, images, devices
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
import modules.gfpgan_model
|
import modules.gfpgan_model
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
@ -11,7 +11,7 @@ cached_images = {}
|
|||||||
|
|
||||||
|
|
||||||
def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||||
processing.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
info = ""
|
info = ""
|
||||||
|
@ -243,16 +243,32 @@ def sanitize_filename_part(text):
|
|||||||
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
|
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
|
||||||
|
|
||||||
|
|
||||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters'):
|
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters', process_info=None):
|
||||||
# would be better to add this as an argument in future, but will do for now
|
# would be better to add this as an argument in future, but will do for now
|
||||||
is_a_grid = basename != ""
|
is_a_grid = basename != ""
|
||||||
|
|
||||||
if short_filename or prompt is None or seed is None:
|
if short_filename or prompt is None or seed is None:
|
||||||
file_decoration = ""
|
file_decoration = ""
|
||||||
elif opts.save_to_dirs:
|
elif opts.save_to_dirs:
|
||||||
file_decoration = f"-{seed}"
|
file_decoration = opts.samples_filename_format or "[SEED]"
|
||||||
else:
|
else:
|
||||||
file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}"
|
file_decoration = opts.samples_filename_format or "[SEED]-[PROMPT]"
|
||||||
|
#file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}"
|
||||||
|
|
||||||
|
#Add new filenames tags here
|
||||||
|
file_decoration = "-" + file_decoration
|
||||||
|
if seed is not None:
|
||||||
|
file_decoration = file_decoration.replace("[SEED]", str(seed))
|
||||||
|
if prompt is not None:
|
||||||
|
file_decoration = file_decoration.replace("[PROMPT]", sanitize_filename_part(prompt)[:128])
|
||||||
|
file_decoration = file_decoration.replace("[PROMPT_SPACES]", prompt.translate({ord(x): '' for x in invalid_filename_chars})[:128])
|
||||||
|
if process_info is not None:
|
||||||
|
file_decoration = file_decoration.replace("[STEPS]", str(process_info.steps))
|
||||||
|
file_decoration = file_decoration.replace("[CFG]", str(process_info.cfg_scale))
|
||||||
|
file_decoration = file_decoration.replace("[WIDTH]", str(process_info.width))
|
||||||
|
file_decoration = file_decoration.replace("[HEIGHT]", str(process_info.height))
|
||||||
|
file_decoration = file_decoration.replace("[SAMPLER]", str(process_info.sampler))
|
||||||
|
|
||||||
|
|
||||||
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
if extension == 'png' and opts.enable_pnginfo and info is not None:
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
@ -3,6 +3,7 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageChops
|
from PIL import Image, ImageOps, ImageChops
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -131,7 +132,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
|
|||||||
upscaler = shared.sd_upscalers[upscaler_index]
|
upscaler = shared.sd_upscalers[upscaler_index]
|
||||||
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
|
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
|
||||||
|
|
||||||
processing.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
|
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
|
||||||
|
|
||||||
@ -179,7 +180,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
|
|||||||
result_images.append(combined_image)
|
result_images.append(combined_image)
|
||||||
|
|
||||||
if opts.samples_save:
|
if opts.samples_save:
|
||||||
images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.grid_format, info=initial_info)
|
images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.samples_format, info=initial_info)
|
||||||
|
|
||||||
processed = Processed(p, result_images, seed, initial_info)
|
processed = Processed(p, result_images, seed, initial_info)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
@ -6,7 +7,6 @@ import re
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
@ -26,6 +26,7 @@ class InterrogateModels:
|
|||||||
clip_model = None
|
clip_model = None
|
||||||
clip_preprocess = None
|
clip_preprocess = None
|
||||||
categories = None
|
categories = None
|
||||||
|
dtype = None
|
||||||
|
|
||||||
def __init__(self, content_dir):
|
def __init__(self, content_dir):
|
||||||
self.categories = []
|
self.categories = []
|
||||||
@ -60,14 +61,20 @@ class InterrogateModels:
|
|||||||
def load(self):
|
def load(self):
|
||||||
if self.blip_model is None:
|
if self.blip_model is None:
|
||||||
self.blip_model = self.load_blip_model()
|
self.blip_model = self.load_blip_model()
|
||||||
|
if not shared.cmd_opts.no_half:
|
||||||
|
self.blip_model = self.blip_model.half()
|
||||||
|
|
||||||
self.blip_model = self.blip_model.to(shared.device)
|
self.blip_model = self.blip_model.to(shared.device)
|
||||||
|
|
||||||
if self.clip_model is None:
|
if self.clip_model is None:
|
||||||
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
self.clip_model, self.clip_preprocess = self.load_clip_model()
|
||||||
|
if not shared.cmd_opts.no_half:
|
||||||
|
self.clip_model = self.clip_model.half()
|
||||||
|
|
||||||
self.clip_model = self.clip_model.to(shared.device)
|
self.clip_model = self.clip_model.to(shared.device)
|
||||||
|
|
||||||
|
self.dtype = next(self.clip_model.parameters()).dtype
|
||||||
|
|
||||||
def unload(self):
|
def unload(self):
|
||||||
if not shared.opts.interrogate_keep_models_in_memory:
|
if not shared.opts.interrogate_keep_models_in_memory:
|
||||||
if self.clip_model is not None:
|
if self.clip_model is not None:
|
||||||
@ -76,14 +83,14 @@ class InterrogateModels:
|
|||||||
if self.blip_model is not None:
|
if self.blip_model is not None:
|
||||||
self.blip_model = self.blip_model.to(devices.cpu)
|
self.blip_model = self.blip_model.to(devices.cpu)
|
||||||
|
|
||||||
|
devices.torch_gc()
|
||||||
|
|
||||||
def rank(self, image_features, text_array, top_count=1):
|
def rank(self, image_features, text_array, top_count=1):
|
||||||
import clip
|
import clip
|
||||||
|
|
||||||
top_count = min(top_count, len(text_array))
|
top_count = min(top_count, len(text_array))
|
||||||
text_tokens = clip.tokenize([text for text in text_array]).cuda()
|
text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
|
||||||
with torch.no_grad():
|
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||||
text_features = self.clip_model.encode_text(text_tokens).float()
|
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
similarity = torch.zeros((1, len(text_array))).to(shared.device)
|
||||||
@ -94,13 +101,12 @@ class InterrogateModels:
|
|||||||
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
||||||
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
||||||
|
|
||||||
|
|
||||||
def generate_caption(self, pil_image):
|
def generate_caption(self, pil_image):
|
||||||
gpu_image = transforms.Compose([
|
gpu_image = transforms.Compose([
|
||||||
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||||
])(pil_image).unsqueeze(0).to(shared.device)
|
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
|
||||||
@ -116,22 +122,23 @@ class InterrogateModels:
|
|||||||
caption = self.generate_caption(pil_image)
|
caption = self.generate_caption(pil_image)
|
||||||
res = caption
|
res = caption
|
||||||
|
|
||||||
images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device)
|
images = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||||
image_features = self.clip_model.encode_image(images).float()
|
with torch.no_grad(), precision_scope("cuda"):
|
||||||
|
image_features = self.clip_model.encode_image(images).type(self.dtype)
|
||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
if shared.opts.interrogate_use_builtin_artists:
|
if shared.opts.interrogate_use_builtin_artists:
|
||||||
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
|
||||||
|
|
||||||
res += ", " + artist[0]
|
res += ", " + artist[0]
|
||||||
|
|
||||||
for name, topn, items in self.categories:
|
for name, topn, items in self.categories:
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, items, top_count=topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
res += ", " + match
|
res += ", " + match
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(f"Error interrogating", file=sys.stderr)
|
print(f"Error interrogating", file=sys.stderr)
|
||||||
|
@ -10,6 +10,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
|
from modules import devices
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.sd_samplers import samplers, samplers_for_img2img
|
from modules.sd_samplers import samplers, samplers_for_img2img
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
@ -23,11 +24,6 @@ opt_C = 4
|
|||||||
opt_f = 8
|
opt_f = 8
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
||||||
@ -157,7 +153,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
||||||
|
|
||||||
assert p.prompt is not None
|
assert p.prompt is not None
|
||||||
torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
fix_seed(p)
|
fix_seed(p)
|
||||||
|
|
||||||
@ -258,7 +254,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
if p.restore_faces:
|
if p.restore_faces:
|
||||||
torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
x_sample = modules.face_restoration.restore_faces(x_sample)
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
||||||
|
|
||||||
@ -279,7 +275,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
image = image.convert('RGB')
|
image = image.convert('RGB')
|
||||||
|
|
||||||
if opts.samples_save and not p.do_not_save_samples:
|
if opts.samples_save and not p.do_not_save_samples:
|
||||||
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i))
|
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), process_info = Processed(p, output_images, all_seeds[0], infotext()))
|
||||||
|
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
@ -297,7 +293,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
torch_gc()
|
devices.torch_gc()
|
||||||
return Processed(p, output_images, all_seeds[0], infotext())
|
return Processed(p, output_images, all_seeds[0], infotext())
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,6 +94,7 @@ class Options:
|
|||||||
data = None
|
data = None
|
||||||
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
|
hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None
|
||||||
data_labels = {
|
data_labels = {
|
||||||
|
"samples_filename_format": OptionInfo("", "Samples filename format using following tags: [STEPS],[CFG],[PROMPT],[PROMPT_SPACES],[WIDTH],[HEIGHT],[SAMPLER],[SEED]. Leave blank for default."),
|
||||||
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to two directories below", component_args=hide_dirs),
|
"outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to two directories below", component_args=hide_dirs),
|
||||||
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
|
||||||
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
|
||||||
|
@ -4,7 +4,7 @@ import modules.scripts as scripts
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from modules import images, processing
|
from modules import images, processing, devices
|
||||||
from modules.processing import Processed, process_images
|
from modules.processing import Processed, process_images
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ class Script(scripts.Script):
|
|||||||
mask.height - down - (mask_blur//2 if down > 0 else 0)
|
mask.height - down - (mask_blur//2 if down > 0 else 0)
|
||||||
), fill="black")
|
), fill="black")
|
||||||
|
|
||||||
processing.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
||||||
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
|
||||||
|
Loading…
Reference in New Issue
Block a user