mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #14707 from AUTOMATIC1111/multi-styles-base-styles-file
re-work multi --styles-file
This commit is contained in:
commit
c17f7ee694
@ -88,7 +88,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anythin
|
|||||||
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
|
||||||
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
|
||||||
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
|
||||||
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
|
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
|
||||||
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
|
||||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -11,7 +12,7 @@ parser = shared_cmd_options.parser
|
|||||||
|
|
||||||
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
|
||||||
parallel_processing_allowed = True
|
parallel_processing_allowed = True
|
||||||
styles_filename = cmd_opts.styles_file
|
styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
|
||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
|
from pathlib import Path
|
||||||
import csv
|
import csv
|
||||||
import fnmatch
|
|
||||||
import os
|
import os
|
||||||
import os.path
|
|
||||||
import typing
|
import typing
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
class PromptStyle(typing.NamedTuple):
|
class PromptStyle(typing.NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
prompt: str
|
prompt: str | None
|
||||||
negative_prompt: str
|
negative_prompt: str | None
|
||||||
path: str = None
|
path: str | None = None
|
||||||
|
|
||||||
|
|
||||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||||
@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
|
|||||||
|
|
||||||
|
|
||||||
class StyleDatabase:
|
class StyleDatabase:
|
||||||
def __init__(self, path: str):
|
def __init__(self, paths: list[str | Path]):
|
||||||
self.no_style = PromptStyle("None", "", "", None)
|
self.no_style = PromptStyle("None", "", "", None)
|
||||||
self.styles = {}
|
self.styles = {}
|
||||||
self.path = path
|
self.paths = paths
|
||||||
|
self.all_styles_files: list[Path] = []
|
||||||
|
|
||||||
folder, file = os.path.split(self.path)
|
folder, file = os.path.split(self.paths[0])
|
||||||
filename, _, ext = file.partition('*')
|
if '*' in file or '?' in file:
|
||||||
self.default_path = os.path.join(folder, filename + ext)
|
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
|
||||||
|
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
|
||||||
|
self.paths.insert(0, self.default_path)
|
||||||
|
else:
|
||||||
|
self.default_path = Path(self.paths[0])
|
||||||
|
|
||||||
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
|
||||||
|
|
||||||
@ -99,33 +103,31 @@ class StyleDatabase:
|
|||||||
"""
|
"""
|
||||||
self.styles.clear()
|
self.styles.clear()
|
||||||
|
|
||||||
path, filename = os.path.split(self.path)
|
# scans for all styles files
|
||||||
|
all_styles_files = []
|
||||||
|
for pattern in self.paths:
|
||||||
|
folder, file = os.path.split(pattern)
|
||||||
|
if '*' in file or '?' in file:
|
||||||
|
found_files = Path(folder).glob(file)
|
||||||
|
[all_styles_files.append(file) for file in found_files]
|
||||||
|
else:
|
||||||
|
# if os.path.exists(pattern):
|
||||||
|
all_styles_files.append(Path(pattern))
|
||||||
|
|
||||||
if "*" in filename:
|
# Remove any duplicate entries
|
||||||
fileglob = filename.split("*")[0] + "*.csv"
|
seen = set()
|
||||||
filelist = []
|
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
|
||||||
for file in os.listdir(path):
|
|
||||||
if fnmatch.fnmatch(file, fileglob):
|
|
||||||
filelist.append(file)
|
|
||||||
# Add a visible divider to the style list
|
|
||||||
half_len = round(len(file) / 2)
|
|
||||||
divider = f"{'-' * (20 - half_len)} {file.upper()}"
|
|
||||||
divider = f"{divider} {'-' * (40 - len(divider))}"
|
|
||||||
self.styles[divider] = PromptStyle(
|
|
||||||
f"{divider}", None, None, "do_not_save"
|
|
||||||
)
|
|
||||||
# Add styles from this CSV file
|
|
||||||
self.load_from_csv(os.path.join(path, file))
|
|
||||||
if len(filelist) == 0:
|
|
||||||
print(f"No styles found in {path} matching {fileglob}")
|
|
||||||
return
|
|
||||||
elif not os.path.exists(self.path):
|
|
||||||
print(f"Style database not found: {self.path}")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.load_from_csv(self.path)
|
|
||||||
|
|
||||||
def load_from_csv(self, path: str):
|
for styles_file in self.all_styles_files:
|
||||||
|
if len(all_styles_files) > 1:
|
||||||
|
# add divider when more than styles file
|
||||||
|
# '---------------- STYLES ----------------'
|
||||||
|
divider = f' {styles_file.stem.upper()} '.center(40, '-')
|
||||||
|
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
|
||||||
|
if styles_file.is_file():
|
||||||
|
self.load_from_csv(styles_file)
|
||||||
|
|
||||||
|
def load_from_csv(self, path: str | Path):
|
||||||
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
with open(path, "r", encoding="utf-8-sig", newline="") as file:
|
||||||
reader = csv.DictReader(file, skipinitialspace=True)
|
reader = csv.DictReader(file, skipinitialspace=True)
|
||||||
for row in reader:
|
for row in reader:
|
||||||
@ -137,7 +139,7 @@ class StyleDatabase:
|
|||||||
negative_prompt = row.get("negative_prompt", "")
|
negative_prompt = row.get("negative_prompt", "")
|
||||||
# Add style to database
|
# Add style to database
|
||||||
self.styles[row["name"]] = PromptStyle(
|
self.styles[row["name"]] = PromptStyle(
|
||||||
row["name"], prompt, negative_prompt, path
|
row["name"], prompt, negative_prompt, str(path)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_style_paths(self) -> set:
|
def get_style_paths(self) -> set:
|
||||||
@ -145,11 +147,11 @@ class StyleDatabase:
|
|||||||
# Update any styles without a path to the default path
|
# Update any styles without a path to the default path
|
||||||
for style in list(self.styles.values()):
|
for style in list(self.styles.values()):
|
||||||
if not style.path:
|
if not style.path:
|
||||||
self.styles[style.name] = style._replace(path=self.default_path)
|
self.styles[style.name] = style._replace(path=str(self.default_path))
|
||||||
|
|
||||||
# Create a list of all distinct paths, including the default path
|
# Create a list of all distinct paths, including the default path
|
||||||
style_paths = set()
|
style_paths = set()
|
||||||
style_paths.add(self.default_path)
|
style_paths.add(str(self.default_path))
|
||||||
for _, style in self.styles.items():
|
for _, style in self.styles.items():
|
||||||
if style.path:
|
if style.path:
|
||||||
style_paths.add(style.path)
|
style_paths.add(style.path)
|
||||||
@ -177,7 +179,6 @@ class StyleDatabase:
|
|||||||
|
|
||||||
def save_styles(self, path: str = None) -> None:
|
def save_styles(self, path: str = None) -> None:
|
||||||
# The path argument is deprecated, but kept for backwards compatibility
|
# The path argument is deprecated, but kept for backwards compatibility
|
||||||
_ = path
|
|
||||||
|
|
||||||
style_paths = self.get_style_paths()
|
style_paths = self.get_style_paths()
|
||||||
|
|
||||||
|
@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
|
|||||||
if not name:
|
if not name:
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
style = styles.PromptStyle(name, prompt, negative_prompt)
|
existing_style = shared.prompt_styles.styles.get(name)
|
||||||
|
path = existing_style.path if existing_style is not None else None
|
||||||
|
|
||||||
|
style = styles.PromptStyle(name, prompt, negative_prompt, path)
|
||||||
shared.prompt_styles.styles[style.name] = style
|
shared.prompt_styles.styles[style.name] = style
|
||||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
shared.prompt_styles.save_styles()
|
||||||
|
|
||||||
return gr.update(visible=True)
|
return gr.update(visible=True)
|
||||||
|
|
||||||
@ -34,7 +37,7 @@ def delete_style(name):
|
|||||||
return
|
return
|
||||||
|
|
||||||
shared.prompt_styles.styles.pop(name, None)
|
shared.prompt_styles.styles.pop(name, None)
|
||||||
shared.prompt_styles.save_styles(shared.styles_filename)
|
shared.prompt_styles.save_styles()
|
||||||
|
|
||||||
return '', '', ''
|
return '', '', ''
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user