stable-diffusion-webui/modules/ui_extra_networks.py
2024-04-22 14:01:19 -04:00

1229 lines
48 KiB
Python

import functools
import html
import json
import os.path
import re
import urllib.parse
from base64 import b64decode
from io import BytesIO
from pathlib import Path
from typing import Callable, Optional
import gradio as gr
from fastapi.exceptions import HTTPException
from PIL import Image
from starlette.responses import FileResponse, JSONResponse, Response
from modules import errors, extra_networks, shared, util
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.infotext_utils import image_from_url_text
from modules.ui_common import OutputPanel
from modules.ui_extra_networks_user_metadata import UserMetadataEditor
extra_pages = []
allowed_dirs = set()
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
class ListItem:
"""
Attributes:
id [str]: The ID of this list item.
html [str]: The HTML string for this item.
"""
def __init__(self, _id: str, _html: str) -> None:
self.id = _id
self.html = _html
class CardListItem(ListItem):
"""
Attributes:
visible [bool]: Whether the item should be shown in the list.
sort_keys [dict]: Nested dict where keys are sort modes and values are sort keys.
search_terms [str]: String containing multiple search terms joined with spaces.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.visible: bool = False
self.abspath = ""
self.relpath = ""
self.sort_keys = {}
self.search_terms = ""
self.search_only = False
class TreeListItem(ListItem):
"""
Attributes:
visible [bool]: Whether the item should be shown in the list.
expanded [bool]: Whether the item children should be shown.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.node: Optional[DirectoryTreeNode] = None
self.visible: bool = False
self.expanded: bool = False
class DirectoryTreeNode:
"""
Attributes:
root_dir [str]: The root directory used to generate a relative path for this node.
abspath [str]: The absolute path of this node.
parent [DirectoryTreeNode]: The parent node of this node.
depth [int]: The depth of this node in the tree. (folder level)
is_dir [bool]: Whether this node is a directory or file.
item [Optional[dict]]: The item data dictionary.
relpath [str]: Relative path from `root_dir` to this node.
children [list[DirectoryTreeNode]]: List of direct child nodes of this node.
"""
def __init__(
self,
root_dir: str,
abspath: str,
parent: Optional["DirectoryTreeNode"] = None,
) -> None:
self.root_dir = root_dir
self.abspath = abspath
self.parent = parent
self.depth = 0
self.is_dir = False
self.item = None
self.relpath = os.path.relpath(self.abspath, self.root_dir)
self.children: list["DirectoryTreeNode"] = []
# If a parent is passed, then we add this instance to the parent's children.
if self.parent is not None:
self.depth = self.parent.depth + 1
self.parent.add_child(self)
def add_child(self, child: "DirectoryTreeNode") -> None:
self.children.append(child)
def build(self, items: dict[str, dict], include_hidden: bool = False) -> None:
"""Builds a tree of nodes as children of this instance.
Args:
items: A dictionary where keys are absolute filepaths for directories/files.
The values are dictionaries representing extra networks items.
include_hidden: Whether to include hidden directories in the tree.
"""
self.is_dir = os.path.isdir(self.abspath)
if self.is_dir:
for x in os.listdir(self.abspath):
child_path = os.path.join(self.abspath, x)
# Skip hidden directories if include_hidden is False
if os.path.isdir(child_path) and os.path.basename(child_path).startswith(".") and not include_hidden:
continue
# Add all directories but only add files if they are in the items dict.
if os.path.isdir(child_path) or child_path in items:
DirectoryTreeNode(self.root_dir, child_path, self).build(items, include_hidden)
else:
self.item = items.get(self.abspath, None)
def flatten(self, res: dict, dirs_only: bool = False) -> None:
"""Flattens the keys/values of the tree nodes into a dictionary.
Args:
res: The dictionary result updated in place. On initial call,
should be passed as an empty dictionary.
dirs_only: Whether to only add directories to the result.
Raises:
KeyError: If any nodes in the tree have the same ID.
"""
if self.abspath in res:
raise KeyError(f"duplicate key: {self.abspath}")
if not dirs_only or (dirs_only and self.is_dir):
res[self.abspath] = self
for child in self.children:
child.flatten(res, dirs_only)
def to_sorted_list(self, res: list) -> None:
"""Sorts the tree by absolute path and groups by directories/files.
Since we are sorting a directory tree, we always want the directories to come
before the files. So we have to sort these two lists separately.
Args:
res: The list result updated in place. On initial call, should be passed
as an empty list.
"""
res.append(self)
dir_children = [x for x in self.children if x.is_dir]
file_children = [x for x in self.children if not x.is_dir]
for child in sorted(dir_children, key=lambda x: shared.natural_sort_key(x.abspath)):
child.to_sorted_list(res)
for child in sorted(file_children, key=lambda x: shared.natural_sort_key(x.abspath)):
child.to_sorted_list(res)
def apply(self, fn: Callable) -> None:
"""Recursively calls passed function with instance for entire tree."""
fn(self)
for child in self.children:
child.apply(fn)
class ExtraNetworksUi:
"""UI components for Extra Networks
Attributes:
button_save_preview: Gradio button for saving previews.
pages: Gradio HTML elements for an ExtraNetworks page.
pages_contents: HTML string content for `pages`.
preview_target_filename: Gradio textbox for entering filename.
related_tabs: Gradio Tab instances for each ExtraNetworksPage.
stored_extra_pages: `ExtraNetworksPage` instance for each page.
tabname: The primary page tab name (i.e. `txt2img`, `img2img`)
user_metadata_editors: The metadata editor objects for a page.
"""
def __init__(self, tabname: str):
self.tabname = tabname
# Dict keys are "{tabname}_{page.extra_networks_tabname}"
self.pages: dict[str, gr.HTML] = {}
self.pages_contents: dict[str, str] = {}
self.stored_extra_pages: dict[str, ExtraNetworksPage] = {}
self.related_tabs: dict[str, gr.Tab] = {}
self.user_metadata_editors: dict[str, UserMetadataEditor] = {}
self.unrelated_tabs: list[gr.Tab] = []
self.button_save_preview: Optional[gr.Button] = None
self.preview_target_filename: Optional[gr.Textbox] = None
# Fetch the extra pages and build a map.
for page in pages_in_preferred_order(extra_pages.copy()):
self.stored_extra_pages[f"{self.tabname}_{page.extra_networks_tabname}"] = page
class ExtraNetworksPage:
def __init__(self, title):
self.title = title
self.name = title.lower()
# This is the actual name of the extra networks tab (not txt2img/img2img).
self.extra_networks_tabname = self.name.replace(" ", "_")
self.allow_prompt = True
self.allow_negative_prompt = False
self.metadata = {}
self.items = {}
self.cards = {}
self.tree = {}
self.tree_roots = {}
self.lister = util.MassFileLister()
# HTML Templates
self.pane_tpl = shared.html("extra-networks-pane.html")
self.card_tpl = shared.html("extra-networks-card.html")
self.tree_row_tpl = shared.html("extra-networks-tree-row.html")
self.btn_copy_path_tpl = shared.html("extra-networks-btn-copy-path.html")
self.btn_show_metadata_tpl = shared.html("extra-networks-btn-show-metadata.html")
self.btn_edit_metadata_tpl = shared.html("extra-networks-btn-edit-metadata.html")
self.btn_dirs_view_item_tpl = shared.html("extra-networks-btn-dirs-view-item.html")
# Sorted lists
# These just store ints so it won't use hardly any memory to just sort ahead
# of time for each sort mode. These are lists of keys for each file.
self.keys_sorted = {}
self.keys_by_name = []
self.keys_by_path = []
self.keys_by_created = []
self.keys_by_modified = []
def refresh(self):
# Whenever we refresh, we want to build our datasets from scratch.
self.items = {}
self.cards = {}
self.tree = {}
self.tree_roots = {}
def read_user_metadata(self, item, use_cache=True):
filename = os.path.normpath(item.get("filename", None))
metadata = extra_networks.get_user_metadata(filename, lister=self.lister if use_cache else None)
desc = metadata.get("description", None)
if desc is not None:
item["description"] = desc
item["user_metadata"] = metadata
def link_preview(self, filename):
quoted_filename = urllib.parse.quote(filename.replace("\\", "/"))
mtime, _ = self.lister.mctime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename)
for parentdir in possible_directories if possible_directories is not None else self.allowed_directories_for_previews():
parentdir = os.path.dirname(os.path.abspath(parentdir))
if abspath.startswith(parentdir):
return os.path.relpath(abspath, parentdir)
return ""
def build_tree_html_row(
self,
tabname: str,
label: str,
btn_type: str,
btn_title: Optional[str] = None,
data_attributes: Optional[dict] = None,
dir_is_empty: bool = False,
item: Optional[dict] = None,
onclick_extra: Optional[str] = None,
) -> str:
"""Generates HTML for a single row of the Tree View
Args:
tabname:
"txt2img" or "img2img"
label:
The text to display for this row.
btn_type:
"dir" or "file"
btn_title:
Optional hover text for the row. Defaults to `label`.
data_attributes:
Dictionary defining data attributes to add to the row's tag.
Ex: {"one": "1"} would generate <div data-one="1"></div>
dir_is_empty:
Whether the directory is empty. Only useful if btn_type=="dir".
item:
Dictionary containing item data such as filename, hash, etc.
onclick_extra:
Additional javascript code to add to the row's `onclick` attribute.
"""
if btn_type not in ["file", "dir"]:
raise ValueError("Invalid button type:", btn_type)
if data_attributes is None:
data_attributes = {}
label = label.strip()
# If not specified, title will just reflect the label
btn_title = btn_title.strip() if btn_title else f'"{label}"'
action_list_item_action_leading = "<i class='tree-list-item-action-chevron'></i>"
action_list_item_visual_leading = "🗀"
action_list_item_visual_trailing = ""
action_list_item_action_trailing = ""
if dir_is_empty:
action_list_item_action_leading = "<i class='tree-list-item-action-chevron' style='visibility: hidden'></i>"
if btn_type == "file":
action_list_item_visual_leading = "🗎"
# Action buttons
if item is not None:
action_list_item_action_trailing += self.get_button_row(tabname, item)
else:
action_list_item_action_trailing += (
"<div class='button-row'>"
"<div class='tree-list-item-action-expand card-button' title='Expand All'></div>"
"<div class='tree-list-item-action-collapse card-button' title='Collapse All'></div>"
"</div>"
)
data_attributes_str = ""
for k, v in data_attributes.items():
if isinstance(v, (bool,)):
# Boolean data attributes only need a key when true.
if v:
data_attributes_str += f"{k} "
elif v not in [None, "", "''", '""']:
data_attributes_str += f"{k}={v} "
res = self.tree_row_tpl.format(
**{
"data_attributes": data_attributes_str,
"search_terms": "",
"btn_type": btn_type,
"btn_title": btn_title,
"tabname": tabname,
"onclick_extra": onclick_extra if onclick_extra else "",
"extra_networks_tabname": self.extra_networks_tabname,
"action_list_item_action_leading": action_list_item_action_leading,
"action_list_item_visual_leading": action_list_item_visual_leading,
"action_list_item_label": label,
"action_list_item_visual_trailing": action_list_item_visual_trailing,
"action_list_item_action_trailing": action_list_item_action_trailing,
}
)
res = res.strip()
res = re.sub(" +", " ", res.replace("\n", ""))
return res
def get_button_row(self, tabname: str, item: dict) -> str:
"""Generates a row of buttons for use in Tree/Cards View items."""
metadata = item.get("metadata", None)
name = item.get("name", "")
filename = os.path.normpath(item.get("filename", ""))
button_row_tpl = '<div class="button-row">{btn_copy_path}{btn_edit_item}{btn_metadata}</div>'
btn_copy_path = self.btn_copy_path_tpl.format(clipboard_text=filename)
btn_edit_item = self.btn_edit_metadata_tpl.format(
tabname=tabname,
extra_networks_tabname=self.extra_networks_tabname,
name=name,
)
btn_metadata = ""
if metadata:
btn_metadata = self.btn_show_metadata_tpl.format(
extra_networks_tabname=self.extra_networks_tabname,
name=name,
)
return button_row_tpl.format(
btn_copy_path=btn_copy_path,
btn_edit_item=btn_edit_item,
btn_metadata=btn_metadata,
)
def create_card_html(
self,
tabname: str,
item: dict,
div_id: Optional[str] = None,
) -> str:
"""Generates HTML for a single ExtraNetworks Item.
Args:
tabname: The name of the active tab.
item: Dictionary containing item information.
template: Optional template string to use.
Returns:
HTML string generated for this item.
Can be empty if the item is not meant to be shown.
"""
style = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;"
if shared.opts.extra_networks_card_height:
style += f"height: {shared.opts.extra_networks_card_height}px;"
if shared.opts.extra_networks_card_width:
style += f"width: {shared.opts.extra_networks_card_width}px;"
background_image = ""
preview = html.escape(item.get("preview", "") or "")
if preview:
background_image = f'<img src="{preview}" class="preview" loading="lazy">'
onclick = item.get("onclick", None)
if onclick is None:
onclick = html.escape(f"extraNetworksCardOnClick(event, '{tabname}_{self.extra_networks_tabname}');")
button_row = self.get_button_row(tabname, item)
filename = os.path.normpath(item.get("filename", ""))
# if this is true, the item must not be shown in the default view,
# and must instead only be shown when searching for it
show_hidden_models = str(shared.opts.extra_networks_hidden_models).strip().lower()
if show_hidden_models == "always":
search_only = False
else:
# If any parent dirs are hidden, the model is also hidden.
search_only = any(x.startswith(".") for x in filename.split(os.sep))
if search_only and show_hidden_models == "never":
return ""
sort_keys = {}
for sort_mode, sort_key in item.get("sort_keys", {}).items():
sort_keys[sort_mode.strip().lower()] = html.escape(str(sort_key))
description = ""
if shared.opts.extra_networks_card_show_desc:
description = item.get("description", "") or ""
if not shared.opts.extra_networks_card_description_is_html:
description = html.escape(description)
data_name = item.get("name", "").strip()
data_path = os.path.normpath(item.get("filename", "").strip())
data_attributes = {
"data-div-id": f'"{div_id}"' if div_id else '""',
"data-name": f'"{data_name}"',
"data-path": f'"{data_path}"',
"data-hash": item.get("shorthash", None),
"data-prompt": item.get("prompt", "").strip(),
"data-neg-prompt": item.get("negative_prompt", "").strip(),
"data-allow-neg": self.allow_negative_prompt,
}
data_attributes_str = ""
for k, v in data_attributes.items():
if isinstance(v, (bool,)):
# Boolean data attributes only need a key when true.
if v:
data_attributes_str += f"{k} "
elif v not in [None, "", "''", '""']:
data_attributes_str += f"{k}={v} "
return self.card_tpl.format(
style=style,
onclick=onclick,
data_attributes=data_attributes_str,
background_image=background_image,
button_row=button_row,
name=html.escape(item["name"].strip()),
description=description,
)
def generate_cards_view_data(self, tabname: str) -> dict:
"""Generates the datasets and HTML used to display the Cards View.
Returns:
A dictionary containing necessary info for the client.
{
search_keys: array of strings,
sort_<mode>: string, (for various sort modes),
visible: True, // all cards are visible by default.
}
Return does not contain the HTML since that is fetched by client.
"""
for i, item in enumerate(self.items.values()):
div_id = str(i)
card_html = self.create_card_html(tabname=tabname, item=item, div_id=div_id)
sort_keys = {k.strip().lower().replace(" ", "_"): html.escape(str(v)) for k, v in item.get("sort_keys", {}).items()}
search_terms = item.get("search_terms", [])
show_hidden_models = str(shared.opts.extra_networks_hidden_models).strip().lower()
if show_hidden_models == "always":
search_only = False
else:
# If any parent dirs are hidden, the model is also hidden.
filename = os.path.normpath(item.get("filename", ""))
search_only = any(x.startswith(".") for x in filename.split(os.sep))
self.cards[div_id] = CardListItem(div_id, card_html)
self.cards[div_id].abspath = os.path.normpath(item.get("filename", ""))
for parent_dir in self.allowed_directories_for_previews():
parent_dir = os.path.dirname(os.path.abspath(parent_dir))
if self.cards[div_id].abspath.startswith(parent_dir):
self.cards[div_id].relpath = os.path.relpath(self.cards[div_id].abspath, parent_dir)
break
self.cards[div_id].sort_keys = sort_keys
self.cards[div_id].search_terms = " ".join(search_terms)
self.cards[div_id].search_only = search_only
# Sort cards for all sort modes
sort_modes = ["name", "path", "date_created", "date_modified"]
for mode in sort_modes:
self.keys_sorted[mode] = sorted(
self.cards.keys(),
key=lambda k: shared.natural_sort_key(self.cards[k].sort_keys[mode]),
)
res = {}
for div_id, card_item in self.cards.items():
rel_parent_dir = os.path.dirname(card_item.relpath)
if (card_item.search_only):
parents = card_item.relpath.split(os.sep)
idxs = [i for i, x in enumerate(parents) if x.startswith(".")]
if len(idxs) > 0:
rel_parent_dir = os.sep.join(parents[idxs[0]:])
else:
print(f"search_only is enabled but no hidden dir found: {card_item.abspath}")
res[div_id] = {
**{f"sort_{mode}": key for mode, key in card_item.sort_keys.items()},
"rel_parent_dir": rel_parent_dir,
"search_terms": card_item.search_terms,
"search_only": card_item.search_only,
"visible": not card_item.search_only,
}
return res
def generate_tree_view_data(self, tabname: str) -> dict:
"""Generates the datasets and HTML used to display the Tree View.
Returns:
A dictionary containing necessary info for the client.
{
parent: None or div_id,
children: list of div_id's,
visible: bool,
expanded: bool,
}
Return does not contain the HTML since that is fetched by client.
"""
if not self.tree_roots:
return {}
# Flatten roots into a single sorted list of nodes.
# Directories always come before files. After that, natural sort is used.
sorted_nodes = []
for node in self.tree_roots.values():
_sorted_nodes = []
node.to_sorted_list(_sorted_nodes)
sorted_nodes.extend(_sorted_nodes)
path_to_div_id = {}
div_id_to_node = {} # reverse mapping
# First assign div IDs to each node. Used for parent ID lookup later.
for i, node in enumerate(sorted_nodes):
div_id = str(i)
path_to_div_id[node.abspath] = div_id
div_id_to_node[div_id] = node
show_files = shared.opts.extra_networks_tree_view_show_files is True
for div_id, node in div_id_to_node.items():
tree_item = TreeListItem(div_id, "")
tree_item.node = node
parent_id = None
if node.parent is not None:
parent_id = path_to_div_id.get(node.parent.abspath, None)
if node.is_dir: # directory
if show_files:
dir_is_empty = node.children == []
else:
dir_is_empty = all(not x.is_dir for x in node.children)
tree_item.html = self.build_tree_html_row(
tabname=tabname,
label=os.path.basename(node.abspath),
btn_type="dir",
btn_title=f'"{node.abspath}"',
dir_is_empty=dir_is_empty,
data_attributes={
"data-div-id": f'"{div_id}"',
"data-parent-id": f'"{parent_id}"',
"data-tree-entry-type": "dir",
"data-depth": node.depth,
"data-path": f'"{node.relpath}"',
"data-expanded": node.parent is None, # Expand root directories
},
)
self.tree[div_id] = tree_item
else: # file
if not show_files:
# Don't add file if files are disabled in the options.
continue
onclick = node.item.get("onclick", None)
if onclick is None:
onclick = html.escape(f"extraNetworksCardOnClick(event, '{tabname}_{self.extra_networks_tabname}');")
item_name = node.item.get("name", "").strip()
data_path = os.path.normpath(node.item.get("filename", "").strip())
tree_item.html = self.build_tree_html_row(
tabname=tabname,
label=html.escape(item_name),
btn_type="file",
data_attributes={
"data-div-id": f'"{div_id}"',
"data-parent-id": f'"{parent_id}"',
"data-tree-entry-type": "file",
"data-name": f'"{item_name}"',
"data-depth": node.depth,
"data-path": f'"{data_path}"',
"data-hash": node.item.get("shorthash", None),
"data-prompt": node.item.get("prompt", "").strip(),
"data-neg-prompt": node.item.get("negative_prompt", "").strip(),
"data-allow-neg": self.allow_negative_prompt,
},
item=node.item,
onclick_extra=onclick,
)
self.tree[div_id] = tree_item
res = {}
# Expand all root directories and set them to active so they are displayed.
for path in self.tree_roots.keys():
div_id = path_to_div_id[path]
self.tree[div_id].expanded = True
self.tree[div_id].visible = True
# Set all direct children to active
for child_node in self.tree[div_id].node.children:
self.tree[path_to_div_id[child_node.abspath]].visible = True
for div_id, tree_item in self.tree.items():
# Expand root nodes and make them visible.
expanded = tree_item.node.parent is None
visible = tree_item.node.parent is None
parent_id = None
if tree_item.node.parent is not None:
parent_id = path_to_div_id[tree_item.node.parent.abspath]
# Direct children of root nodes should be visible by default.
if self.tree[parent_id].node.parent is None:
visible = True
res[div_id] = {
"parent": parent_id,
"children": [path_to_div_id[child.abspath] for child in tree_item.node.children],
"visible": visible,
"expanded": expanded,
}
return res
def create_dirs_view_html(self, tabname: str) -> str:
"""Generates HTML for displaying folders."""
# Flatten each root into a single dict. Only get the directories for buttons.
tree = {}
for node in self.tree_roots.values():
subtree = {}
node.flatten(subtree, dirs_only=True)
tree.update(subtree)
# Sort the tree nodes by relative paths
dir_nodes = sorted(
tree.values(),
key=lambda x: shared.natural_sort_key(x.relpath),
)
dirs_html = "".join(
[
self.btn_dirs_view_item_tpl.format(
**{
"extra_class": "search-all" if node.relpath == "" else "",
"tabname_full": f"{tabname}_{self.extra_networks_tabname}",
"path": html.escape(node.relpath),
}
)
for node in dir_nodes
]
)
return dirs_html
def create_html(self, tabname: str, *, empty: bool = False) -> str:
"""Generates an HTML string for the current pane.
The generated HTML uses `extra-networks-pane.html` as a template.
Args:
tabname: The name of the active tab.
empty: create an empty HTML page with no items
Returns:
HTML formatted string.
"""
self.lister.reset()
self.metadata = {}
items_list = [] if empty else self.list_items()
self.items = {x["name"]: x for x in items_list}
# Populate the instance metadata for each item.
for item in self.items.values():
metadata = item.get("metadata")
if metadata:
self.metadata[item["name"]] = metadata
if "user_metadata" not in item:
self.read_user_metadata(item)
# Setup the tree dictionary.
tree_items = {os.path.normpath(v["filename"]): v for v in self.items.values()}
# Create a DirectoryTreeNode for each root directory since they might not share
# a common path.
for path in self.allowed_directories_for_previews():
abspath = os.path.abspath(path)
if not os.path.exists(abspath):
continue
self.tree_roots[abspath] = DirectoryTreeNode(os.path.dirname(abspath), abspath, None)
self.tree_roots[abspath].build(
tree_items if shared.opts.extra_networks_tree_view_show_files else {},
include_hidden=shared.opts.extra_networks_show_hidden_directories,
)
# Generate the html for displaying directory buttons
dirs_html = self.create_dirs_view_html(tabname)
sort_mode = shared.opts.extra_networks_card_order_field.lower().strip().replace(" ", "_")
sort_dir = shared.opts.extra_networks_card_order.lower().strip()
dirs_view_en = shared.opts.extra_networks_dirs_view_default_enabled
tree_view_en = shared.opts.extra_networks_tree_view_default_enabled
return self.pane_tpl.format(
**{
"tabname": tabname,
"extra_networks_tabname": self.extra_networks_tabname,
"data_sort_dir": sort_dir,
"btn_sort_mode_path_data_attributes": "data-selected" if sort_mode == "path" else "",
"btn_sort_mode_name_data_attributes": "data-selected" if sort_mode == "name" else "",
"btn_sort_mode_date_created_data_attributes": "data-selected" if sort_mode == "date_created" else "",
"btn_sort_mode_date_modified_data_attributes": "data-selected" if sort_mode == "date_modified" else "",
"btn_dirs_view_data_attributes": "data-selected" if dirs_view_en else "",
"btn_tree_view_data_attributes": "data-selected" if tree_view_en else "",
"dirs_view_hidden_cls": "" if dirs_view_en else "hidden",
"tree_view_hidden_cls": "" if tree_view_en else "hidden",
"tree_view_style": f"flex-basis: {shared.opts.extra_networks_tree_view_default_width}px;",
"cards_view_style": "flex-grow: 1;",
"dirs_html": dirs_html,
}
)
def create_item(self, name, index=None):
raise NotImplementedError()
def list_items(self):
raise NotImplementedError()
def allowed_directories_for_previews(self):
return []
def get_sort_keys(self, path):
"""
List of default keys used for sorting in the UI.
"""
pth = Path(path)
mtime, ctime = self.lister.mctime(path)
return {
"date_created": int(mtime),
"date_modified": int(ctime),
"name": pth.name.lower(),
"path": str(pth).lower(),
}
def find_preview(self, path):
"""
Find a preview PNG for a given path (without extension) and call link_preview on it.
"""
potential_files = sum([[f"{path}.{ext}", f"{path}.preview.{ext}"] for ext in allowed_preview_extensions()], [])
for file in potential_files:
if self.lister.exists(file):
return self.link_preview(file)
return None
def find_embedded_preview(self, path, name, metadata):
"""
Find if embedded preview exists in safetensors metadata and return endpoint for it.
"""
file = f"{path}.safetensors"
if (
self.lister.exists(file)
and "ssmd_cover_images" in metadata
and len(list(filter(None, json.loads(metadata["ssmd_cover_images"])))) > 0
):
return f"./sd_extra_networks/cover-images?extra_networks_tabname={self.extra_networks_tabname}&item={name}"
return None
def find_description(self, path):
"""
Find and read a description file for a given path (without extension).
"""
for file in [f"{path}.txt", f"{path}.description.txt"]:
if not self.lister.exists(file):
continue
try:
with open(file, "r", encoding="utf-8", errors="replace") as f:
return f.read()
except OSError:
pass
return None
def create_user_metadata_editor(self, ui, tabname) -> UserMetadataEditor:
return UserMetadataEditor(ui, tabname, self)
@functools.cache
def allowed_preview_extensions_with_extra(extra_extensions=None):
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
def allowed_preview_extensions():
return allowed_preview_extensions_with_extra((shared.opts.samples_format,))
def register_page(page):
"""registers extra networks page for the UI
recommend doing it in on_before_ui() callback for extensions
"""
extra_pages.append(page)
allowed_dirs.clear()
allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))
def get_page_by_name(extra_networks_tabname: str = "") -> "ExtraNetworksPage":
"""Gets a page from extra pages for the specified tabname.
Raises:
HTTPException if the tabname is not in the `extra_pages` dict.
"""
for page in extra_pages:
if page.extra_networks_tabname == extra_networks_tabname:
return page
raise HTTPException(status_code=404, detail=f"Page not found: {extra_networks_tabname}")
def fetch_file(filename: str = ""):
if not os.path.isfile(filename):
raise HTTPException(status_code=404, detail="File not found")
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
ext = os.path.splitext(filename)[1].lower()[1:]
if ext not in allowed_preview_extensions():
raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.")
# would profit from returning 304
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
def fetch_cover_images(extra_networks_tabname: str = "", item: str = "", index: int = 0):
page = get_page_by_name(extra_networks_tabname)
metadata = page.metadata.get(item)
if metadata is None:
raise HTTPException(status_code=404, detail="File not found")
cover_images = json.loads(metadata.get("ssmd_cover_images", {}))
image = cover_images[index] if index < len(cover_images) else None
if not image:
raise HTTPException(status_code=404, detail="File not found")
try:
image = Image.open(BytesIO(b64decode(image)))
buffer = BytesIO()
image.save(buffer, format=image.format)
return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())
except Exception as err:
raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err
def init_tree_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONResponse:
"""Generates the initial Tree View data and returns a simplified dataset.
The data returned does not contain any HTML strings.
Status Codes:
200 on success
404 if data isn't ready or tabname doesn't exist.
"""
page = get_page_by_name(extra_networks_tabname)
data = page.generate_tree_view_data(tabname)
if data is None:
raise HTTPException(status_code=404, detail=f"data not ready: {extra_networks_tabname}")
return JSONResponse(data)
def fetch_tree_data(
extra_networks_tabname: str = "",
div_ids: str = "",
) -> JSONResponse:
"""Retrieves Tree View HTML strings for the specified `div_ids`.
Args:
div_ids: A string with div_ids in CSV format.
Status Codes:
200 on success
404 if tabname doesn't exist
"""
page = get_page_by_name(extra_networks_tabname)
res = {}
missed = []
for div_id in div_ids.split(","):
if div_id in page.tree:
res[div_id] = page.tree[div_id].html
else:
missed.append(div_id)
return JSONResponse({"data": res, "missing_div_ids": missed})
def fetch_cards_data(
extra_networks_tabname: str = "",
div_ids: str = "",
) -> JSONResponse:
"""Retrieves Cards View HTML strings for the specified `div_ids`.
Args:
div_ids: A string with div_ids in CSV format.
Status Codes:
200 on success
404 if tabname doesn't exist
"""
page = get_page_by_name(extra_networks_tabname)
res = {}
missed = []
for div_id in div_ids.split(","):
if div_id in page.cards:
res[div_id] = page.cards[div_id].html
else:
missed.append(div_id)
return JSONResponse({"data": res, "missing_div_ids": missed})
def init_cards_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONResponse:
"""Generates the initial Cards View data and returns a simplified dataset.
The data returned does not contain any HTML strings.
Status Codes:
200 on success
404 if data isn't ready or tabname doesn't exist.
"""
page = get_page_by_name(extra_networks_tabname)
data = page.generate_cards_view_data(tabname)
if data is None:
raise HTTPException(status_code=404, detail=f"data not ready: {extra_networks_tabname}")
return JSONResponse(data)
def page_is_ready(extra_networks_tabname: str = "") -> JSONResponse:
"""Returns whether the specified page is ready for fetching data.
Status Codes:
200 if page is ready
404 if page isn't ready or tabname doesnt exist.
"""
page = get_page_by_name(extra_networks_tabname)
if len(page.items) == len(list(page.list_items())):
return JSONResponse({}, status_code=200)
else:
raise HTTPException(status_code=404, detail=f"page not ready: {extra_networks_tabname}")
def get_metadata(extra_networks_tabname: str = "", item: str = "") -> JSONResponse:
try:
page = get_page_by_name(extra_networks_tabname)
except HTTPException:
return JSONResponse({})
metadata = page.metadata.get(item)
if metadata is None:
return JSONResponse({})
# those are cover images, and they are too big to display in UI as text
# FIXME: WHY WAS THIS HERE?
# metadata = {i: metadata[i] for i in metadata if i != 'ssmd_cover_images'}
return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
def get_single_card(tabname: str = "", extra_networks_tabname: str = "", name: str = "") -> JSONResponse:
page = get_page_by_name(extra_networks_tabname)
try:
item = page.create_item(name, enable_filter=False)
page.items[name] = item
except Exception as exc:
errors.display(exc, "creating item for extra network")
item = page.items.get(name, None)
if item is None:
return JSONResponse({})
page.read_user_metadata(item, use_cache=False)
item_html = page.create_card_html(tabname=tabname, item=item)
return JSONResponse({"html": item_html})
def add_pages_to_demo(app):
app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
app.add_api_route("/sd_extra_networks/init-tree-data", init_tree_data, methods=["GET"])
app.add_api_route("/sd_extra_networks/init-cards-data", init_cards_data, methods=["GET"])
app.add_api_route("/sd_extra_networks/fetch-tree-data", fetch_tree_data, methods=["GET"])
app.add_api_route("/sd_extra_networks/fetch-cards-data", fetch_cards_data, methods=["GET"])
app.add_api_route("/sd_extra_networks/page-is-ready", page_is_ready, methods=["GET"])
def quote_js(s):
s = s.replace("\\", "\\\\")
s = s.replace('"', '\\"')
return f'"{s}"'
def initialize():
extra_pages.clear()
def register_default_pages():
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
register_page(ExtraNetworksPageTextualInversion())
register_page(ExtraNetworksPageHypernetworks())
register_page(ExtraNetworksPageCheckpoints())
def pages_in_preferred_order(pages):
tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")]
def tab_name_score(name):
name = name.lower()
for i, possible_match in enumerate(tab_order):
if possible_match in name:
return i
return len(pages)
tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}
return sorted(pages, key=lambda x: tab_scores[x.name])
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
ui = ExtraNetworksUi(tabname)
ui.unrelated_tabs = unrelated_tabs
for tabname_full, page in ui.stored_extra_pages.items():
with gr.Tab(page.title, elem_id=tabname_full, elem_classes=["extra-page"]) as tab:
with gr.Column(elem_id=f"{tabname_full}_prompts", elem_classes=["extra-page-prompts"]):
pass
page_elem = gr.HTML(
page.create_html(tabname, empty=True),
elem_id=f"{tabname_full}_pane_container",
)
ui.pages[tabname_full] = page_elem
editor = page.create_user_metadata_editor(ui, tabname)
editor.create_ui()
ui.user_metadata_editors[tabname_full] = editor
ui.related_tabs[tabname_full] = tab
ui.button_save_preview = gr.Button(
"Save preview",
elem_id=f"{tabname}_save_preview",
visible=False,
)
ui.preview_target_filename = gr.Textbox(
"Preview save filename",
elem_id=f"{tabname}_preview_filename",
visible=False,
)
for tab in ui.unrelated_tabs:
tab.select(
fn=None,
_js=f"function(){{extraNetworksUnrelatedTabSelected('{ui.tabname}');}}",
inputs=[],
outputs=[],
show_progress=False,
)
for tabname_full, page in ui.stored_extra_pages.items():
tab = ui.related_tabs[tabname_full]
tab.select(
fn=None,
_js=(
"function(){extraNetworksTabSelected("
f"'{tabname_full}', "
f"{str(page.allow_prompt).lower()}, "
f"{str(page.allow_negative_prompt).lower()}"
");}"
),
inputs=[],
outputs=[],
show_progress=False,
)
def refresh(tabname_full):
page = ui.stored_extra_pages[tabname_full]
page.refresh()
ui.pages_contents[tabname_full] = page.create_html(ui.tabname)
return list(ui.pages_contents.values())
button_refresh = gr.Button(
"Refresh",
elem_id=f"{tabname_full}_extra_refresh_internal",
visible=False,
)
button_refresh.click(
fn=functools.partial(refresh, tabname_full),
inputs=[],
outputs=list(ui.pages.values()),
).then(
fn=lambda: None,
_js="setupAllResizeHandles",
).then(
fn=lambda: None,
_js=f"function(){{extraNetworksRefreshTab('{tabname_full}');}}",
)
def create_html():
for tabname_full, page in ui.stored_extra_pages.items():
ui.pages_contents[tabname_full] = page.create_html(ui.tabname)
def pages_html():
if not ui.pages_contents:
create_html()
return list(ui.pages_contents.values())
interface.load(fn=pages_html, inputs=[], outputs=list(ui.pages.values()),).then(
fn=lambda: None,
_js="setupAllResizeHandles",
)
return ui
def path_is_parent(parent_path, child_path):
parent_path = os.path.abspath(parent_path)
child_path = os.path.abspath(child_path)
return child_path.startswith(parent_path)
def setup_ui(ui: ExtraNetworksUi, gallery: OutputPanel):
def save_preview(index, images, filename):
# this function is here for backwards compatibility and likely will be removed soon
if len(images) == 0:
print("There is no image in gallery to save as a preview.")
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages.values()]
index = int(index)
index = 0 if index < 0 else index
index = len(images) - 1 if index >= len(images) else index
img_info = images[index if index >= 0 else 0]
image = image_from_url_text(img_info)
geninfo, items = read_info_from_image(image)
is_allowed = False
for page in ui.stored_extra_pages.values():
if any(path_is_parent(x, filename) for x in page.allowed_directories_for_previews()):
is_allowed = True
break
assert is_allowed, f"writing to {filename} is not allowed"
save_image_with_geninfo(image, geninfo, filename)
return [page.create_html(ui.tabname) for page in ui.stored_extra_pages.values()]
ui.button_save_preview.click(
fn=save_preview,
_js="function(x, y, z){return [selected_gallery_index(), y, z]}",
inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
outputs=[*list(ui.pages.values())],
)
for editor in ui.user_metadata_editors.values():
editor.setup_ui(gallery)