import functools import html import json import os.path import re import urllib.parse from base64 import b64decode from io import BytesIO from pathlib import Path from typing import Callable, Optional import gradio as gr from fastapi.exceptions import HTTPException from PIL import Image from starlette.responses import FileResponse, JSONResponse, Response from modules import (errors, extra_networks, shared, ui_extra_networks_user_metadata, util) from modules.images import read_info_from_image, save_image_with_geninfo from modules.infotext_utils import image_from_url_text extra_pages = [] allowed_dirs = set() default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"] @functools.cache def allowed_preview_extensions_with_extra(extra_extensions=None): return set(default_allowed_preview_extensions) | set(extra_extensions or []) def allowed_preview_extensions(): return allowed_preview_extensions_with_extra((shared.opts.samples_format,)) class ListItem: """ Attributes: id [str]: The ID of this list item. html [str]: The HTML string for this item. """ def __init__(self, _id: str, _html: str) -> None: self.id = _id self.html = _html class CardListItem(ListItem): """ Attributes: visible [bool]: Whether the item should be shown in the list. sort_keys [dict]: Nested dict where keys are sort modes and values are sort keys. search_terms [str]: String containing multiple search terms joined with spaces. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.visible: bool = False self.sort_keys = {} self.search_terms = "" class TreeListItem(ListItem): """ Attributes: visible [bool]: Whether the item should be shown in the list. expanded [bool]: Whether the item children should be shown. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.node: Optional[DirectoryTreeNode] = None self.visible: bool = False self.expanded: bool = False class DirectoryTreeNode: """ Attributes: root_dir [str]: The root directory used to generate a relative path for this node. abspath [str]: The absolute path of this node. parent [DirectoryTreeNode]: The parent node of this node. depth [int]: The depth of this node in the tree. (folder level) is_dir [bool]: Whether this node is a directory or file. item [Optional[dict]]: The item data dictionary. relpath [str]: Relative path from `root_dir` to this node. children [list[DirectoryTreeNode]]: List of direct child nodes of this node. """ def __init__( self, root_dir: str, abspath: str, parent: Optional["DirectoryTreeNode"] = None, ) -> None: self.root_dir = root_dir self.abspath = abspath self.parent = parent self.depth = 0 self.is_dir = False self.item = None self.relpath = os.path.relpath(self.abspath, self.root_dir) self.children: list["DirectoryTreeNode"] = [] # If a parent is passed, then we add this instance to the parent's children. if self.parent is not None: self.depth = self.parent.depth + 1 self.parent.add_child(self) def add_child(self, child: "DirectoryTreeNode") -> None: self.children.append(child) def build(self, items: dict[str, dict]) -> None: """Builds a tree of nodes as children of this instance. Args: items: A dictionary where keys are absolute filepaths for directories/files. The values are dictionaries representing extra networks items. """ self.is_dir = os.path.isdir(self.abspath) if self.is_dir: for x in os.listdir(self.abspath): child_path = os.path.join(self.abspath, x) # Add all directories but only add files if they are in the items dict. if os.path.isdir(child_path) or child_path in items: DirectoryTreeNode(self.root_dir, child_path, self).build(items) else: self.item = items.get(self.abspath, None) def flatten(self, res: dict, dirs_only: bool = False) -> None: """Flattens the keys/values of the tree nodes into a dictionary. Args: res: The dictionary result updated in place. On initial call, should be passed as an empty dictionary. dirs_only: Whether to only add directories to the result. Raises: KeyError: If any nodes in the tree have the same ID. """ if self.abspath in res: raise KeyError(f"duplicate key: {self.abspath}") if not dirs_only or (dirs_only and self.is_dir): res[self.abspath] = self for child in self.children: child.flatten(res, dirs_only) def apply(self, fn: Callable) -> None: """Recursively calls passed function with instance for entire tree.""" fn(self) for child in self.children: child.apply(fn) def register_page(page): """registers extra networks page for the UI recommend doing it in on_before_ui() callback for extensions """ extra_pages.append(page) allowed_dirs.clear() allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], []))) def get_page_by_name(extra_networks_tabname: str = "") -> "ExtraNetworksPage": """Gets a page from extra pages for the specified tabname. Raises: HTTPException if the tabname is not in the `extra_pages` dict. """ for page in extra_pages: if page.extra_networks_tabname == extra_networks_tabname: return page raise HTTPException(status_code=404, detail=f"Page not found: {extra_networks_tabname}") def fetch_file(filename: str = ""): if not os.path.isfile(filename): raise HTTPException(status_code=404, detail="File not found") if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs): raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") ext = os.path.splitext(filename)[1].lower()[1:] if ext not in allowed_preview_extensions(): raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.") # would profit from returning 304 return FileResponse(filename, headers={"Accept-Ranges": "bytes"}) def fetch_cover_images(extra_networks_tabname: str = "", item: str = "", index: int = 0): page = get_page_by_name(extra_networks_tabname) metadata = page.metadata.get(item) if metadata is None: raise HTTPException(status_code=404, detail="File not found") cover_images = json.loads(metadata.get("ssmd_cover_images", {})) image = cover_images[index] if index < len(cover_images) else None if not image: raise HTTPException(status_code=404, detail="File not found") try: image = Image.open(BytesIO(b64decode(image))) buffer = BytesIO() image.save(buffer, format=image.format) return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype()) except Exception as err: raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err def init_tree_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONResponse: """Generates the initial Tree View data and returns a simplified dataset. The data returned does not contain any HTML strings. Status Codes: 200 on success 503 when data is not ready 500 on any other error """ page = get_page_by_name(extra_networks_tabname) data = page.generate_tree_view_data(tabname) if data is None: return JSONResponse({}, status_code=503) return JSONResponse(data, status_code=200) def fetch_tree_data( extra_networks_tabname: str = "", div_ids: str = "", ) -> JSONResponse: """Retrieves Tree View HTML strings for the specified `div_ids`. Args: div_ids: A string with div_ids in CSV format. """ page = get_page_by_name(extra_networks_tabname) res = {} for div_id in div_ids.split(","): if div_id in page.tree: res[div_id] = page.tree[div_id].html return JSONResponse(res) def fetch_cards_data( extra_networks_tabname: str = "", div_ids: str = "", ) -> JSONResponse: """Retrieves Cards View HTML strings for the specified `div_ids`. Args: div_ids: A string with div_ids in CSV format. """ page = get_page_by_name(extra_networks_tabname) res = {} for div_id in div_ids.split(","): if div_id in page.cards: res[div_id] = page.cards[div_id].html return JSONResponse(res) def init_cards_data(tabname: str = "", extra_networks_tabname: str = "") -> JSONResponse: """Generates the initial Cards View data and returns a simplified dataset. The data returned does not contain any HTML strings. Status Codes: 200 on success 503 when data is not ready 500 on any other error """ page = get_page_by_name(extra_networks_tabname) data = page.generate_cards_view_data(tabname) if data is None: return JSONResponse({}, status_code=503) return JSONResponse(data, status_code=200) def page_is_ready(extra_networks_tabname: str = "") -> JSONResponse: """Returns whether the specified page is ready for fetching data. Status Codes: 200 ready 503 not ready 500 on any other error """ page = get_page_by_name(extra_networks_tabname) try: items_list = list(page.list_items()) if len(page.items) == len(items_list): return JSONResponse({}, status_code=200) return JSONResponse({"error": "page not ready"}, status_code=503) except Exception as exc: return JSONResponse({"error": str(exc)}, status_code=500) def get_metadata(extra_networks_tabname: str = "", item: str = "") -> JSONResponse: try: page = get_page_by_name(extra_networks_tabname) except HTTPException: return JSONResponse({}) metadata = page.metadata.get(item) if metadata is None: return JSONResponse({}) # those are cover images, and they are too big to display in UI as text # FIXME: WHY WAS THIS HERE? # metadata = {i: metadata[i] for i in metadata if i != 'ssmd_cover_images'} return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)}) def get_single_card(tabname: str = "", extra_networks_tabname: str = "", name: str = "") -> JSONResponse: page = get_page_by_name(extra_networks_tabname) try: item = page.create_item(name, enable_filter=False) page.items[name] = item except Exception as exc: errors.display(exc, "creating item for extra network") item = page.items.get(name, None) if item is None: return JSONResponse({}) page.read_user_metadata(item, use_cache=False) item_html = page.create_card_html(tabname=tabname, item=item) return JSONResponse({"html": item_html}) def add_pages_to_demo(app): app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"]) app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"]) app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"]) app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"]) app.add_api_route("/sd_extra_networks/init-tree-data", init_tree_data, methods=["GET"]) app.add_api_route("/sd_extra_networks/init-cards-data", init_cards_data, methods=["GET"]) app.add_api_route("/sd_extra_networks/fetch-tree-data", fetch_tree_data, methods=["GET"]) app.add_api_route("/sd_extra_networks/fetch-cards-data", fetch_cards_data, methods=["GET"]) app.add_api_route("/sd_extra_networks/page-is-ready", page_is_ready, methods=["GET"]) def quote_js(s): s = s.replace("\\", "\\\\") s = s.replace('"', '\\"') return f'"{s}"' class ExtraNetworksPage: def __init__(self, title): self.title = title self.name = title.lower() # This is the actual name of the extra networks tab (not txt2img/img2img). self.extra_networks_tabname = self.name.replace(" ", "_") self.allow_prompt = True self.allow_negative_prompt = False self.metadata = {} self.items = {} self.cards = {} self.tree = {} self.tree_roots = {} self.lister = util.MassFileLister() # HTML Templates self.pane_tpl = shared.html("extra-networks-pane.html") self.card_tpl = shared.html("extra-networks-card.html") self.tree_row_tpl = shared.html("extra-networks-tree-row.html") self.btn_copy_path_tpl = shared.html("extra-networks-btn-copy-path.html") self.btn_show_metadata_tpl = shared.html("extra-networks-btn-show-metadata.html") self.btn_edit_metadata_tpl = shared.html("extra-networks-btn-edit-metadata.html") self.btn_dirs_view_item_tpl = shared.html("extra-networks-btn-dirs-view-item.html") # Sorted lists # These just store ints so it won't use hardly any memory to just sort ahead # of time for each sort mode. These are lists of keys for each file. self.keys_sorted = {} self.keys_by_name = [] self.keys_by_path = [] self.keys_by_created = [] self.keys_by_modified = [] def refresh(self): # Whenever we refresh, we want to build our datasets from scratch. self.items = {} self.cards = {} self.tree = {} self.tree_roots = {} def read_user_metadata(self, item, use_cache=True): filename = item.get("filename", None) metadata = extra_networks.get_user_metadata(filename, lister=self.lister if use_cache else None) desc = metadata.get("description", None) if desc is not None: item["description"] = desc item["user_metadata"] = metadata def link_preview(self, filename): quoted_filename = urllib.parse.quote(filename.replace("\\", "/")) mtime, _ = self.lister.mctime(filename) return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}" def search_terms_from_path(self, filename, possible_directories=None): abspath = os.path.abspath(filename) for parentdir in possible_directories if possible_directories is not None else self.allowed_directories_for_previews(): parentdir = os.path.dirname(os.path.abspath(parentdir)) if abspath.startswith(parentdir): return os.path.relpath(abspath, parentdir) return "" def build_tree_html_row( self, tabname: str, label: str, btn_type: str, btn_title: Optional[str] = None, data_attributes: Optional[dict] = None, dir_is_empty: bool = False, item: Optional[dict] = None, onclick_extra: Optional[str] = None, ) -> str: """Generates HTML for a single row of the Tree View Args: tabname: "txt2img" or "img2img" label: The text to display for this row. btn_type: "dir" or "file" btn_title: Optional hover text for the row. Defaults to `label`. data_attributes: Dictionary defining data attributes to add to the row's tag. Ex: {"one": "1"} would generate
dir_is_empty: Whether the directory is empty. Only useful if btn_type=="dir". item: Dictionary containing item data such as filename, hash, etc. onclick_extra: Additional javascript code to add to the row's `onclick` attribute. """ if btn_type not in ["file", "dir"]: raise ValueError("Invalid button type:", btn_type) if data_attributes is None: data_attributes = {} label = label.strip() # If not specified, title will just reflect the label btn_title = btn_title.strip() if btn_title else label action_list_item_action_leading = "" action_list_item_visual_leading = "🗀" action_list_item_visual_trailing = "" action_list_item_action_trailing = "" if dir_is_empty: action_list_item_action_leading = "" if btn_type == "file": action_list_item_visual_leading = "🗎" # Action buttons if item is not None: action_list_item_action_trailing += self.get_button_row(tabname, item) data_attributes_str = "" for k, v in data_attributes.items(): if isinstance(v, (bool,)): # Boolean data attributes only need a key when true. if v: data_attributes_str += f"{k} " elif v not in [None, "", "''", '""']: data_attributes_str += f"{k}={v} " res = self.tree_row_tpl.format( **{ "data_attributes": data_attributes_str, "search_terms": "", "btn_type": btn_type, "btn_title": btn_title, "tabname": tabname, "onclick_extra": onclick_extra if onclick_extra else "", "extra_networks_tabname": self.extra_networks_tabname, "action_list_item_action_leading": action_list_item_action_leading, "action_list_item_visual_leading": action_list_item_visual_leading, "action_list_item_label": label, "action_list_item_visual_trailing": action_list_item_visual_trailing, "action_list_item_action_trailing": action_list_item_action_trailing, } ) res = res.strip() res = re.sub(" +", " ", res.replace("\n", "")) return res def get_button_row(self, tabname: str, item: dict) -> str: """Generates a row of buttons for use in Tree/Cards View items.""" metadata = item.get("metadata", None) name = item.get("name", "") filename = item.get("filename", "") button_row_tpl = ' ' btn_copy_path = self.btn_copy_path_tpl.format(filename=filename) btn_edit_item = self.btn_edit_metadata_tpl.format( tabname=tabname, extra_networks_tabname=self.extra_networks_tabname, name=name, ) btn_metadata = "" if metadata: btn_metadata = self.btn_show_metadata_tpl.format( extra_networks_tabname=self.extra_networks_tabname, name=name, ) return button_row_tpl.format( btn_copy_path=btn_copy_path, btn_edit_item=btn_edit_item, btn_metadata=btn_metadata, ) def create_card_html( self, tabname: str, item: dict, div_id: Optional[str] = None, ) -> str: """Generates HTML for a single ExtraNetworks Item. Args: tabname: The name of the active tab. item: Dictionary containing item information. template: Optional template string to use. Returns: HTML string generated for this item. Can be empty if the item is not meant to be shown. """ style = f"font-size: {shared.opts.extra_networks_card_text_scale*100}%;" if shared.opts.extra_networks_card_height: style += f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_width: style += f"width: {shared.opts.extra_networks_card_width}px;" background_image = "" preview = html.escape(item.get("preview", "") or "") if preview: background_image = f'' onclick = item.get("onclick", None) if onclick is None: onclick = html.escape(f"extraNetworksCardOnClick(event, '{tabname}_{self.extra_networks_tabname}');") button_row = self.get_button_row(tabname, item) filename = item.get("filename", "") # if this is true, the item must not be shown in the default view, # and must instead only be shown when searching for it if shared.opts.extra_networks_hidden_models == "Always": search_only = False else: search_only = filename.startswith(".") if search_only and shared.opts.extra_networks_hidden_models == "Never": return "" sort_keys = {} for sort_mode, sort_key in item.get("sort_keys", {}).items(): sort_keys[sort_mode.strip().lower()] = html.escape(str(sort_key)) search_terms_html = "" search_terms_tpl = "{search_term}" for search_term in item.get("search_terms", []): search_terms_html += search_terms_tpl.format( **{ "class": f"search_terms{' search_only' if search_only else ''}", "search_term": search_term, } ) description = "" if shared.opts.extra_networks_card_show_desc: description = item.get("description", "") or "" if not shared.opts.extra_networks_card_description_is_html: description = html.escape(description) data_attributes = { "data-div-id": div_id if div_id else "", "data-name": item.get("name", "").strip(), "data-path": item.get("filename", "").strip(), "data-hash": item.get("shorthash", None), "data-prompt": item.get("prompt", "").strip(), "data-neg-prompt": item.get("negative_prompt", "").strip(), "data-allow-neg": self.allow_negative_prompt, **{f"data-sort-{sort_mode}": sort_key for sort_mode, sort_key in sort_keys.items()}, } data_attributes_str = "" for k, v in data_attributes.items(): if isinstance(v, (bool,)): # Boolean data attributes only need a key when true. if v: data_attributes_str += f"{k} " elif v not in [None, "", "''", '""']: data_attributes_str += f"{k}={v} " return self.card_tpl.format( style=style, onclick=onclick, data_attributes=data_attributes_str, background_image=background_image, button_row=button_row, search_terms=search_terms_html, name=html.escape(item["name"].strip()), description=description, ) def generate_cards_view_data(self, tabname: str) -> dict: """Generates the datasets and HTML used to display the Cards View. Returns: A dictionary containing necessary info for the client. { search_keys: array of strings, sort_