mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
added resize seeds and variation seeds features
This commit is contained in:
parent
003b60b94e
commit
b1707553cf
@ -136,7 +136,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
|||||||
color_active = (0, 0, 0)
|
color_active = (0, 0, 0)
|
||||||
color_inactive = (153, 153, 153)
|
color_inactive = (153, 153, 153)
|
||||||
|
|
||||||
pad_left = width * 3 // 4 if len(ver_texts) > 0 else 0
|
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
|
||||||
|
|
||||||
cols = im.width // width
|
cols = im.width // width
|
||||||
rows = im.height // height
|
rows = im.height // height
|
||||||
|
@ -11,7 +11,7 @@ from modules.ui import plaintext_to_html
|
|||||||
import modules.images as images
|
import modules.images as images
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
|
def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
|
||||||
is_inpaint = mode == 1
|
is_inpaint = mode == 1
|
||||||
is_loopback = mode == 2
|
is_loopback = mode == 2
|
||||||
is_upscale = mode == 3
|
is_upscale = mode == 3
|
||||||
@ -34,6 +34,10 @@ def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, ste
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
subseed=subseed,
|
||||||
|
subseed_strength=subseed_strength,
|
||||||
|
seed_resize_from_h=seed_resize_from_h,
|
||||||
|
seed_resize_from_w=seed_resize_from_w,
|
||||||
sampler_index=sampler_index,
|
sampler_index=sampler_index,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
|
@ -29,7 +29,7 @@ def torch_gc():
|
|||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessing:
|
class StableDiffusionProcessing:
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
|
||||||
self.sd_model = sd_model
|
self.sd_model = sd_model
|
||||||
self.outpath_samples: str = outpath_samples
|
self.outpath_samples: str = outpath_samples
|
||||||
self.outpath_grids: str = outpath_grids
|
self.outpath_grids: str = outpath_grids
|
||||||
@ -37,6 +37,10 @@ class StableDiffusionProcessing:
|
|||||||
self.prompt_for_display: str = None
|
self.prompt_for_display: str = None
|
||||||
self.negative_prompt: str = (negative_prompt or "")
|
self.negative_prompt: str = (negative_prompt or "")
|
||||||
self.seed: int = seed
|
self.seed: int = seed
|
||||||
|
self.subseed: int = subseed
|
||||||
|
self.subseed_strength: float = subseed_strength
|
||||||
|
self.seed_resize_from_h: int = seed_resize_from_h
|
||||||
|
self.seed_resize_from_w: int = seed_resize_from_w
|
||||||
self.sampler_index: int = sampler_index
|
self.sampler_index: int = sampler_index
|
||||||
self.batch_size: int = batch_size
|
self.batch_size: int = batch_size
|
||||||
self.n_iter: int = n_iter
|
self.n_iter: int = n_iter
|
||||||
@ -84,23 +88,67 @@ class Processed:
|
|||||||
|
|
||||||
return json.dumps(obj)
|
return json.dumps(obj)
|
||||||
|
|
||||||
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||||
|
def slerp(val, low, high):
|
||||||
|
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
||||||
|
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
||||||
|
omega = torch.acos((low_norm*high_norm).sum(1))
|
||||||
|
so = torch.sin(omega)
|
||||||
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
||||||
|
return res
|
||||||
|
|
||||||
def create_random_tensors(shape, seeds):
|
|
||||||
|
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
|
||||||
xs = []
|
xs = []
|
||||||
for seed in seeds:
|
for i, seed in enumerate(seeds):
|
||||||
torch.manual_seed(seed)
|
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
||||||
|
|
||||||
|
subnoise = None
|
||||||
|
if subseeds is not None:
|
||||||
|
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
||||||
|
torch.manual_seed(subseed)
|
||||||
|
subnoise = torch.randn(noise_shape, device=shared.device)
|
||||||
|
|
||||||
# randn results depend on device; gpu and cpu get different results for same seed;
|
# randn results depend on device; gpu and cpu get different results for same seed;
|
||||||
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
||||||
# but the original script had it like this so I do not dare change it for now because
|
# but the original script had it like this, so I do not dare change it for now because
|
||||||
# it will break everyone's seeds.
|
# it will break everyone's seeds.
|
||||||
xs.append(torch.randn(shape, device=shared.device))
|
torch.manual_seed(seed)
|
||||||
x = torch.stack(xs)
|
noise = torch.randn(noise_shape, device=shared.device)
|
||||||
|
|
||||||
|
if subnoise is not None:
|
||||||
|
#noise = subnoise * subseed_strength + noise * (1 - subseed_strength)
|
||||||
|
noise = slerp(subseed_strength, noise, subnoise)
|
||||||
|
|
||||||
|
if noise_shape != shape:
|
||||||
|
#noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze()
|
||||||
|
# noise_shape = (64, 80)
|
||||||
|
# shape = (64, 72)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
x = torch.randn(shape, device=shared.device)
|
||||||
|
dx = (shape[2] - noise_shape[2]) // 2 # -4
|
||||||
|
dy = (shape[1] - noise_shape[1]) // 2
|
||||||
|
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
||||||
|
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
||||||
|
tx = 0 if dx < 0 else dx
|
||||||
|
ty = 0 if dy < 0 else dy
|
||||||
|
dx = max(-dx, 0)
|
||||||
|
dy = max(-dy, 0)
|
||||||
|
|
||||||
|
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
||||||
|
noise = x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
xs.append(noise)
|
||||||
|
x = torch.stack(xs).to(shared.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def fix_seed(p):
|
||||||
return int(random.randrange(4294967294)) if seed is None or seed == -1 else seed
|
p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed
|
||||||
|
p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed
|
||||||
|
|
||||||
|
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
@ -111,7 +159,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
assert p.prompt is not None
|
assert p.prompt is not None
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
seed = set_seed(p.seed)
|
fix_seed(p)
|
||||||
|
|
||||||
os.makedirs(p.outpath_samples, exist_ok=True)
|
os.makedirs(p.outpath_samples, exist_ok=True)
|
||||||
os.makedirs(p.outpath_grids, exist_ok=True)
|
os.makedirs(p.outpath_grids, exist_ok=True)
|
||||||
@ -125,20 +173,31 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
else:
|
else:
|
||||||
all_prompts = p.batch_size * p.n_iter * [prompt]
|
all_prompts = p.batch_size * p.n_iter * [prompt]
|
||||||
|
|
||||||
if type(seed) == list:
|
if type(p.seed) == list:
|
||||||
all_seeds = seed
|
all_seeds = int(p.seed)
|
||||||
else:
|
else:
|
||||||
all_seeds = [int(seed + x) for x in range(len(all_prompts))]
|
all_seeds = [int(p.seed + x) for x in range(len(all_prompts))]
|
||||||
|
|
||||||
|
if type(p.subseed) == list:
|
||||||
|
all_subseeds = p.subseed
|
||||||
|
else:
|
||||||
|
all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))]
|
||||||
|
|
||||||
def infotext(iteration=0, position_in_batch=0):
|
def infotext(iteration=0, position_in_batch=0):
|
||||||
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
"Sampler": samplers[p.sampler_index].name,
|
"Sampler": samplers[p.sampler_index].name,
|
||||||
"CFG scale": p.cfg_scale,
|
"CFG scale": p.cfg_scale,
|
||||||
"Seed": all_seeds[position_in_batch + iteration * p.batch_size],
|
"Seed": all_seeds[index],
|
||||||
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
||||||
|
"Size": f"{p.width}x{p.height}",
|
||||||
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
||||||
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
||||||
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
||||||
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
||||||
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.extra_generation_params is not None:
|
if p.extra_generation_params is not None:
|
||||||
@ -174,7 +233,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
comments += model_hijack.comments
|
comments += model_hijack.comments
|
||||||
|
|
||||||
# we manually generate all input noises because each one should have a specific seed
|
# we manually generate all input noises because each one should have a specific seed
|
||||||
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
|
x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=all_subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
|
||||||
|
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
@ -231,10 +290,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
output_images.insert(0, grid)
|
output_images.insert(0, grid)
|
||||||
|
|
||||||
if opts.grid_save:
|
if opts.grid_save:
|
||||||
images.save_image(grid, p.outpath_grids, "grid", seed, all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
|
||||||
|
|
||||||
torch_gc()
|
torch_gc()
|
||||||
return Processed(p, output_images, seed, infotext())
|
return Processed(p, output_images, all_seeds[0], infotext())
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||||
|
@ -62,7 +62,6 @@ class State:
|
|||||||
current_image = None
|
current_image = None
|
||||||
current_image_sampling_step = 0
|
current_image_sampling_step = 0
|
||||||
|
|
||||||
|
|
||||||
def interrupt(self):
|
def interrupt(self):
|
||||||
self.interrupted = True
|
self.interrupted = True
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import modules.processing as processing
|
|||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
|
|
||||||
|
|
||||||
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
|
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args):
|
||||||
p = StableDiffusionProcessingTxt2Img(
|
p = StableDiffusionProcessingTxt2Img(
|
||||||
sd_model=shared.sd_model,
|
sd_model=shared.sd_model,
|
||||||
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
|
||||||
@ -14,6 +14,10 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, r
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
subseed=subseed,
|
||||||
|
subseed_strength=subseed_strength,
|
||||||
|
seed_resize_from_h=seed_resize_from_h,
|
||||||
|
seed_resize_from_w=seed_resize_from_w,
|
||||||
sampler_index=sampler_index,
|
sampler_index=sampler_index,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
n_iter=n_iter,
|
n_iter=n_iter,
|
||||||
|
@ -192,6 +192,40 @@ def visit(x, func, path=""):
|
|||||||
func(path + "/" + str(x.label), x)
|
func(path + "/" + str(x.label), x)
|
||||||
|
|
||||||
|
|
||||||
|
def create_seed_inputs():
|
||||||
|
with gr.Row():
|
||||||
|
seed = gr.Number(label='Seed', value=-1)
|
||||||
|
subseed = gr.Number(label='Variation seed', value=-1, visible=False)
|
||||||
|
seed_checkbox = gr.Checkbox(label="Extra", elem_id="subseed_show", value=False)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, visible=False)
|
||||||
|
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0, visible=False)
|
||||||
|
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0, visible=False)
|
||||||
|
|
||||||
|
def change_visiblity(show):
|
||||||
|
|
||||||
|
return {
|
||||||
|
subseed: gr_show(show),
|
||||||
|
subseed_strength: gr_show(show),
|
||||||
|
seed_resize_from_h: gr_show(show),
|
||||||
|
seed_resize_from_w: gr_show(show),
|
||||||
|
}
|
||||||
|
|
||||||
|
seed_checkbox.change(
|
||||||
|
change_visiblity,
|
||||||
|
inputs=[seed_checkbox],
|
||||||
|
outputs=[
|
||||||
|
subseed,
|
||||||
|
subseed_strength,
|
||||||
|
seed_resize_from_h,
|
||||||
|
seed_resize_from_w
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w
|
||||||
|
|
||||||
|
|
||||||
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -220,7 +254,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
|
||||||
seed = gr.Number(label='Seed', value=-1)
|
seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs()
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
|
||||||
@ -260,6 +294,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
batch_size,
|
batch_size,
|
||||||
cfg_scale,
|
cfg_scale,
|
||||||
seed,
|
seed,
|
||||||
|
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
] + custom_inputs,
|
] + custom_inputs,
|
||||||
@ -357,7 +392,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
|
||||||
|
|
||||||
seed = gr.Number(label='Seed', value=-1)
|
seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs()
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
|
||||||
@ -440,6 +475,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
|||||||
denoising_strength,
|
denoising_strength,
|
||||||
denoising_strength_change_factor,
|
denoising_strength_change_factor,
|
||||||
seed,
|
seed,
|
||||||
|
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
resize_mode,
|
resize_mode,
|
||||||
|
@ -46,6 +46,11 @@ titles = {
|
|||||||
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
|
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
|
||||||
|
|
||||||
"Roll": "Add a random artist to the prompt.",
|
"Roll": "Add a random artist to the prompt.",
|
||||||
|
|
||||||
|
"Variation seed": "Seed of a different picture to be mixed into the generation.",
|
||||||
|
"Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).",
|
||||||
|
"Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
||||||
|
"Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution",
|
||||||
}
|
}
|
||||||
|
|
||||||
function gradioApp(){
|
function gradioApp(){
|
||||||
|
@ -50,7 +50,7 @@ class Script(scripts.Script):
|
|||||||
return [put_at_start]
|
return [put_at_start]
|
||||||
|
|
||||||
def run(self, p, put_at_start):
|
def run(self, p, put_at_start):
|
||||||
seed = modules.processing.set_seed(p.seed)
|
modules.processing.fix_seed(p)
|
||||||
|
|
||||||
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@ from collections import namedtuple
|
|||||||
from copy import copy
|
from copy import copy
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
@ -46,18 +48,27 @@ def format_value_add_label(p, opt, x):
|
|||||||
def format_value(p, opt, x):
|
def format_value(p, opt, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def do_nothing(p, x, xs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def format_nothing(p, opt, x):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
|
||||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
|
||||||
|
|
||||||
|
|
||||||
axis_options = [
|
axis_options = [
|
||||||
|
AxisOption("Nothing", str, do_nothing, format_nothing),
|
||||||
AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
|
AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
|
||||||
|
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
|
||||||
|
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
|
||||||
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
|
||||||
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
|
||||||
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
AxisOption("Prompt S/R", str, apply_prompt, format_value),
|
||||||
AxisOption("Sampler", str, apply_sampler, format_value),
|
AxisOption("Sampler", str, apply_sampler, format_value),
|
||||||
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label) # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -90,6 +101,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
|
|||||||
|
|
||||||
|
|
||||||
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
|
||||||
|
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
def title(self):
|
def title(self):
|
||||||
@ -99,17 +111,17 @@ class Script(scripts.Script):
|
|||||||
current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
|
current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="x_type")
|
x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="x_type")
|
||||||
x_values = gr.Textbox(label="X values", visible=False, lines=1)
|
x_values = gr.Textbox(label="X values", visible=False, lines=1)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="y_type")
|
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type")
|
||||||
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
|
||||||
|
|
||||||
return [x_type, x_values, y_type, y_values]
|
return [x_type, x_values, y_type, y_values]
|
||||||
|
|
||||||
def run(self, p, x_type, x_values, y_type, y_values):
|
def run(self, p, x_type, x_values, y_type, y_values):
|
||||||
p.seed = modules.processing.set_seed(p.seed)
|
modules.processing.fix_seed(p)
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
p.batch_count = 1
|
p.batch_count = 1
|
||||||
|
|
||||||
@ -132,6 +144,21 @@ class Script(scripts.Script):
|
|||||||
valslist_ext.append(val)
|
valslist_ext.append(val)
|
||||||
|
|
||||||
valslist = valslist_ext
|
valslist = valslist_ext
|
||||||
|
elif opt.type == float:
|
||||||
|
valslist_ext = []
|
||||||
|
|
||||||
|
for val in valslist:
|
||||||
|
m = re_range_float.fullmatch(val)
|
||||||
|
if m is not None:
|
||||||
|
start = float(m.group(1))
|
||||||
|
end = float(m.group(2))
|
||||||
|
step = float(m.group(3)) if m.group(3) is not None else 1
|
||||||
|
|
||||||
|
valslist_ext += np.arange(start, end + step, step).tolist()
|
||||||
|
else:
|
||||||
|
valslist_ext.append(val)
|
||||||
|
|
||||||
|
valslist = valslist_ext
|
||||||
|
|
||||||
valslist = [opt.type(x) for x in valslist]
|
valslist = [opt.type(x) for x in valslist]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user