Merge branch 'dev' into torch

This commit is contained in:
AUTOMATIC1111 2023-04-29 11:58:54 +03:00 committed by GitHub
commit f54cd3f158
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 332 additions and 99 deletions

View File

@ -8,7 +8,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_lora
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

View File

@ -52,5 +52,5 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))

View File

@ -5,11 +5,15 @@ import traceback
import PIL.Image
import numpy as np
import torch
from tqdm import tqdm
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from scunet_model_arch import SCUNet as net
from modules.shared import opts
from modules import images
class UpscalerScuNET(modules.upscaler.Upscaler):
@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
@staticmethod
@torch.no_grad()
def tiled_inference(img, model):
# test the image tile by tile
h, w = img.shape[2:]
tile = opts.SCUNET_tile
tile_overlap = opts.SCUNET_tile_overlap
if tile == 0:
return model(img)
device = devices.get_device_for('scunet')
assert tile % 8 == 0, "tile size should be a multiple of window_size"
sf = 1
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
def do_upscale(self, img: PIL.Image.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
tile = opts.SCUNET_tile
h, w = img.height, img.width
np_img = np.array(img)
np_img = np_img[:, :, ::-1] # RGB to BGR
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
if tile > h or tile > w:
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
output = np_output.transpose((1, 2, 0)) # CHW to HWC
output = output[:, :, ::-1] # BGR to RGB
return PIL.Image.fromarray((output * 255).astype(np.uint8))
def load_model(self, path: str):
device = devices.get_device_for('scunet')
@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
model = model.to(device)
return model

View File

@ -161,14 +161,6 @@ addContextMenuEventListener = initResponse[2];
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
})();
//End example Context Menu Items

View File

@ -17,7 +17,7 @@ function keyupEditAttention(event){
// Find opening parenthesis around current cursor
const before = text.substring(0, selectionStart);
let beforeParen = before.lastIndexOf(OPEN);
if (beforeParen == -1) return false;
if (beforeParen == -1) return false;
let beforeParenClose = before.lastIndexOf(CLOSE);
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
@ -27,7 +27,7 @@ function keyupEditAttention(event){
// Find closing parenthesis around current cursor
const after = text.substring(selectionStart);
let afterParen = after.indexOf(CLOSE);
if (afterParen == -1) return false;
if (afterParen == -1) return false;
let afterParenOpen = after.indexOf(OPEN);
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
afterParen = after.indexOf(CLOSE, afterParen + 1);
@ -43,10 +43,28 @@ function keyupEditAttention(event){
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
function selectCurrentWord(){
if (selectionStart !== selectionEnd) return false;
const delimiters = opts.keyedit_delimiters + " \r\n\t";
// seek backward until to find beggining
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
selectionStart--;
}
// seek forward to find end
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
selectionEnd++;
}
// If the user hasn't selected anything, let's select their current parenthesis block
if(! selectCurrentParenthesisBlock('<', '>')){
selectCurrentParenthesisBlock('(', ')')
target.setSelectionRange(selectionStart, selectionEnd);
return true;
}
// If the user hasn't selected anything, let's select their current parenthesis block or word
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
selectCurrentWord();
}
event.preventDefault();
@ -81,7 +99,13 @@ function keyupEditAttention(event){
weight = parseFloat(weight.toPrecision(12));
if(String(weight).length == 1) weight += ".0"
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
if (closeCharacter == ')' && weight == 1) {
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
selectionStart--;
selectionEnd--;
} else {
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
}
target.focus();
target.value = text;
@ -93,4 +117,4 @@ function keyupEditAttention(event){
addEventListener('keydown', (event) => {
keyupEditAttention(event);
});
});

View File

@ -16,7 +16,7 @@ onUiUpdate(function(){
let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
let selectedTab = gradioApp().querySelector('#tabs div button')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
});

View File

@ -251,8 +251,11 @@ document.addEventListener("DOMContentLoaded", function() {
modal.appendChild(modalNext)
gradioApp().appendChild(modal)
try {
gradioApp().appendChild(modal);
} catch (e) {
gradioApp().body.appendChild(modal);
}
document.body.appendChild(modal);

View File

@ -138,7 +138,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
return
}
if(elapsedFromStart > 5 && !res.queued && !res.active){
if(elapsedFromStart > 40 && !res.queued && !res.active){
removeProgressBar()
return
}

View File

@ -6,7 +6,6 @@ import uvicorn
import gradio as gr
from threading import Lock
from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, Request, Response
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
@ -395,16 +394,11 @@ class Api:
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
image_list = reqDict.pop('imageList', [])
image_folder = [decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

View File

@ -92,14 +92,18 @@ def cond_cast_float(input):
def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed)
if device.type == 'mps':
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_without_seed(shape):
if device.type == 'mps':
from modules.shared import opts
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)

View File

@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
def activate(self, p, params_list):
additional = shared.opts.sd_hypernetwork
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

View File

@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
restore_old_hires_fix_params(res)
# Missing RNG means the default was set, which is GPU RNG
if "RNG" not in res:
res["RNG"] = "GPU"
return res
@ -304,6 +308,7 @@ infotext_to_setting_name_mapping = [
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
('RNG', 'randn_source'),
]

View File

@ -352,6 +352,7 @@ class FilenameGenerator:
'prompt_no_styles': lambda self: self.prompt_no_style(),
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
'prompt_words': lambda self: self.prompt_words(),
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
}
default_time_format = '%Y%m%d%H%M%S'

View File

@ -151,7 +151,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
override_settings=override_settings,
)
p.scripts = modules.scripts.scripts_txt2img
p.scripts = modules.scripts.scripts_img2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:

View File

@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
category_types = ["artists", "flavors", "mediums", "movements"]
try:
os.makedirs(tmpdir)
os.makedirs(tmpdir, exist_ok=True)
for category_type in category_types:
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
os.rename(tmpdir, content_dir)
@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
os.remove(tmpdir)
os.removedirs(tmpdir)
class InterrogateModels:

View File

@ -18,9 +18,15 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
if extras_mode == 1:
for img in image_folder:
image = Image.open(img)
if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
image_data.append(image)
image_names.append(os.path.splitext(img.orig_name)[0])
image_names.append(fn)
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'

View File

@ -3,6 +3,7 @@ import math
import os
import sys
import warnings
import hashlib
import torch
import numpy as np
@ -476,6 +477,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": (opts.randn_source if opts.randn_source != "GPU" else None)
}
generation_params.update(p.extra_generation_params)
@ -1007,6 +1010,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.color_corrections = []
imgs = []
for img in self.init_images:
# Save init image
if opts.save_init_img:
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
image = images.flatten(img, opts.img2img_background_color)
if crop_region is None and self.resize_mode != 3:

View File

@ -60,3 +60,13 @@ def store_latent(decoded):
class InterruptedException(BaseException):
pass
if opts.randn_source == "CPU":
import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn

View File

@ -190,7 +190,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
if x.device.type == 'mps':
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)

View File

@ -39,6 +39,7 @@ restricted_opts = {
"outdir_grids",
"outdir_txt2img_grids",
"outdir_save",
"outdir_init_images"
}
ui_reorder_categories = [
@ -253,6 +254,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
@ -268,6 +270,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
}))
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
@ -283,6 +286,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
}))
options_templates.update(options_section(('face-restoration', "Face restoration"), {
@ -331,6 +336,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
@ -361,7 +367,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
}))
options_templates.update(options_section(('ui', "User interface"), {
@ -382,6 +388,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),

View File

@ -171,8 +171,8 @@ def create_seed_inputs(target_interface):
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
seed.style(container=False)
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
@ -468,7 +468,7 @@ def create_ui():
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
if opts.dimensions_and_batch_together:
with gr.Column(elem_id="txt2img_column_batch"):
@ -1204,7 +1204,7 @@ def create_ui():
with gr.Column(elem_id='ti_gallery_container'):
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
@ -1705,7 +1705,7 @@ def create_ui():
if init_field is not None:
init_field(saved_value)
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
apply_field(x, 'visible')
if type(x) == gr.Slider:

View File

@ -125,7 +125,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
generation_info = None
with gr.Column():

View File

@ -13,7 +13,7 @@ def create_ui():
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")

View File

@ -5,7 +5,7 @@ basicsr
fonts
font-roboto
gfpgan
gradio==3.23
gradio==3.27
invisible-watermark
numpy
omegaconf

View File

@ -3,7 +3,7 @@ transformers==4.25.1
accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.23
gradio==3.27
numpy==1.23.5
Pillow==9.4.0
realesrgan==0.3.0

View File

@ -1,9 +1,40 @@
import modules.scripts as scripts
import gradio as gr
import ast
import copy
from modules.processing import Processed
from modules.shared import opts, cmd_opts, state
def convertExpr2Expression(expr):
expr.lineno = 0
expr.col_offset = 0
result = ast.Expression(expr.value, lineno=0, col_offset = 0)
return result
def exec_with_return(code, module):
"""
like exec() but can return values
https://stackoverflow.com/a/52361938/5862977
"""
code_ast = ast.parse(code)
init_ast = copy.deepcopy(code_ast)
init_ast.body = code_ast.body[:-1]
last_ast = copy.deepcopy(code_ast)
last_ast.body = code_ast.body[-1:]
exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
if type(last_ast.body[0]) == ast.Expr:
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
else:
exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
class Script(scripts.Script):
def title(self):
@ -13,12 +44,23 @@ class Script(scripts.Script):
return cmd_opts.allow_code
def ui(self, is_img2img):
code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
example = """from modules.processing import process_images
return [code]
p.width = 768
p.height = 768
p.batch_size = 2
p.steps = 10
return process_images(p)
"""
def run(self, p, code):
code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
return [code, indent_level]
def run(self, p, code, indent_level):
assert cmd_opts.allow_code, '--allow-code option must be enabled'
display_result_data = [[], -1, ""]
@ -29,13 +71,20 @@ class Script(scripts.Script):
display_result_data[2] = i
from types import ModuleType
compiled = compile(code, '', 'exec')
module = ModuleType("testmodule")
module.__dict__.update(globals())
module.p = p
module.display = display
exec(compiled, module.__dict__)
indent = " " * indent_level
indented = code.replace('\n', '\n' + indent)
body = f"""def __webuitemp__():
{indent}{indented}
__webuitemp__()"""
result = exec_with_return(body, module)
if isinstance(result, Processed):
return result
return Processed(p, *display_result_data)

View File

@ -4,8 +4,8 @@ import numpy as np
from modules import scripts_postprocessing, shared
import gradio as gr
from modules.ui_components import FormRow
from modules.ui_components import FormRow, ToolButton
from modules.ui import switch_values_symbol
upscale_cache = {}
@ -25,9 +25,12 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
with FormRow():
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with gr.Column(elem_id="upscaling_column_size", scale=4):
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
with FormRow():
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
@ -36,6 +39,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])

View File

@ -374,16 +374,19 @@ class Script(scripts.Script):
with gr.Row():
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
with gr.Row():
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
with gr.Row(variant="compact", elem_id="axis_options"):
@ -401,54 +404,74 @@ class Script(scripts.Script):
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
xy_swap_args = [x_type, x_values, y_type, y_values]
xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
yz_swap_args = [y_type, y_values, z_type, z_values]
yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
xz_swap_args = [x_type, x_values, z_type, z_values]
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
def fill(x_type):
axis = self.current_axis_options[x_type]
return ", ".join(axis.choices()) if axis.choices else gr.update()
return axis.choices() if axis.choices else gr.update()
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
def select_axis(x_type):
return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
def select_axis(axis_type,axis_values_dropdown):
choices = self.current_axis_options[axis_type].choices
has_choices = choices is not None
current_values = axis_values_dropdown
if has_choices:
choices = choices()
if isinstance(current_values,str):
current_values = current_values.split(",")
current_values = list(filter(lambda x: x in choices, current_values))
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
def get_dropdown_update_from_params(axis,params):
val_key = axis + " Values"
vals = params.get(val_key,"")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
return gr.update(value = valslist)
self.infotext_fields = (
(x_type, "X Type"),
(x_values, "X Values"),
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
(y_type, "Y Type"),
(y_values, "Y Values"),
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
(z_type, "Z Type"),
(z_values, "Z Values"),
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
)
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
if not opts.return_grid:
p.batch_size = 1
def process_axis(opt, vals):
def process_axis(opt, vals, vals_dropdown):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.choices is not None:
valslist = vals_dropdown
else:
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
if opt.type == int:
valslist_ext = []
@ -506,13 +529,19 @@ class Script(scripts.Script):
return valslist
x_opt = self.current_axis_options[x_type]
xs = process_axis(x_opt, x_values)
if x_opt.choices is not None:
x_values = ",".join(x_values_dropdown)
xs = process_axis(x_opt, x_values, x_values_dropdown)
y_opt = self.current_axis_options[y_type]
ys = process_axis(y_opt, y_values)
if y_opt.choices is not None:
y_values = ",".join(y_values_dropdown)
ys = process_axis(y_opt, y_values, y_values_dropdown)
z_opt = self.current_axis_options[z_type]
zs = process_axis(z_opt, z_values)
if z_opt.choices is not None:
z_values = ",".join(z_values_dropdown)
zs = process_axis(z_opt, z_values, z_values_dropdown)
# this could be moved to common code, but unlikely to be ever triggered anywhere else
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes

View File

@ -312,6 +312,10 @@ div.dimensions-tools{
align-content: center;
}
div#extras_scale_to_tab div.form{
flex-direction: row;
}
#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
height: 480px !important;
max-height: 480px !important;

View File

@ -69,6 +69,46 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def check_versions():
if shared.cmd_opts.skip_version_check:
return
@ -101,6 +141,8 @@ Use --skip-version-check commandline argument to disable this check.
def initialize():
fix_asyncio_event_loop_policy()
check_versions()
extensions.list_extensions()
@ -128,9 +170,6 @@ def initialize():
modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")