Add a prompt order option to XY plot script

This commit is contained in:
DoTheSneedful 2022-10-03 22:20:09 -04:00 committed by AUTOMATIC1111
parent 5ef0baf5ea
commit 1c5604791d

View File

@ -1,5 +1,6 @@
from collections import namedtuple from collections import namedtuple
from copy import copy from copy import copy
from itertools import permutations
import random import random
from PIL import Image from PIL import Image
@ -28,6 +29,27 @@ def apply_prompt(p, x, xs):
p.prompt = p.prompt.replace(xs[0], x) p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x)
def apply_order(p, x, xs):
token_order = []
# Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
search_from_pos = 0
for idx, token in enumerate(x):
original_pos, old_token = token_order[idx]
# Get position of the token again as it will likely change as tokens are being replaced
pos = p.prompt.find(old_token)
if original_pos >= 0:
# Avoid trying to replace what was just replaced by searching later in the prompt string
p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1)
search_from_pos = pos + len(token)
samplers_dict = {} samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers): for i, sampler in enumerate(modules.sd_samplers.samplers):
@ -60,7 +82,8 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x): def format_value(p, opt, x):
if type(x) == float: if type(x) == float:
x = round(x, 8) x = round(x, 8)
if type(x) == type(list()):
x = str(x)
return x return x
def do_nothing(p, x, xs): def do_nothing(p, x, xs):
@ -89,6 +112,7 @@ axis_options = [
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOption("Prompt order", type(list()), apply_order, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
@ -159,8 +183,12 @@ class Script(scripts.Script):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
if opt.type == type(list()):
valslist = [x for x in vals]
else:
valslist = [x.strip() for x in vals.split(",")] valslist = [x.strip() for x in vals.split(",")]
if opt.type == int: if opt.type == int:
valslist_ext = [] valslist_ext = []
@ -212,9 +240,17 @@ class Script(scripts.Script):
return valslist return valslist
x_opt = axis_options[x_type] x_opt = axis_options[x_type]
if x_opt.label == "Prompt order":
x_values = list(permutations([x.strip() for x in x_values.split(",")]))
xs = process_axis(x_opt, x_values) xs = process_axis(x_opt, x_values)
y_opt = axis_options[y_type] y_opt = axis_options[y_type]
if y_opt.label == "Prompt order":
y_values = list(permutations([y.strip() for y in y_values.split(",")]))
ys = process_axis(y_opt, y_values) ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list): def fix_axis_seeds(axis_opt, axis_list):