scheduler selection in main UI

This commit is contained in:
AUTOMATIC1111 2024-03-20 09:17:11 +03:00
parent 060e55dfe3
commit 25cd53d775
11 changed files with 175 additions and 108 deletions

View File

@ -146,7 +146,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
return batch_results
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5
@ -193,10 +193,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
prompt=prompt,
negative_prompt=negative_prompt,
styles=prompt_styles,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,

View File

@ -152,6 +152,7 @@ class StableDiffusionProcessing:
seed_resize_from_w: int = -1
seed_enable_extras: bool = True
sampler_name: str = None
scheduler: str = None
batch_size: int = 1
n_iter: int = 1
steps: int = 50
@ -721,6 +722,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params = {
"Steps": p.steps,
"Sampler": p.sampler_name,
"Schedule type": p.scheduler,
"CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],

View File

@ -0,0 +1,79 @@
import gradio as gr
import functools
from modules import scripts, sd_samplers, sd_schedulers, shared
from modules.infotext_utils import PasteField
from modules.ui_components import FormRow, FormGroup
def get_sampler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
def get_scheduler_from_infotext(d: dict):
return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
@functools.cache
def get_sampler_and_scheduler(sampler_name, scheduler_name):
default_sampler = sd_samplers.samplers[0]
found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
name = sampler_name or default_sampler.name
for scheduler in sd_schedulers.schedulers:
name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
for name_option in name_options:
if name.endswith(" " + name_option):
found_scheduler = scheduler
name = name[0:-(len(name_option) + 1)]
break
sampler = sd_samplers.all_samplers_map.get(name, default_sampler)
# revert back to Automatic if it's the default scheduler for the selected sampler
if sampler.options.get('scheduler', None) == found_scheduler.name:
found_scheduler = sd_schedulers.schedulers[0]
return sampler.name, found_scheduler.label
class ScriptSampler(scripts.ScriptBuiltinUI):
section = "sampler"
def __init__(self):
self.steps = None
self.sampler_name = None
self.scheduler = None
def title(self):
return "Sampler"
def ui(self, is_img2img):
sampler_names = [x.name for x in sd_samplers.visible_samplers()]
scheduler_names = [x.label for x in sd_schedulers.schedulers]
if shared.opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{self.tabname}"):
self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{self.tabname}"):
self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20)
self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0])
self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0])
self.infotext_fields = [
PasteField(self.steps, "Steps", api="steps"),
PasteField(self.sampler_name, get_sampler_from_infotext, api="sampler_name"),
PasteField(self.scheduler, get_scheduler_from_infotext, api="scheduler"),
]
return self.steps, self.sampler_name, self.scheduler
def setup(self, p, steps, sampler_name, scheduler):
p.steps = steps
p.sampler_name = sampler_name
p.scheduler = scheduler

View File

@ -352,6 +352,9 @@ class ScriptBuiltinUI(Script):
return f'{tabname}{item_id}'
def show(self, is_img2img):
return AlwaysVisible
current_basedir = paths.script_path

View File

@ -1,7 +1,10 @@
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
from __future__ import annotations
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common
# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
sample_to_image = sd_samplers_common.sample_to_image
all_samplers = [
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
@ -10,8 +13,8 @@ all_samplers = [
]
all_samplers_map = {x.name: x for x in all_samplers}
samplers = []
samplers_for_img2img = []
samplers: list[sd_samplers_common.SamplerData] = []
samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
samplers_map = {}
samplers_hidden = {}
@ -57,4 +60,8 @@ def visible_sampler_names():
return [x.name for x in samplers if x.name not in samplers_hidden]
def visible_samplers():
return [x for x in samplers if x.name not in samplers_hidden]
set_samplers()

View File

@ -1,12 +0,0 @@
import torch
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
sigs = [
inner_model.t_to_sigma(ts)
for ts in torch.linspace(start, end, n)[:-1]
]
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)

View File

@ -1,41 +1,28 @@
import torch
import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
from modules.sd_samplers_custom_schedulers import sgm_uniform
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_exp'], {'scheduler': 'exponential', "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', "brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {"uses_ensd": True}),
('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {"second_order": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True, "second_order": True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {"brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_ka'], {'scheduler': 'karras', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 2M SDE Heun Exponential', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun_exp'], {'scheduler': 'exponential', "brownian_noise": True, "solver_type": "heun"}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "second_order": True}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', "second_order": True}),
]
@ -59,13 +46,7 @@ sampler_extra_params = {
}
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
k_diffusion_scheduler = {
'Automatic': None,
'karras': k_diffusion.sampling.get_sigmas_karras,
'exponential': k_diffusion.sampling.get_sigmas_exponential,
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential,
'sgm_uniform' : sgm_uniform,
}
k_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}
class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@ -98,47 +79,44 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
steps += 1 if discard_next_to_last_sigma else 0
scheduler_name = p.scheduler or 'Automatic'
if scheduler_name == 'Automatic':
scheduler_name = self.config.options.get('scheduler', None)
scheduler = sd_schedulers.schedulers_map.get(scheduler_name)
m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif opts.k_sched_type != "Automatic":
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)
sigmas_kwargs = {
'sigma_min': sigma_min,
'sigma_max': sigma_max,
}
elif scheduler is None or scheduler.function is None:
sigmas = self.model_wrap.get_sigmas(steps)
else:
sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}
sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
p.extra_generation_params["Schedule type"] = opts.k_sched_type
p.extra_generation_params["Schedule type"] = scheduler.label
if opts.sigma_min != m_sigma_min and opts.sigma_min != 0:
if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:
sigmas_kwargs['sigma_min'] = opts.sigma_min
p.extra_generation_params["Schedule min sigma"] = opts.sigma_min
if opts.sigma_max != m_sigma_max and opts.sigma_max != 0:
if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:
sigmas_kwargs['sigma_max'] = opts.sigma_max
p.extra_generation_params["Schedule max sigma"] = opts.sigma_max
default_rho = 1. if opts.k_sched_type == "polyexponential" else 7.
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:
sigmas_kwargs['rho'] = opts.rho
p.extra_generation_params["Schedule rho"] = opts.rho
if opts.k_sched_type == 'sgm_uniform':
if scheduler.need_inner_model:
sigmas_kwargs['inner_model'] = self.model_wrap
if scheduler.name == "sgm_uniform": # XXX check this
# Ensure the "step" will be target step + 1
steps += 1 if not discard_next_to_last_sigma else 0
sigmas_kwargs['inner_model'] = self.model_wrap
sigmas_kwargs.pop('rho', None)
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':
m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=shared.device)
if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])

38
modules/sd_schedulers.py Normal file
View File

@ -0,0 +1,38 @@
import dataclasses
import torch
import k_diffusion
@dataclasses.dataclass
class Scheduler:
name: str
label: str
function: any
default_rho: float = -1
need_inner_model: bool = False
aliases: list = None
def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
start = inner_model.sigma_to_t(torch.tensor(sigma_max))
end = inner_model.sigma_to_t(torch.tensor(sigma_min))
sigs = [
inner_model.t_to_sigma(ts)
for ts in torch.linspace(start, end, n)[:-1]
]
sigs += [0.0]
return torch.FloatTensor(sigs).to(device)
schedulers = [
Scheduler('automatic', 'Automatic', None),
Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),
Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),
Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),
Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=["SGMUniform"]),
]
schedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}

View File

@ -368,7 +368,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential", "sgm_uniform"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),

View File

@ -11,7 +11,7 @@ from PIL import Image
import gradio as gr
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):
override_settings = create_override_settings_dict(override_settings_texts)
if force_enable_hr:
@ -24,10 +24,8 @@ def txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, ne
prompt=prompt,
styles=prompt_styles,
negative_prompt=negative_prompt,
sampler_name=sampler_name,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,

View File

@ -12,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401
from modules import gradio_extensons, sd_schedulers # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
from modules.paths import script_path
@ -229,19 +229,6 @@ def create_output_panel(tabname, outdir, toprow=None):
return ui_common.create_output_panel(tabname, outdir, toprow)
def create_sampler_and_steps_selection(choices, tabname):
if opts.samplers_in_dropdown:
with FormRow(elem_id=f"sampler_selection_{tabname}"):
sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
else:
with FormGroup(elem_id=f"sampler_selection_{tabname}"):
steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20)
sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0])
return steps, sampler_name
def ordered_ui_categories():
user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}
@ -295,9 +282,6 @@ def create_ui():
if category == "prompt":
toprow.create_inline_toprow_prompts()
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="txt2img_column_size", scale=4):
@ -396,8 +380,6 @@ def create_ui():
toprow.prompt,
toprow.negative_prompt,
toprow.ui_styles.dropdown,
steps,
sampler_name,
batch_count,
batch_size,
cfg_scale,
@ -461,8 +443,6 @@ def create_ui():
txt2img_paste_fields = [
PasteField(toprow.prompt, "Prompt", api="prompt"),
PasteField(toprow.negative_prompt, "Negative prompt", api="negative_prompt"),
PasteField(steps, "Steps", api="steps"),
PasteField(sampler_name, "Sampler", api="sampler_name"),
PasteField(cfg_scale, "CFG scale", api="cfg_scale"),
PasteField(width, "Size-1", api="width"),
PasteField(height, "Size-2", api="height"),
@ -488,11 +468,13 @@ def create_ui():
paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
))
steps = scripts.scripts_txt2img.script('Sampler').steps
txt2img_preview_params = [
toprow.prompt,
toprow.negative_prompt,
steps,
sampler_name,
scripts.scripts_txt2img.script('Sampler').sampler_name,
cfg_scale,
scripts.scripts_txt2img.script('Seed').seed,
width,
@ -623,9 +605,6 @@ def create_ui():
with FormRow():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
if category == "sampler":
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
elif category == "dimensions":
with FormRow():
with gr.Column(elem_id="img2img_column_size", scale=4):
@ -754,8 +733,6 @@ def create_ui():
inpaint_color_sketch_orig,
init_img_inpaint,
init_mask_inpaint,
steps,
sampler_name,
mask_blur,
mask_alpha,
inpainting_fill,
@ -840,6 +817,8 @@ def create_ui():
**interrogate_args,
)
steps = scripts.scripts_img2img.script('Sampler').steps
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])
toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])
@ -848,8 +827,6 @@ def create_ui():
img2img_paste_fields = [
(toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_name, "Sampler"),
(cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"),
(width, "Size-1"),