add override_settings support for infotext API

This commit is contained in:
AUTOMATIC1111 2023-12-30 12:11:09 +03:00
parent bb07cb6a0d
commit ba92135a2b
2 changed files with 54 additions and 22 deletions

View File

@ -341,6 +341,7 @@ class Api:
params = generation_parameters_copypaste.parse_generation_parameters(request.infotext)
handled_fields = {}
for field in generation_parameters_copypaste.paste_fields[tabname]["fields"]:
if not field.api:
continue
@ -355,6 +356,15 @@ class Api:
value = target_type(value)
setattr(request, field.api, value)
handled_fields[field.label] = 1
if request.override_settings is None:
request.override_settings = {}
overriden_settings = generation_parameters_copypaste.get_override_settings(params, skip_fields=handled_fields)
for infotext_text, setting_name, value in overriden_settings:
if setting_name not in request.override_settings:
request.override_settings[setting_name] = value
return params

View File

@ -390,6 +390,48 @@ def create_override_settings_dict(text_pairs):
return res
def get_override_settings(params, *, skip_fields=None):
"""Returns a list of settings overrides from the infotext parameters dictionary.
This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
a list of tuples containing the parameter name, setting name, and new value cast to correct type.
It checks for conditions before adding an override:
- ignores settings that match the current value
- ignores parameter keys present in skip_fields argument.
Example input:
{"Clip skip": "2"}
Example output:
[("Clip skip", "CLIP_stop_at_last_layers", 2)]
"""
res = []
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in (skip_fields or {}):
continue
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
res.append((param_name, setting_name, v))
return res
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
@ -431,29 +473,9 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
already_handled_fields = {key: 1 for _, key in paste_fields}
def paste_settings(params):
vals = {}
vals = get_override_settings(params, skip_fields=already_handled_fields)
mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in already_handled_fields:
continue
v = params.get(param_name, None)
if v is None:
continue
if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
continue
v = shared.opts.cast_value(setting_name, v)
current_value = getattr(shared.opts, setting_name, None)
if v == current_value:
continue
vals[param_name] = v
vals_pairs = [f"{k}: {v}" for k, v in vals.items()]
vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))