mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Allow use of mutiple styles csv files
* https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14122 Fix edge case where style text has multiple {prompt} placeholders * https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/14005
This commit is contained in:
parent
f0f100e67b
commit
26a0c29587
@ -1,4 +1,5 @@
|
|||||||
import csv
|
import csv
|
||||||
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
@ -10,6 +11,23 @@ class PromptStyle(typing.NamedTuple):
|
|||||||
name: str
|
name: str
|
||||||
prompt: str
|
prompt: str
|
||||||
negative_prompt: str
|
negative_prompt: str
|
||||||
|
path: str = None
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Iterating through a list of regular expressions and replacement strings, we
|
||||||
|
clean up the prompt and style text to make it easier to match against each
|
||||||
|
other.
|
||||||
|
"""
|
||||||
|
re_list = [
|
||||||
|
("multiple commas", re.compile("(,+\s+)+,?"), ", "),
|
||||||
|
("multiple spaces", re.compile("\s{2,}"), " "),
|
||||||
|
]
|
||||||
|
for _, regex, replace in re_list:
|
||||||
|
text = regex.sub(replace, text)
|
||||||
|
|
||||||
|
return text.strip(", ")
|
||||||
|
|
||||||
|
|
||||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||||
@ -26,41 +44,64 @@ def apply_styles_to_prompt(prompt, styles):
|
|||||||
for style in styles:
|
for style in styles:
|
||||||
prompt = merge_prompts(style, prompt)
|
prompt = merge_prompts(style, prompt)
|
||||||
|
|
||||||
return prompt
|
return clean_text(prompt)
|
||||||
|
|
||||||
|
|
||||||
re_spaces = re.compile(" +")
|
def unwrap_style_text_from_prompt(style_text, prompt):
|
||||||
|
"""
|
||||||
|
Checks the prompt to see if the style text is wrapped around it. If so,
|
||||||
|
returns True plus the prompt text without the style text. Otherwise, returns
|
||||||
|
False with the original prompt.
|
||||||
|
|
||||||
|
Note that the "cleaned" version of the style text is only used for matching
|
||||||
def extract_style_text_from_prompt(style_text, prompt):
|
purposes here. It isn't returned; the original style text is not modified.
|
||||||
stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
|
"""
|
||||||
stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
|
stripped_prompt = clean_text(prompt)
|
||||||
|
stripped_style_text = clean_text(style_text)
|
||||||
if "{prompt}" in stripped_style_text:
|
if "{prompt}" in stripped_style_text:
|
||||||
left, right = stripped_style_text.split("{prompt}", 2)
|
# Work out whether the prompt is wrapped in the style text. If so, we
|
||||||
|
# return True and the "inner" prompt text that isn't part of the style.
|
||||||
|
try:
|
||||||
|
left, right = stripped_style_text.split("{prompt}", 2)
|
||||||
|
except ValueError as e:
|
||||||
|
# If the style text has multple "{prompt}"s, we can't split it into
|
||||||
|
# two parts. This is an error, but we can't do anything about it.
|
||||||
|
print(f"Unable to compare style text to prompt:\n{style_text}")
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return False, prompt
|
||||||
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
|
||||||
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
|
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
|
||||||
return True, prompt
|
return True, prompt
|
||||||
else:
|
else:
|
||||||
|
# Work out whether the given prompt ends with the style text. If so, we
|
||||||
|
# return True and the prompt text up to where the style text starts.
|
||||||
if stripped_prompt.endswith(stripped_style_text):
|
if stripped_prompt.endswith(stripped_style_text):
|
||||||
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
|
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
|
||||||
|
if prompt.endswith(", "):
|
||||||
if prompt.endswith(', '):
|
|
||||||
prompt = prompt[:-2]
|
prompt = prompt[:-2]
|
||||||
|
|
||||||
return True, prompt
|
return True, prompt
|
||||||
|
|
||||||
return False, prompt
|
return False, prompt
|
||||||
|
|
||||||
|
|
||||||
def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
|
def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
||||||
|
"""
|
||||||
|
Takes a style and compares it to the prompt and negative prompt. If the style
|
||||||
|
matches, returns True plus the prompt and negative prompt with the style text
|
||||||
|
removed. Otherwise, returns False with the original prompt and negative prompt.
|
||||||
|
"""
|
||||||
if not style.prompt and not style.negative_prompt:
|
if not style.prompt and not style.negative_prompt:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
|
match_positive, extracted_positive = unwrap_style_text_from_prompt(
|
||||||
|
style.prompt, prompt
|
||||||
|
)
|
||||||
if not match_positive:
|
if not match_positive:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
|
match_negative, extracted_negative = unwrap_style_text_from_prompt(
|
||||||
|
style.negative_prompt, negative_prompt
|
||||||
|
)
|
||||||
if not match_negative:
|
if not match_negative:
|
||||||
return False, prompt, negative_prompt
|
return False, prompt, negative_prompt
|
||||||
|
|
||||||
@ -69,25 +110,88 @@ def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
|
|||||||
|
|
||||||
class StyleDatabase:
|
class StyleDatabase:
|
||||||
def __init__(self, path: str):
|
def __init__(self, path: str):
|
||||||
self.no_style = PromptStyle("None", "", "")
|
self.no_style = PromptStyle("None", "", "", None)
|
||||||
self.styles = {}
|
self.styles = {}
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
folder, file = os.path.split(self.path)
|
||||||
|
self.default_file = file.split("*")[0] + ".csv"
|
||||||
|
if self.default_file == ".csv":
|
||||||
|
self.default_file = "styles.csv"
|
||||||
|
self.default_path = os.path.join(folder, self.default_file)
|
||||||
|
|
||||||
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||||
|
|
||||||
self.reload()
|
self.reload()
|
||||||
|
|
||||||
def reload(self):
|
def reload(self):
|
||||||
|
"""
|
||||||
|
Clears the style database and reloads the styles from the CSV file(s)
|
||||||
|
matching the path used to initialize the database.
|
||||||
|
"""
|
||||||
self.styles.clear()
|
self.styles.clear()
|
||||||
|
|
||||||
if not os.path.exists(self.path):
|
path, filename = os.path.split(self.path)
|
||||||
return
|
|
||||||
|
|
||||||
with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
|
if "*" in filename:
|
||||||
|
fileglob = filename.split("*")[0] + "*.csv"
|
||||||
|
filelist = []
|
||||||
|
for file in os.listdir(path):
|
||||||
|
if fnmatch.fnmatch(file, fileglob):
|
||||||
|
filelist.append(file)
|
||||||
|
# Add a visible divider to the style list
|
||||||
|
half_len = round(len(file) / 2)
|
||||||
|
divider = f"{'-' * (20 - half_len)} {file.upper()}"
|
||||||
|
divider = f"{divider} {'-' * (40 - len(divider))}"
|
||||||
|
self.styles[divider] = PromptStyle(
|
||||||
|
f"{divider}", None, None, "do_not_save"
|
||||||
|
)
|
||||||
|
# Add styles from this CSV file
|
||||||
|
self.load_from_csv(os.path.join(path, file))
|
||||||
|
if len(filelist) == 0:
|
||||||
|
print(f"No styles found in {path} matching {fileglob}")
|
||||||
|
return
|
||||||
|
elif not os.path.exists(self.path):
|
||||||
|
print(f"Style database not found: {self.path}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.load_from_csv(self.path)
|
||||||
|
|
||||||
|
def load_from_csv(self, path: str):
|
||||||
|
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||||
reader = csv.DictReader(file, skipinitialspace=True)
|
reader = csv.DictReader(file, skipinitialspace=True)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
|
# Ignore empty rows or rows starting with a comment
|
||||||
|
if not row or row["name"].startswith("#"):
|
||||||
|
continue
|
||||||
# Support loading old CSV format with "name, text"-columns
|
# Support loading old CSV format with "name, text"-columns
|
||||||
prompt = row["prompt"] if "prompt" in row else row["text"]
|
prompt = row["prompt"] if "prompt" in row else row["text"]
|
||||||
negative_prompt = row.get("negative_prompt", "")
|
negative_prompt = row.get("negative_prompt", "")
|
||||||
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
|
# Add style to database
|
||||||
|
self.styles[row["name"]] = PromptStyle(
|
||||||
|
row["name"], prompt, negative_prompt, path
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_style_paths(self) -> list():
|
||||||
|
"""
|
||||||
|
Returns a list of all distinct paths, including the default path, of
|
||||||
|
files that styles are loaded from."""
|
||||||
|
# Update any styles without a path to the default path
|
||||||
|
for style in list(self.styles.values()):
|
||||||
|
if not style.path:
|
||||||
|
self.styles[style.name] = style._replace(path=self.default_path)
|
||||||
|
|
||||||
|
# Create a list of all distinct paths, including the default path
|
||||||
|
style_paths = set()
|
||||||
|
style_paths.add(self.default_path)
|
||||||
|
for _, style in self.styles.items():
|
||||||
|
if style.path:
|
||||||
|
style_paths.add(style.path)
|
||||||
|
|
||||||
|
# Remove any paths for styles that are just list dividers
|
||||||
|
style_paths.remove("do_not_save")
|
||||||
|
|
||||||
|
return list(style_paths)
|
||||||
|
|
||||||
def get_style_prompts(self, styles):
|
def get_style_prompts(self, styles):
|
||||||
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
return [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||||
@ -96,20 +200,53 @@ class StyleDatabase:
|
|||||||
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
||||||
|
|
||||||
def apply_styles_to_prompt(self, prompt, styles):
|
def apply_styles_to_prompt(self, prompt, styles):
|
||||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
|
return apply_styles_to_prompt(
|
||||||
|
prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
|
||||||
|
)
|
||||||
|
|
||||||
def apply_negative_styles_to_prompt(self, prompt, styles):
|
def apply_negative_styles_to_prompt(self, prompt, styles):
|
||||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
return apply_styles_to_prompt(
|
||||||
|
prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
|
||||||
|
)
|
||||||
|
|
||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str = None) -> None:
|
||||||
# Always keep a backup file around
|
# The path argument is deprecated, but kept for backwards compatibility
|
||||||
if os.path.exists(path):
|
_ = path
|
||||||
shutil.copy(path, f"{path}.bak")
|
|
||||||
|
|
||||||
with open(path, "w", encoding="utf-8-sig", newline='') as file:
|
# Update any styles without a path to the default path
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
for style in list(self.styles.values()):
|
||||||
writer.writeheader()
|
if not style.path:
|
||||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
self.styles[style.name] = style._replace(path=self.default_path)
|
||||||
|
|
||||||
|
# Create a list of all distinct paths, including the default path
|
||||||
|
style_paths = set()
|
||||||
|
style_paths.add(self.default_path)
|
||||||
|
for _, style in self.styles.items():
|
||||||
|
if style.path:
|
||||||
|
style_paths.add(style.path)
|
||||||
|
|
||||||
|
# Remove any paths for styles that are just list dividers
|
||||||
|
style_paths.remove("do_not_save")
|
||||||
|
|
||||||
|
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|
||||||
|
|
||||||
|
for style_path in style_paths:
|
||||||
|
# Always keep a backup file around
|
||||||
|
if os.path.exists(style_path):
|
||||||
|
shutil.copy(style_path, f"{style_path}.bak")
|
||||||
|
|
||||||
|
# Write the styles to the CSV file
|
||||||
|
with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
|
||||||
|
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
for style in (s for s in self.styles.values() if s.path == style_path):
|
||||||
|
# Skip style list dividers, e.g. "STYLES.CSV"
|
||||||
|
if style.name.lower().strip("# ") in csv_names:
|
||||||
|
continue
|
||||||
|
# Write style fields, ignoring the path field
|
||||||
|
writer.writerow(
|
||||||
|
{k: v for k, v in style._asdict().items() if k != "path"}
|
||||||
|
)
|
||||||
|
|
||||||
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
def extract_styles_from_prompt(self, prompt, negative_prompt):
|
||||||
extracted = []
|
extracted = []
|
||||||
@ -120,7 +257,9 @@ class StyleDatabase:
|
|||||||
found_style = None
|
found_style = None
|
||||||
|
|
||||||
for style in applicable_styles:
|
for style in applicable_styles:
|
||||||
is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
|
is_match, new_prompt, new_neg_prompt = extract_original_prompts(
|
||||||
|
style, prompt, negative_prompt
|
||||||
|
)
|
||||||
if is_match:
|
if is_match:
|
||||||
found_style = style
|
found_style = style
|
||||||
prompt = new_prompt
|
prompt = new_prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user