mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge branch 'dev' into master
This commit is contained in:
commit
816096e642
@ -74,6 +74,7 @@ module.exports = {
|
||||
create_submit_args: "readonly",
|
||||
restart_reload: "readonly",
|
||||
updateInput: "readonly",
|
||||
onEdit: "readonly",
|
||||
//extraNetworks.js
|
||||
requestGet: "readonly",
|
||||
popup: "readonly",
|
||||
|
73
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
73
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -1,25 +1,45 @@
|
||||
name: Bug Report
|
||||
description: You think somethings is broken in the UI
|
||||
description: You think something is broken in the UI
|
||||
title: "[Bug]: "
|
||||
labels: ["bug-report"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
|
||||
options:
|
||||
- label: I have searched the existing issues and checked the recent builds/commits
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
*Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
|
||||
> The title of the bug report should be short and descriptive.
|
||||
> Use relevant keywords for searchability.
|
||||
> Do not leave it blank, but also do not put an entire error log in it.
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Checklist
|
||||
description: |
|
||||
Please perform basic debugging to see if extensions or configuration is the cause of the issue.
|
||||
Basic debug procedure
|
||||
1. Disable all third-party extensions - check if extension is the cause
|
||||
2. Update extensions and webui - sometimes things just need to be updated
|
||||
3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration
|
||||
4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed
|
||||
5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue
|
||||
Before making a issue report please, check that the issue hasn't been reported recently.
|
||||
options:
|
||||
- label: The issue exists after disabling all extensions
|
||||
- label: The issue exists on a clean installation of webui
|
||||
- label: The issue is caused by an extension, but I believe it is caused by a bug in the webui
|
||||
- label: The issue exists in the current version of the webui
|
||||
- label: The issue has not been reported before recently
|
||||
- label: The issue has been reported before but has not been fixed yet
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
> Please fill this form with as much information as possible. Don't forget to "Upload Sysinfo" and "What browsers" and provide screenshots if possible
|
||||
- type: textarea
|
||||
id: what-did
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: Tell us what happened in a very clear and simple way
|
||||
placeholder: |
|
||||
txt2img is not working as intended.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@ -27,9 +47,9 @@ body:
|
||||
attributes:
|
||||
label: Steps to reproduce the problem
|
||||
description: Please provide us with precise step by step instructions on how to reproduce the bug
|
||||
value: |
|
||||
1. Go to ....
|
||||
2. Press ....
|
||||
placeholder: |
|
||||
1. Go to ...
|
||||
2. Press ...
|
||||
3. ...
|
||||
validations:
|
||||
required: true
|
||||
@ -38,13 +58,8 @@ body:
|
||||
attributes:
|
||||
label: What should have happened?
|
||||
description: Tell us what you think the normal behavior should be
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: sysinfo
|
||||
attributes:
|
||||
label: Sysinfo
|
||||
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
||||
placeholder: |
|
||||
WebUI should ...
|
||||
validations:
|
||||
required: true
|
||||
- type: dropdown
|
||||
@ -58,12 +73,25 @@ body:
|
||||
- Brave
|
||||
- Apple Safari
|
||||
- Microsoft Edge
|
||||
- Android
|
||||
- iOS
|
||||
- Other
|
||||
- type: textarea
|
||||
id: sysinfo
|
||||
attributes:
|
||||
label: Sysinfo
|
||||
description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.
|
||||
placeholder: |
|
||||
1. Go to WebUI Settings -> Sysinfo -> Download system info.
|
||||
If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file
|
||||
2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Console logs
|
||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
|
||||
description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occured. If it's very long, provide a link to pastebin or similar service.
|
||||
render: Shell
|
||||
validations:
|
||||
required: true
|
||||
@ -71,4 +99,7 @@ body:
|
||||
id: misc
|
||||
attributes:
|
||||
label: Additional information
|
||||
description: Please provide us with any relevant additional info or context.
|
||||
description: |
|
||||
Please provide us with any relevant additional info or context.
|
||||
Examples:
|
||||
I have updated my GPU driver recently.
|
||||
|
@ -88,9 +88,10 @@ A browser interface based on Gradio library for Stable Diffusion.
|
||||
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
|
||||
- Now without any bad letters!
|
||||
- Load checkpoints in safetensors format
|
||||
- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
|
||||
- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64
|
||||
- Now with a license!
|
||||
- Reorder elements in the UI from settings screen
|
||||
- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support
|
||||
|
||||
## Installation and Running
|
||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:
|
||||
@ -103,7 +104,7 @@ Alternatively, use online services (like Google Colab):
|
||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||
|
||||
### Installation on Windows 10/11 with NVidia-GPUs using release package
|
||||
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
|
||||
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract its contents.
|
||||
2. Run `update.bat`.
|
||||
3. Run `run.bat`.
|
||||
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
|
||||
|
73
configs/alt-diffusion-m18-inference.yaml
Normal file
73
configs/alt-diffusion-m18-inference.yaml
Normal file
@ -0,0 +1,73 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: modules.xlmr_m18.BertSeriesModelWithTransformation
|
||||
params:
|
||||
name: "XLMR-Large"
|
33
extensions-builtin/Lora/lora_logger.py
Normal file
33
extensions-builtin/Lora/lora_logger.py
Normal file
@ -0,0 +1,33 @@
|
||||
import sys
|
||||
import copy
|
||||
import logging
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
COLORS = {
|
||||
"DEBUG": "\033[0;36m", # CYAN
|
||||
"INFO": "\033[0;32m", # GREEN
|
||||
"WARNING": "\033[0;33m", # YELLOW
|
||||
"ERROR": "\033[0;31m", # RED
|
||||
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
||||
"RESET": "\033[0m", # RESET COLOR
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
colored_record = copy.copy(record)
|
||||
levelname = colored_record.levelname
|
||||
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
||||
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
||||
return super().format(colored_record)
|
||||
|
||||
|
||||
logger = logging.getLogger("lora")
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(
|
||||
ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s")
|
||||
)
|
||||
logger.addHandler(handler)
|
@ -93,6 +93,7 @@ class Network: # LoraModule
|
||||
self.unet_multiplier = 1.0
|
||||
self.dyn_dim = None
|
||||
self.modules = {}
|
||||
self.bundle_embeddings = {}
|
||||
self.mtime = None
|
||||
|
||||
self.mentioned_name = None
|
||||
|
33
extensions-builtin/Lora/network_glora.py
Normal file
33
extensions-builtin/Lora/network_glora.py
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
import network
|
||||
|
||||
class ModuleTypeGLora(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]):
|
||||
return NetworkModuleGLora(net, weights)
|
||||
|
||||
return None
|
||||
|
||||
# adapted from https://github.com/KohakuBlueleaf/LyCORIS
|
||||
class NetworkModuleGLora(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.w1a = weights.w["a1.weight"]
|
||||
self.w1b = weights.w["b1.weight"]
|
||||
self.w2a = weights.w["a2.weight"]
|
||||
self.w2b = weights.w["b2.weight"]
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
|
||||
|
||||
output_shape = [w1a.size(0), w1b.size(1)]
|
||||
updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a))
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, output_shape)
|
@ -5,6 +5,7 @@ import re
|
||||
import lora_patches
|
||||
import network
|
||||
import network_lora
|
||||
import network_glora
|
||||
import network_hada
|
||||
import network_ia3
|
||||
import network_lokr
|
||||
@ -15,6 +16,9 @@ import torch
|
||||
from typing import Union
|
||||
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||
|
||||
from lora_logger import logger
|
||||
|
||||
module_types = [
|
||||
network_lora.ModuleTypeLora(),
|
||||
@ -23,6 +27,7 @@ module_types = [
|
||||
network_lokr.ModuleTypeLokr(),
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
network_glora.ModuleTypeGLora(),
|
||||
]
|
||||
|
||||
|
||||
@ -149,9 +154,19 @@ def load_network(name, network_on_disk):
|
||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
|
||||
|
||||
matched_networks = {}
|
||||
bundle_embeddings = {}
|
||||
|
||||
for key_network, weight in sd.items():
|
||||
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
||||
if key_network_without_network_parts == "bundle_emb":
|
||||
emb_name, vec_name = network_part.split(".", 1)
|
||||
emb_dict = bundle_embeddings.get(emb_name, {})
|
||||
if vec_name.split('.')[0] == 'string_to_param':
|
||||
_, k2 = vec_name.split('.', 1)
|
||||
emb_dict['string_to_param'] = {k2: weight}
|
||||
else:
|
||||
emb_dict[vec_name] = weight
|
||||
bundle_embeddings[emb_name] = emb_dict
|
||||
|
||||
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
@ -195,6 +210,14 @@ def load_network(name, network_on_disk):
|
||||
|
||||
net.modules[key] = net_module
|
||||
|
||||
embeddings = {}
|
||||
for emb_name, data in bundle_embeddings.items():
|
||||
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
|
||||
embedding.loaded = None
|
||||
embeddings[emb_name] = embedding
|
||||
|
||||
net.bundle_embeddings = embeddings
|
||||
|
||||
if keys_failed_to_match:
|
||||
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
|
||||
|
||||
@ -210,11 +233,15 @@ def purge_networks_from_memory():
|
||||
|
||||
|
||||
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
|
||||
emb_db = sd_hijack.model_hijack.embedding_db
|
||||
already_loaded = {}
|
||||
|
||||
for net in loaded_networks:
|
||||
if net.name in names:
|
||||
already_loaded[net.name] = net
|
||||
for emb_name, embedding in net.bundle_embeddings.items():
|
||||
if embedding.loaded:
|
||||
emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
|
||||
|
||||
loaded_networks.clear()
|
||||
|
||||
@ -257,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
|
||||
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
|
||||
loaded_networks.append(net)
|
||||
|
||||
for emb_name, embedding in net.bundle_embeddings.items():
|
||||
if embedding.loaded is None and emb_name in emb_db.word_embeddings:
|
||||
logger.warning(
|
||||
f'Skip bundle embedding: "{emb_name}"'
|
||||
' as it was already loaded from embeddings folder'
|
||||
)
|
||||
continue
|
||||
|
||||
embedding.loaded = False
|
||||
if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
|
||||
embedding.loaded = True
|
||||
emb_db.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
emb_db.skipped_embeddings[name] = embedding
|
||||
|
||||
if failed_to_load_networks:
|
||||
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
|
||||
|
||||
@ -420,6 +462,7 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||
self.network_weights_backup = None
|
||||
self.network_bias_backup = None
|
||||
|
||||
|
||||
def network_Linear_forward(self, input):
|
||||
if shared.opts.lora_functional:
|
||||
return network_forward(self, input, originals.Linear_forward)
|
||||
@ -564,6 +607,7 @@ extra_network_lora = None
|
||||
available_networks = {}
|
||||
available_network_aliases = {}
|
||||
loaded_networks = []
|
||||
loaded_bundle_embeddings = {}
|
||||
networks_in_memory = {}
|
||||
available_network_hash_lookup = {}
|
||||
forbidden_network_aliases = {}
|
||||
|
@ -12,6 +12,8 @@ function isMobile() {
|
||||
}
|
||||
|
||||
function reportWindowSize() {
|
||||
if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout
|
||||
|
||||
var currentlyMobile = isMobile();
|
||||
if (currentlyMobile == isSetupForMobile) return;
|
||||
isSetupForMobile = currentlyMobile;
|
||||
|
2
javascript/dragdrop.js
vendored
2
javascript/dragdrop.js
vendored
@ -119,7 +119,7 @@ window.addEventListener('paste', e => {
|
||||
}
|
||||
|
||||
const firstFreeImageField = visibleImageFields
|
||||
.filter(el => el.querySelector('input[type=file]'))?.[0];
|
||||
.filter(el => !el.querySelector('img'))?.[0];
|
||||
|
||||
dropReplaceImage(
|
||||
firstFreeImageField ?
|
||||
|
@ -18,37 +18,43 @@ function keyupEditAttention(event) {
|
||||
const before = text.substring(0, selectionStart);
|
||||
let beforeParen = before.lastIndexOf(OPEN);
|
||||
if (beforeParen == -1) return false;
|
||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
||||
beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1);
|
||||
}
|
||||
|
||||
let beforeClosingParen = before.lastIndexOf(CLOSE);
|
||||
if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;
|
||||
|
||||
// Find closing parenthesis around current cursor
|
||||
const after = text.substring(selectionStart);
|
||||
let afterParen = after.indexOf(CLOSE);
|
||||
if (afterParen == -1) return false;
|
||||
let afterParenOpen = after.indexOf(OPEN);
|
||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
||||
afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1);
|
||||
}
|
||||
if (beforeParen === -1 || afterParen === -1) return false;
|
||||
|
||||
let afterOpeningParen = after.indexOf(OPEN);
|
||||
if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;
|
||||
|
||||
// Set the selection to the text between the parenthesis
|
||||
const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);
|
||||
const lastColon = parenContent.lastIndexOf(":");
|
||||
selectionStart = beforeParen + 1;
|
||||
selectionEnd = selectionStart + lastColon;
|
||||
if (/.*:-?[\d.]+/s.test(parenContent)) {
|
||||
const lastColon = parenContent.lastIndexOf(":");
|
||||
selectionStart = beforeParen + 1;
|
||||
selectionEnd = selectionStart + lastColon;
|
||||
} else {
|
||||
selectionStart = beforeParen + 1;
|
||||
selectionEnd = selectionStart + parenContent.length;
|
||||
}
|
||||
|
||||
target.setSelectionRange(selectionStart, selectionEnd);
|
||||
return true;
|
||||
}
|
||||
|
||||
function selectCurrentWord() {
|
||||
if (selectionStart !== selectionEnd) return false;
|
||||
const delimiters = opts.keyedit_delimiters + " \r\n\t";
|
||||
const whitespace_delimiters = {"Tab": "\t", "Carriage Return": "\r", "Line Feed": "\n"};
|
||||
let delimiters = opts.keyedit_delimiters;
|
||||
|
||||
// seek backward until to find beggining
|
||||
for (let i of opts.keyedit_delimiters_whitespace) {
|
||||
delimiters += whitespace_delimiters[i];
|
||||
}
|
||||
|
||||
// seek backward to find beginning
|
||||
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
||||
selectionStart--;
|
||||
}
|
||||
@ -63,7 +69,7 @@ function keyupEditAttention(event) {
|
||||
}
|
||||
|
||||
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
||||
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
||||
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {
|
||||
selectCurrentWord();
|
||||
}
|
||||
|
||||
@ -71,33 +77,54 @@ function keyupEditAttention(event) {
|
||||
|
||||
var closeCharacter = ')';
|
||||
var delta = opts.keyedit_precision_attention;
|
||||
var start = selectionStart > 0 ? text[selectionStart - 1] : "";
|
||||
var end = text[selectionEnd];
|
||||
|
||||
if (selectionStart > 0 && text[selectionStart - 1] == '<') {
|
||||
if (start == '<') {
|
||||
closeCharacter = '>';
|
||||
delta = opts.keyedit_precision_extra;
|
||||
} else if (selectionStart == 0 || text[selectionStart - 1] != "(") {
|
||||
} else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))
|
||||
let numParen = 0;
|
||||
|
||||
while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {
|
||||
numParen++;
|
||||
}
|
||||
|
||||
if (start == "[") {
|
||||
weight = (1 / 1.1) ** numParen;
|
||||
} else {
|
||||
weight = 1.1 ** numParen;
|
||||
}
|
||||
|
||||
weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;
|
||||
|
||||
text = text.slice(0, selectionStart - numParen) + "(" + text.slice(selectionStart, selectionEnd) + ":" + weight + ")" + text.slice(selectionEnd + numParen);
|
||||
selectionStart -= numParen - 1;
|
||||
selectionEnd -= numParen - 1;
|
||||
} else if (start != '(') {
|
||||
// do not include spaces at the end
|
||||
while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {
|
||||
selectionEnd -= 1;
|
||||
selectionEnd--;
|
||||
}
|
||||
|
||||
if (selectionStart == selectionEnd) {
|
||||
return;
|
||||
}
|
||||
|
||||
text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd);
|
||||
|
||||
selectionStart += 1;
|
||||
selectionEnd += 1;
|
||||
selectionStart++;
|
||||
selectionEnd++;
|
||||
}
|
||||
|
||||
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||
if (text[selectionEnd] != ':') return;
|
||||
var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));
|
||||
if (isNaN(weight)) return;
|
||||
|
||||
weight += isPlus ? delta : -delta;
|
||||
weight = parseFloat(weight.toPrecision(12));
|
||||
if (String(weight).length == 1) weight += ".0";
|
||||
if (Number.isInteger(weight)) weight += ".0";
|
||||
|
||||
if (closeCharacter == ')' && weight == 1) {
|
||||
var endParenPos = text.substring(selectionEnd).indexOf(')');
|
||||
@ -105,7 +132,7 @@ function keyupEditAttention(event) {
|
||||
selectionStart--;
|
||||
selectionEnd--;
|
||||
} else {
|
||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + end);
|
||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);
|
||||
}
|
||||
|
||||
target.focus();
|
||||
|
@ -26,8 +26,9 @@ function setupExtraNetworksForTab(tabname) {
|
||||
var refresh = gradioApp().getElementById(tabname + '_extra_refresh');
|
||||
var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs');
|
||||
var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input');
|
||||
var promptContainer = gradioApp().querySelector('.prompt-container-compact#' + tabname + '_prompt_container');
|
||||
var negativePrompt = gradioApp().querySelector('#' + tabname + '_neg_prompt');
|
||||
|
||||
sort.dataset.sortkey = 'sortDefault';
|
||||
tabs.appendChild(searchDiv);
|
||||
tabs.appendChild(sort);
|
||||
tabs.appendChild(sortOrder);
|
||||
@ -49,20 +50,23 @@ function setupExtraNetworksForTab(tabname) {
|
||||
|
||||
elem.style.display = visible ? "" : "none";
|
||||
});
|
||||
|
||||
applySort();
|
||||
};
|
||||
|
||||
var applySort = function() {
|
||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
||||
|
||||
var reverse = sortOrder.classList.contains("sortReverse");
|
||||
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim();
|
||||
sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : "";
|
||||
var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : "";
|
||||
if (!sortKey || sortKeyStore == sort.dataset.sortkey) {
|
||||
var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim() || "name";
|
||||
sortKey = "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);
|
||||
var sortKeyStore = sortKey + "-" + (reverse ? "Descending" : "Ascending") + "-" + cards.length;
|
||||
|
||||
if (sortKeyStore == sort.dataset.sortkey) {
|
||||
return;
|
||||
}
|
||||
|
||||
sort.dataset.sortkey = sortKeyStore;
|
||||
|
||||
var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card');
|
||||
cards.forEach(function(card) {
|
||||
card.originalParentElement = card.parentElement;
|
||||
});
|
||||
@ -88,15 +92,13 @@ function setupExtraNetworksForTab(tabname) {
|
||||
};
|
||||
|
||||
search.addEventListener("input", applyFilter);
|
||||
applyFilter();
|
||||
["change", "blur", "click"].forEach(function(evt) {
|
||||
sort.querySelector("input").addEventListener(evt, applySort);
|
||||
});
|
||||
sortOrder.addEventListener("click", function() {
|
||||
sortOrder.classList.toggle("sortReverse");
|
||||
applySort();
|
||||
});
|
||||
applyFilter();
|
||||
|
||||
extraNetworksApplySort[tabname] = applySort;
|
||||
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||
|
||||
var showDirsUpdate = function() {
|
||||
@ -109,11 +111,47 @@ function setupExtraNetworksForTab(tabname) {
|
||||
showDirsUpdate();
|
||||
}
|
||||
|
||||
function extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {
|
||||
if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout
|
||||
|
||||
var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');
|
||||
var prompt = gradioApp().getElementById(tabname + '_prompt_row');
|
||||
var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');
|
||||
var elem = id ? gradioApp().getElementById(id) : null;
|
||||
|
||||
if (showNegativePrompt && elem) {
|
||||
elem.insertBefore(negPrompt, elem.firstChild);
|
||||
} else {
|
||||
promptContainer.insertBefore(negPrompt, promptContainer.firstChild);
|
||||
}
|
||||
|
||||
if (showPrompt && elem) {
|
||||
elem.insertBefore(prompt, elem.firstChild);
|
||||
} else {
|
||||
promptContainer.insertBefore(prompt, promptContainer.firstChild);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function extraNetworksUrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)
|
||||
extraNetworksMovePromptToTab(tabname, '', false, false);
|
||||
}
|
||||
|
||||
function extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt) { // called from python when user selects an extra networks tab
|
||||
extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);
|
||||
|
||||
}
|
||||
|
||||
function applyExtraNetworkFilter(tabname) {
|
||||
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||
}
|
||||
|
||||
function applyExtraNetworkSort(tabname) {
|
||||
setTimeout(extraNetworksApplySort[tabname], 1);
|
||||
}
|
||||
|
||||
var extraNetworksApplyFilter = {};
|
||||
var extraNetworksApplySort = {};
|
||||
var activePromptTextarea = {};
|
||||
|
||||
function setupExtraNetworks() {
|
||||
@ -140,14 +178,15 @@ function setupExtraNetworks() {
|
||||
|
||||
onUiLoaded(setupExtraNetworks);
|
||||
|
||||
var re_extranet = /<([^:]+:[^:]+):[\d.]+>(.*)/;
|
||||
var re_extranet_g = /\s+<([^:]+:[^:]+):[\d.]+>/g;
|
||||
var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/;
|
||||
var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g;
|
||||
|
||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||
var m = text.match(re_extranet);
|
||||
var replaced = false;
|
||||
var newTextareaText;
|
||||
if (m) {
|
||||
var extraTextBeforeNet = opts.extra_networks_add_text_separator;
|
||||
var extraTextAfterNet = m[2];
|
||||
var partToSearch = m[1];
|
||||
var foundAtPosition = -1;
|
||||
@ -161,8 +200,13 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
||||
return found;
|
||||
});
|
||||
|
||||
if (foundAtPosition >= 0 && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
||||
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
||||
if (foundAtPosition >= 0) {
|
||||
if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {
|
||||
newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);
|
||||
}
|
||||
if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {
|
||||
newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found) {
|
||||
@ -216,27 +260,24 @@ function extraNetworksSearchButton(tabs_id, event) {
|
||||
|
||||
var globalPopup = null;
|
||||
var globalPopupInner = null;
|
||||
|
||||
function closePopup() {
|
||||
if (!globalPopup) return;
|
||||
|
||||
globalPopup.style.display = "none";
|
||||
}
|
||||
|
||||
function popup(contents) {
|
||||
if (!globalPopup) {
|
||||
globalPopup = document.createElement('div');
|
||||
globalPopup.onclick = closePopup;
|
||||
globalPopup.classList.add('global-popup');
|
||||
|
||||
var close = document.createElement('div');
|
||||
close.classList.add('global-popup-close');
|
||||
close.onclick = closePopup;
|
||||
close.addEventListener("click", closePopup);
|
||||
close.title = "Close";
|
||||
globalPopup.appendChild(close);
|
||||
|
||||
globalPopupInner = document.createElement('div');
|
||||
globalPopupInner.onclick = function(event) {
|
||||
event.stopPropagation(); return false;
|
||||
};
|
||||
globalPopupInner.classList.add('global-popup-inner');
|
||||
globalPopup.appendChild(globalPopupInner);
|
||||
|
||||
@ -335,7 +376,7 @@ function extraNetworksEditUserMetadata(event, tabname, extraPage, cardName) {
|
||||
function extraNetworksRefreshSingleCard(page, tabname, name) {
|
||||
requestGet("./sd_extra_networks/get-single-card", {page: page, tabname: tabname, name: name}, function(data) {
|
||||
if (data && data.html) {
|
||||
var card = gradioApp().querySelector('.card[data-name=' + JSON.stringify(name) + ']'); // likely using the wrong stringify function
|
||||
var card = gradioApp().querySelector(`#${tabname}_${page.replace(" ", "_")}_cards > .card[data-name="${name}"]`);
|
||||
|
||||
var newDiv = document.createElement('DIV');
|
||||
newDiv.innerHTML = data.html;
|
||||
|
@ -33,8 +33,11 @@ function updateOnBackgroundChange() {
|
||||
const modalImage = gradioApp().getElementById("modalImage");
|
||||
if (modalImage && modalImage.offsetParent) {
|
||||
let currentButton = selected_gallery_button();
|
||||
|
||||
if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||
let preview = gradioApp().querySelectorAll('.livePreview > img');
|
||||
if (preview.length > 0) {
|
||||
// show preview image if available
|
||||
modalImage.src = preview[preview.length - 1].src;
|
||||
} else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {
|
||||
modalImage.src = currentButton.children[0].src;
|
||||
if (modalImage.style.display === 'none') {
|
||||
const modal = gradioApp().getElementById("lightboxModal");
|
||||
|
@ -1,37 +1,68 @@
|
||||
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutationRecord) {
|
||||
var elem = mutationRecord.target;
|
||||
var open = elem.classList.contains('open');
|
||||
|
||||
var accordion = elem.parentNode;
|
||||
accordion.classList.toggle('input-accordion-open', open);
|
||||
|
||||
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||
checkbox.checked = open;
|
||||
updateInput(checkbox);
|
||||
|
||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||
if (extra) {
|
||||
extra.style.display = open ? "" : "none";
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
function inputAccordionChecked(id, checked) {
|
||||
var label = gradioApp().querySelector('#' + id + " .label-wrap");
|
||||
if (label.classList.contains('open') != checked) {
|
||||
label.click();
|
||||
var accordion = gradioApp().getElementById(id);
|
||||
accordion.visibleCheckbox.checked = checked;
|
||||
accordion.onVisibleCheckboxChange();
|
||||
}
|
||||
|
||||
function setupAccordion(accordion) {
|
||||
var labelWrap = accordion.querySelector('.label-wrap');
|
||||
var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
|
||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||
var span = labelWrap.querySelector('span');
|
||||
var linked = true;
|
||||
|
||||
var isOpen = function() {
|
||||
return labelWrap.classList.contains('open');
|
||||
};
|
||||
|
||||
var observerAccordionOpen = new MutationObserver(function(mutations) {
|
||||
mutations.forEach(function(mutationRecord) {
|
||||
accordion.classList.toggle('input-accordion-open', isOpen());
|
||||
|
||||
if (linked) {
|
||||
accordion.visibleCheckbox.checked = isOpen();
|
||||
accordion.onVisibleCheckboxChange();
|
||||
}
|
||||
});
|
||||
});
|
||||
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||
|
||||
if (extra) {
|
||||
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||
}
|
||||
|
||||
accordion.onChecked = function(checked) {
|
||||
if (isOpen() != checked) {
|
||||
labelWrap.click();
|
||||
}
|
||||
};
|
||||
|
||||
var visibleCheckbox = document.createElement('INPUT');
|
||||
visibleCheckbox.type = 'checkbox';
|
||||
visibleCheckbox.checked = isOpen();
|
||||
visibleCheckbox.id = accordion.id + "-visible-checkbox";
|
||||
visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox";
|
||||
span.insertBefore(visibleCheckbox, span.firstChild);
|
||||
|
||||
accordion.visibleCheckbox = visibleCheckbox;
|
||||
accordion.onVisibleCheckboxChange = function() {
|
||||
if (linked && isOpen() != visibleCheckbox.checked) {
|
||||
labelWrap.click();
|
||||
}
|
||||
|
||||
gradioCheckbox.checked = visibleCheckbox.checked;
|
||||
updateInput(gradioCheckbox);
|
||||
};
|
||||
|
||||
visibleCheckbox.addEventListener('click', function(event) {
|
||||
linked = false;
|
||||
event.stopPropagation();
|
||||
});
|
||||
visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);
|
||||
}
|
||||
|
||||
onUiLoaded(function() {
|
||||
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
|
||||
var labelWrap = accordion.querySelector('.label-wrap');
|
||||
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
|
||||
|
||||
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
|
||||
if (extra) {
|
||||
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
|
||||
}
|
||||
setupAccordion(accordion);
|
||||
}
|
||||
});
|
||||
|
@ -26,7 +26,11 @@ onAfterUiUpdate(function() {
|
||||
lastHeadImg = headImg;
|
||||
|
||||
// play notification sound if available
|
||||
gradioApp().querySelector('#audio_notification audio')?.play();
|
||||
const notificationAudio = gradioApp().querySelector('#audio_notification audio');
|
||||
if (notificationAudio) {
|
||||
notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;
|
||||
notificationAudio.play();
|
||||
}
|
||||
|
||||
if (document.hasFocus()) return;
|
||||
|
||||
|
46
javascript/settings.js
Normal file
46
javascript/settings.js
Normal file
@ -0,0 +1,46 @@
|
||||
let settingsExcludeTabsFromShowAll = {
|
||||
settings_tab_defaults: 1,
|
||||
settings_tab_sysinfo: 1,
|
||||
settings_tab_actions: 1,
|
||||
settings_tab_licenses: 1,
|
||||
};
|
||||
|
||||
function settingsShowAllTabs() {
|
||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
|
||||
if (settingsExcludeTabsFromShowAll[elem.id]) return;
|
||||
|
||||
elem.style.display = "block";
|
||||
});
|
||||
}
|
||||
|
||||
function settingsShowOneTab() {
|
||||
gradioApp().querySelector('#settings_show_one_page').click();
|
||||
}
|
||||
|
||||
onUiLoaded(function() {
|
||||
var edit = gradioApp().querySelector('#settings_search');
|
||||
var editTextarea = gradioApp().querySelector('#settings_search > label > input');
|
||||
var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages');
|
||||
var settings_tabs = gradioApp().querySelector('#settings div');
|
||||
|
||||
onEdit('settingsSearch', editTextarea, 250, function() {
|
||||
var searchText = (editTextarea.value || "").trim().toLowerCase();
|
||||
|
||||
gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) {
|
||||
var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1;
|
||||
elem.style.display = visible ? "" : "none";
|
||||
});
|
||||
|
||||
if (searchText != "") {
|
||||
settingsShowAllTabs();
|
||||
} else {
|
||||
settingsShowOneTab();
|
||||
}
|
||||
});
|
||||
|
||||
settings_tabs.insertBefore(edit, settings_tabs.firstChild);
|
||||
settings_tabs.appendChild(buttonShowAllPages);
|
||||
|
||||
|
||||
buttonShowAllPages.addEventListener("click", settingsShowAllTabs);
|
||||
});
|
@ -1,10 +1,9 @@
|
||||
let promptTokenCountDebounceTime = 800;
|
||||
let promptTokenCountTimeouts = {};
|
||||
var promptTokenCountUpdateFunctions = {};
|
||||
let promptTokenCountUpdateFunctions = {};
|
||||
|
||||
function update_txt2img_tokens(...args) {
|
||||
// Called from Gradio
|
||||
update_token_counter("txt2img_token_button");
|
||||
update_token_counter("txt2img_negative_token_button");
|
||||
if (args.length == 2) {
|
||||
return args[0];
|
||||
}
|
||||
@ -14,6 +13,7 @@ function update_txt2img_tokens(...args) {
|
||||
function update_img2img_tokens(...args) {
|
||||
// Called from Gradio
|
||||
update_token_counter("img2img_token_button");
|
||||
update_token_counter("img2img_negative_token_button");
|
||||
if (args.length == 2) {
|
||||
return args[0];
|
||||
}
|
||||
@ -21,16 +21,7 @@ function update_img2img_tokens(...args) {
|
||||
}
|
||||
|
||||
function update_token_counter(button_id) {
|
||||
if (opts.disable_token_counters) {
|
||||
return;
|
||||
}
|
||||
if (promptTokenCountTimeouts[button_id]) {
|
||||
clearTimeout(promptTokenCountTimeouts[button_id]);
|
||||
}
|
||||
promptTokenCountTimeouts[button_id] = setTimeout(
|
||||
() => gradioApp().getElementById(button_id)?.click(),
|
||||
promptTokenCountDebounceTime,
|
||||
);
|
||||
promptTokenCountUpdateFunctions[button_id]?.();
|
||||
}
|
||||
|
||||
|
||||
@ -69,10 +60,11 @@ function setupTokenCounting(id, id_counter, id_button) {
|
||||
prompt.parentElement.insertBefore(counter, prompt);
|
||||
prompt.parentElement.style.position = "relative";
|
||||
|
||||
promptTokenCountUpdateFunctions[id] = function() {
|
||||
update_token_counter(id_button);
|
||||
};
|
||||
textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]);
|
||||
var func = onEdit(id, textarea, 800, function() {
|
||||
gradioApp().getElementById(id_button)?.click();
|
||||
});
|
||||
promptTokenCountUpdateFunctions[id] = func;
|
||||
promptTokenCountUpdateFunctions[id_button] = func;
|
||||
}
|
||||
|
||||
function setupTokenCounters() {
|
||||
|
@ -263,21 +263,6 @@ onAfterUiUpdate(function() {
|
||||
json_elem.parentElement.style.display = "none";
|
||||
|
||||
setupTokenCounters();
|
||||
|
||||
var show_all_pages = gradioApp().getElementById('settings_show_all_pages');
|
||||
var settings_tabs = gradioApp().querySelector('#settings div');
|
||||
if (show_all_pages && settings_tabs) {
|
||||
settings_tabs.appendChild(show_all_pages);
|
||||
show_all_pages.onclick = function() {
|
||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {
|
||||
if (elem.id == "settings_tab_licenses") {
|
||||
return;
|
||||
}
|
||||
|
||||
elem.style.display = "block";
|
||||
});
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
onOptionsChanged(function() {
|
||||
@ -366,3 +351,20 @@ function switchWidthHeight(tabname) {
|
||||
updateInput(height);
|
||||
return [];
|
||||
}
|
||||
|
||||
|
||||
var onEditTimers = {};
|
||||
|
||||
// calls func after afterMs milliseconds has passed since the input elem has beed enited by user
|
||||
function onEdit(editId, elem, afterMs, func) {
|
||||
var edited = function() {
|
||||
var existingTimer = onEditTimers[editId];
|
||||
if (existingTimer) clearTimeout(existingTimer);
|
||||
|
||||
onEditTimers[editId] = setTimeout(func, afterMs);
|
||||
};
|
||||
|
||||
elem.addEventListener("input", edited);
|
||||
|
||||
return edited;
|
||||
}
|
||||
|
@ -17,19 +17,18 @@ from fastapi.encoders import jsonable_encoder
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
||||
from modules.api import models
|
||||
from modules.shared import opts
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||
from modules.textual_inversion.preprocess import preprocess
|
||||
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
||||
from PIL import PngImagePlugin,Image
|
||||
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
||||
from PIL import PngImagePlugin, Image
|
||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||
from modules.realesrgan_model import get_realesrgan_models
|
||||
from modules import devices
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
import piexif
|
||||
import piexif.helper
|
||||
from contextlib import closing
|
||||
@ -103,7 +102,8 @@ def decode_base64_to_image(encoding):
|
||||
|
||||
def encode_pil_to_base64(image):
|
||||
with io.BytesIO() as output_bytes:
|
||||
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
if opts.samples_format.lower() == 'png':
|
||||
use_metadata = False
|
||||
metadata = PngImagePlugin.PngInfo()
|
||||
@ -221,15 +221,15 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
|
||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
|
||||
self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
|
||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
|
||||
self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
|
||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
|
||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
|
||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
|
||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
|
||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
|
||||
@ -242,7 +242,8 @@ class Api:
|
||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
|
||||
self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
|
||||
|
||||
if shared.cmd_opts.api_server_stop:
|
||||
self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
|
||||
@ -473,9 +474,6 @@ class Api:
|
||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||
if(not req.image.strip()):
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
image = decode_base64_to_image(req.image.strip())
|
||||
if image is None:
|
||||
return models.PNGInfoResponse(info="")
|
||||
@ -484,9 +482,10 @@ class Api:
|
||||
if geninfo is None:
|
||||
geninfo = ""
|
||||
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
|
||||
script_callbacks.infotext_pasted_callback(geninfo, params)
|
||||
|
||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||
return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
|
||||
|
||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
@ -541,12 +540,12 @@ class Api:
|
||||
return {}
|
||||
|
||||
def unloadapi(self):
|
||||
unload_model_weights()
|
||||
sd_models.unload_model_weights()
|
||||
|
||||
return {}
|
||||
|
||||
def reloadapi(self):
|
||||
reload_model_weights()
|
||||
sd_models.send_model_to_device(shared.sd_model)
|
||||
|
||||
return {}
|
||||
|
||||
@ -564,9 +563,9 @@ class Api:
|
||||
|
||||
return options
|
||||
|
||||
def set_config(self, req: Dict[str, Any]):
|
||||
def set_config(self, req: dict[str, Any]):
|
||||
checkpoint_name = req.get("sd_model_checkpoint", None)
|
||||
if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
||||
if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
||||
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
||||
|
||||
for k, v in req.items():
|
||||
@ -770,6 +769,25 @@ class Api:
|
||||
cuda = {'error': f'{err}'}
|
||||
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||
|
||||
def get_extensions_list(self):
|
||||
from modules import extensions
|
||||
extensions.list_extensions()
|
||||
ext_list = []
|
||||
for ext in extensions.extensions:
|
||||
ext: extensions.Extension
|
||||
ext.read_info_from_repo()
|
||||
if ext.remote is not None:
|
||||
ext_list.append({
|
||||
"name": ext.name,
|
||||
"remote": ext.remote,
|
||||
"branch": ext.branch,
|
||||
"commit_hash":ext.commit_hash,
|
||||
"commit_date":ext.commit_date,
|
||||
"version":ext.version,
|
||||
"enabled":ext.enabled
|
||||
})
|
||||
return ext_list
|
||||
|
||||
def launch(self, server_name, port, root_path):
|
||||
self.app.include_router(self.router)
|
||||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
|
||||
|
@ -1,12 +1,10 @@
|
||||
import inspect
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import Literal
|
||||
from typing import Any, Optional, Literal
|
||||
from inflection import underscore
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
|
||||
from modules.shared import sd_upscalers, opts, parser
|
||||
from typing import Dict, List
|
||||
|
||||
API_NOT_ALLOWED = [
|
||||
"self",
|
||||
@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
class ImageToImageResponse(BaseModel):
|
||||
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
|
||||
parameters: dict
|
||||
info: str
|
||||
|
||||
@ -168,17 +166,18 @@ class FileData(BaseModel):
|
||||
name: str = Field(title="File name")
|
||||
|
||||
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
|
||||
imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
|
||||
|
||||
class ExtrasBatchImagesResponse(ExtraBaseResponse):
|
||||
images: List[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
|
||||
|
||||
class PNGInfoRequest(BaseModel):
|
||||
image: str = Field(title="Image", description="The base64 encoded PNG image")
|
||||
|
||||
class PNGInfoResponse(BaseModel):
|
||||
info: str = Field(title="Image info", description="A string with the parameters used to generate the image")
|
||||
items: dict = Field(title="Items", description="An object containing all the info the image had")
|
||||
items: dict = Field(title="Items", description="A dictionary containing all the other fields the image had")
|
||||
parameters: dict = Field(title="Parameters", description="A dictionary with parsed generation info fields")
|
||||
|
||||
class ProgressRequest(BaseModel):
|
||||
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
|
||||
@ -232,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
|
||||
|
||||
class SamplerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
aliases: List[str] = Field(title="Aliases")
|
||||
options: Dict[str, str] = Field(title="Options")
|
||||
aliases: list[str] = Field(title="Aliases")
|
||||
options: dict[str, str] = Field(title="Options")
|
||||
|
||||
class UpscalerItem(BaseModel):
|
||||
name: str = Field(title="Name")
|
||||
@ -284,8 +283,8 @@ class EmbeddingItem(BaseModel):
|
||||
vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
|
||||
|
||||
class EmbeddingsResponse(BaseModel):
|
||||
loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
|
||||
skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
ram: dict = Field(title="RAM", description="System memory stats")
|
||||
@ -303,11 +302,20 @@ class ScriptArg(BaseModel):
|
||||
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||
|
||||
|
||||
class ScriptInfo(BaseModel):
|
||||
name: str = Field(default=None, title="Name", description="Script name")
|
||||
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||
|
||||
class ExtensionItem(BaseModel):
|
||||
name: str = Field(title="Name", description="Extension name")
|
||||
remote: str = Field(title="Remote", description="Extension Repository URL")
|
||||
branch: str = Field(title="Branch", description="Extension Repository Branch")
|
||||
commit_hash: str = Field(title="Commit Hash", description="Extension Repository Commit Hash")
|
||||
version: str = Field(title="Version", description="Extension Version")
|
||||
commit_date: str = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||
|
@ -90,7 +90,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
|
||||
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
|
||||
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
|
||||
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
|
||||
parser.add_argument("--enable-console-prompts", action='store_true', help="does not do anything", default=False) # Legacy compatibility, use as default value shared.opts.enable_console_prompts
|
||||
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
|
||||
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
|
||||
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
|
||||
@ -107,13 +107,14 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
|
||||
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions")
|
||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||
parser.add_argument('--add-stop-route', action='store_true', help='does not do anything')
|
||||
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
|
||||
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
|
||||
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False)
|
||||
parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", )
|
||||
|
@ -4,7 +4,6 @@ Supports saving and restoring webui and extensions from a known working set of c
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import tqdm
|
||||
|
||||
from datetime import datetime
|
||||
@ -38,7 +37,7 @@ def list_config_states():
|
||||
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||
|
||||
for cs in config_states:
|
||||
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||
timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
|
||||
name = cs.get("name", "Config")
|
||||
full_name = f"{name}: {timestamp}"
|
||||
all_config_states[full_name] = cs
|
||||
|
@ -60,7 +60,8 @@ def enable_tf32():
|
||||
|
||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
|
||||
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -9,7 +9,7 @@ from modules.paths import data_path
|
||||
from modules import shared, ui_tempdir, script_callbacks, processing
|
||||
from PIL import Image
|
||||
|
||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
|
||||
re_param = re.compile(re_param_code)
|
||||
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
|
||||
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
|
||||
|
@ -9,6 +9,7 @@ from modules import paths, shared, devices, modelloader, errors
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_file_path = None
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
@ -17,6 +18,7 @@ loaded_gfpgan_model = None
|
||||
def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
global model_path
|
||||
global model_file_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||
return loaded_gfpgan_model
|
||||
@ -24,17 +26,24 @@ def gfpgann():
|
||||
if gfpgan_constructor is None:
|
||||
return None
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
||||
|
||||
if len(models) == 1 and models[0].startswith("http"):
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
latest_file = max(models, key=os.path.getctime)
|
||||
gfp_models = []
|
||||
for item in models:
|
||||
if 'GFPGAN' in os.path.basename(item):
|
||||
gfp_models.append(item)
|
||||
latest_file = max(gfp_models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||
model_file_path = model_file
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
@ -77,19 +86,25 @@ def setup_model(dirname):
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
global model_file_path
|
||||
|
||||
facexlib_path = model_path
|
||||
|
||||
if dirname is not None:
|
||||
facexlib_path = dirname
|
||||
|
||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
|
@ -23,7 +23,7 @@ class Git(git.Git):
|
||||
)
|
||||
return self._parse_object_header(ret)
|
||||
|
||||
def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
|
||||
def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
|
||||
# Not really streaming, per se; this buffers the entire object in memory.
|
||||
# Shouldn't be a problem for our use case, since we're only using this for
|
||||
# object headers (commit objects).
|
||||
|
@ -468,7 +468,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
|
||||
shared.reload_hypernetworks()
|
||||
|
||||
|
||||
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):
|
||||
from modules import images, processing
|
||||
|
||||
save_hypernetwork_every = save_hypernetwork_every or 0
|
||||
@ -698,7 +698,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
|
@ -561,6 +561,8 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
|
||||
})
|
||||
|
||||
piexif.insert(exif_bytes, filename)
|
||||
elif extension.lower() == ".gif":
|
||||
image.save(filename, format=image_format, comment=geninfo)
|
||||
else:
|
||||
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||
|
||||
@ -661,7 +663,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
|
||||
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)
|
||||
|
||||
os.replace(temp_file_path, filename_without_extension + extension)
|
||||
filename = filename_without_extension + extension
|
||||
if shared.opts.save_images_replace_action != "Replace":
|
||||
n = 0
|
||||
while os.path.exists(filename):
|
||||
n += 1
|
||||
filename = f"{filename_without_extension}-{n}{extension}"
|
||||
os.replace(temp_file_path, filename)
|
||||
|
||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||
if hasattr(os, 'statvfs'):
|
||||
@ -718,7 +726,12 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||
geninfo = items.pop('parameters', None)
|
||||
|
||||
if "exif" in items:
|
||||
exif = piexif.load(items["exif"])
|
||||
exif_data = items["exif"]
|
||||
try:
|
||||
exif = piexif.load(exif_data)
|
||||
except OSError:
|
||||
# memory / exif was not valid so piexif tried to read from a file
|
||||
exif = None
|
||||
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
|
||||
try:
|
||||
exif_comment = piexif.helper.UserComment.load(exif_comment)
|
||||
@ -728,6 +741,8 @@ def read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:
|
||||
if exif_comment:
|
||||
items['exif comment'] = exif_comment
|
||||
geninfo = exif_comment
|
||||
elif "comment" in items: # for gif
|
||||
geninfo = items["comment"].decode('utf8', errors="ignore")
|
||||
|
||||
for field in IGNORED_INFO_KEYS:
|
||||
items.pop(field, None)
|
||||
|
@ -10,6 +10,7 @@ from modules import images as imgutil
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
|
||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.shared import opts, state
|
||||
from modules.sd_models import get_closet_checkpoint_match
|
||||
import modules.shared as shared
|
||||
import modules.processing as processing
|
||||
from modules.ui import plaintext_to_html
|
||||
@ -41,7 +42,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
cfg_scale = p.cfg_scale
|
||||
sampler_name = p.sampler_name
|
||||
steps = p.steps
|
||||
|
||||
override_settings = p.override_settings
|
||||
sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
|
||||
for i, image in enumerate(images):
|
||||
state.job = f"{i+1} out of {len(images)}"
|
||||
if state.skipped:
|
||||
@ -104,15 +106,27 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
|
||||
p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
|
||||
p.steps = int(parsed_parameters.get("Steps", steps))
|
||||
|
||||
model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
|
||||
if model_info is not None:
|
||||
p.override_settings['sd_model_checkpoint'] = model_info.name
|
||||
elif sd_model_checkpoint_override:
|
||||
p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
|
||||
else:
|
||||
p.override_settings.pop("sd_model_checkpoint", None)
|
||||
|
||||
if output_dir:
|
||||
p.outpath_samples = output_dir
|
||||
p.override_settings['save_to_dirs'] = False
|
||||
p.override_settings['save_images_replace_action'] = "Add number suffix"
|
||||
if p.n_iter > 1 or p.batch_size > 1:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||
else:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||
|
||||
proc = modules.scripts.scripts_img2img.run(p, *args)
|
||||
|
||||
if proc is None:
|
||||
if output_dir:
|
||||
p.outpath_samples = output_dir
|
||||
p.override_settings['save_to_dirs'] = False
|
||||
if p.n_iter > 1 or p.batch_size > 1:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
|
||||
else:
|
||||
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
|
||||
p.override_settings.pop('save_images_replace_action', None)
|
||||
process_images(p)
|
||||
|
||||
|
||||
@ -189,7 +203,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if shared.cmd_opts.enable_console_prompts:
|
||||
if shared.opts.enable_console_prompts:
|
||||
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
if mask:
|
||||
|
@ -151,8 +151,8 @@ def initialize_rest(*, reload_script_modules=False):
|
||||
|
||||
from modules import devices
|
||||
devices.first_time_calculation()
|
||||
|
||||
Thread(target=load_model).start()
|
||||
if not shared.cmd_opts.skip_load_model_at_start:
|
||||
Thread(target=load_model).start()
|
||||
|
||||
from modules import shared_items
|
||||
shared_items.reload_hypernetworks()
|
||||
|
@ -150,10 +150,14 @@ def dumpstacks():
|
||||
|
||||
def configure_sigint_handler():
|
||||
# make the program just exit at ctrl+c without waiting for anything
|
||||
|
||||
from modules import shared
|
||||
|
||||
def sigint_handler(sig, frame):
|
||||
print(f'Interrupted with signal {sig} in {frame}')
|
||||
|
||||
dumpstacks()
|
||||
if shared.opts.dump_stacks_on_signal:
|
||||
dumpstacks()
|
||||
|
||||
os._exit(0)
|
||||
|
||||
|
@ -64,7 +64,7 @@ Use --skip-python-version-check to suppress this warning.
|
||||
@lru_cache()
|
||||
def commit_hash():
|
||||
try:
|
||||
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||
return subprocess.check_output([git, "-C", script_path, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
return "<none>"
|
||||
|
||||
@ -72,7 +72,7 @@ def commit_hash():
|
||||
@lru_cache()
|
||||
def git_tag():
|
||||
try:
|
||||
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
return subprocess.check_output([git, "-C", script_path, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||
except Exception:
|
||||
try:
|
||||
|
||||
|
@ -14,21 +14,24 @@ def list_localizations(dirname):
|
||||
if ext.lower() != ".json":
|
||||
continue
|
||||
|
||||
localizations[fn] = os.path.join(dirname, file)
|
||||
localizations[fn] = [os.path.join(dirname, file)]
|
||||
|
||||
for file in scripts.list_scripts("localizations", ".json"):
|
||||
fn, ext = os.path.splitext(file.filename)
|
||||
localizations[fn] = file.path
|
||||
if fn not in localizations:
|
||||
localizations[fn] = []
|
||||
localizations[fn].append(file.path)
|
||||
|
||||
|
||||
def localization_js(current_localization_name: str) -> str:
|
||||
fn = localizations.get(current_localization_name, None)
|
||||
fns = localizations.get(current_localization_name, None)
|
||||
data = {}
|
||||
if fn is not None:
|
||||
try:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
data = json.load(file)
|
||||
except Exception:
|
||||
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
||||
if fns is not None:
|
||||
for fn in fns:
|
||||
try:
|
||||
with open(fn, "r", encoding="utf8") as file:
|
||||
data.update(json.load(file))
|
||||
except Exception:
|
||||
errors.report(f"Error loading localization from {fn}", exc_info=True)
|
||||
|
||||
return f"window.localization = {json.dumps(data)}"
|
||||
|
@ -210,6 +210,8 @@ class Options:
|
||||
|
||||
def add_option(self, key, info):
|
||||
self.data_labels[key] = info
|
||||
if key not in self.data:
|
||||
self.data[key] = info.default
|
||||
|
||||
def reorder(self):
|
||||
"""reorder settings so that all items related to section always go together"""
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401
|
||||
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
@ -8,6 +8,7 @@ import shlex
|
||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||
sys.argv += shlex.split(commandline_args)
|
||||
|
||||
cwd = os.getcwd()
|
||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||
script_path = os.path.dirname(modules_path)
|
||||
|
||||
|
@ -142,7 +142,7 @@ class StableDiffusionProcessing:
|
||||
overlay_images: list = None
|
||||
eta: float = None
|
||||
do_not_reload_embeddings: bool = False
|
||||
denoising_strength: float = 0
|
||||
denoising_strength: float = None
|
||||
ddim_discretize: str = None
|
||||
s_min_uncond: float = None
|
||||
s_churn: float = None
|
||||
@ -296,7 +296,7 @@ class StableDiffusionProcessing:
|
||||
return conditioning
|
||||
|
||||
def edit_image_conditioning(self, source_image):
|
||||
conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
|
||||
conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()
|
||||
|
||||
return conditioning_image
|
||||
|
||||
@ -533,6 +533,7 @@ class Processed:
|
||||
self.all_seeds = all_seeds or p.all_seeds or [self.seed]
|
||||
self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]
|
||||
self.infotexts = infotexts or [info]
|
||||
self.version = program_version()
|
||||
|
||||
def js(self):
|
||||
obj = {
|
||||
@ -567,6 +568,7 @@ class Processed:
|
||||
"job_timestamp": self.job_timestamp,
|
||||
"clip_skip": self.clip_skip,
|
||||
"is_using_inpainting_conditioning": self.is_using_inpainting_conditioning,
|
||||
"version": self.version,
|
||||
}
|
||||
|
||||
return json.dumps(obj)
|
||||
@ -709,7 +711,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||
if p.scripts is not None:
|
||||
p.scripts.before_process(p)
|
||||
|
||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||
stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}
|
||||
|
||||
try:
|
||||
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||
@ -884,6 +886,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
state.nextjob()
|
||||
|
||||
if p.scripts is not None:
|
||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||
|
||||
@ -956,7 +960,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
state.nextjob()
|
||||
if not infotexts:
|
||||
infotexts.append(Processed(p, []).infotext(p, 0))
|
||||
|
||||
p.color_corrections = None
|
||||
|
||||
|
@ -29,8 +29,8 @@ class ScriptSeed(scripts.ScriptBuiltinUI):
|
||||
else:
|
||||
self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0)
|
||||
|
||||
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed')
|
||||
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed')
|
||||
random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), tooltip="Set seed to -1, which will cause a new random number to be used every time")
|
||||
reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), tooltip="Reuse seed from last generation, mostly useful if it was randomized")
|
||||
|
||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False)
|
||||
|
||||
|
@ -2,10 +2,9 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
import lark
|
||||
|
||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
||||
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]"
|
||||
# will be represented with prompt_schedule like this (assuming steps=100):
|
||||
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
||||
# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
|
||||
@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
||||
|
||||
class ComposableScheduledPromptConditioning:
|
||||
def __init__(self, schedules, weight=1.0):
|
||||
self.schedules: List[ScheduledPromptConditioning] = schedules
|
||||
self.schedules: list[ScheduledPromptConditioning] = schedules
|
||||
self.weight: float = weight
|
||||
|
||||
|
||||
class MulticondLearnedConditioning:
|
||||
def __init__(self, shape, batch):
|
||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
|
||||
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
|
||||
@ -278,7 +277,7 @@ class DictWithShape(dict):
|
||||
return self["crossattn"].shape
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
is_dict = isinstance(param, dict)
|
||||
|
||||
|
@ -14,7 +14,9 @@ def is_restartable() -> bool:
|
||||
def restart_program() -> None:
|
||||
"""creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again"""
|
||||
|
||||
(Path(script_path) / "tmp" / "restart").touch()
|
||||
tmpdir = Path(script_path) / "tmp"
|
||||
tmpdir.mkdir(parents=True, exist_ok=True)
|
||||
(tmpdir / "restart").touch()
|
||||
|
||||
stop_program()
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import inspect
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from gradio import Blocks
|
||||
@ -258,7 +258,7 @@ def image_grid_callback(params: ImageGridLoopParams):
|
||||
report_exception(c, 'image_grid')
|
||||
|
||||
|
||||
def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
|
||||
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
|
||||
for c in callback_map['callbacks_infotext_pasted']:
|
||||
try:
|
||||
c.callback(infotext, params)
|
||||
@ -449,7 +449,7 @@ def on_infotext_pasted(callback):
|
||||
"""register a function to be called before applying an infotext.
|
||||
The callback is called with two arguments:
|
||||
- infotext: str - raw infotext.
|
||||
- result: Dict[str, any] - parsed infotext parameters.
|
||||
- result: dict[str, any] - parsed infotext parameters.
|
||||
"""
|
||||
add_callback(callback_map['callbacks_infotext_pasted'], callback)
|
||||
|
||||
|
@ -491,11 +491,15 @@ class ScriptRunner:
|
||||
|
||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||
|
||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||
for field in ("value", "minimum", "maximum", "step"):
|
||||
v = getattr(control, field, None)
|
||||
if v is not None:
|
||||
setattr(arg_info, field, v)
|
||||
|
||||
choices = getattr(control, 'choices', None) # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string
|
||||
if choices is not None:
|
||||
arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]
|
||||
|
||||
api_args.append(arg_info)
|
||||
|
||||
script.api_info = api_models.ScriptInfo(
|
||||
|
@ -2,14 +2,15 @@ import torch
|
||||
from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm.models.diffusion.ddpm
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
@ -37,6 +38,8 @@ ldm.models.diffusion.ddpm.print = shared.ldm_print
|
||||
optimizers = []
|
||||
current_optimizer: sd_hijack_optimizations.SdOptimization = None
|
||||
|
||||
ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
|
||||
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
|
||||
|
||||
def list_optimizers():
|
||||
new_optimizers = script_callbacks.list_optimizers_callback()
|
||||
@ -181,6 +184,20 @@ class StableDiffusionModelHijack:
|
||||
errors.display(e, "applying cross attention optimization")
|
||||
undo_optimizations()
|
||||
|
||||
def convert_sdxl_to_ssd(self, m):
|
||||
"""Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
|
||||
|
||||
delattr(m.model.diffusion_model.middle_block, '1')
|
||||
delattr(m.model.diffusion_model.middle_block, '2')
|
||||
for i in ['9', '8', '7', '6', '5', '4']:
|
||||
delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
|
||||
delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
|
||||
delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
|
||||
devices.torch_gc()
|
||||
|
||||
def hijack(self, m):
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
if conditioner:
|
||||
@ -208,7 +225,7 @@ class StableDiffusionModelHijack:
|
||||
else:
|
||||
m.cond_stage_model = conditioner
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
@ -239,10 +256,17 @@ class StableDiffusionModelHijack:
|
||||
|
||||
self.layers = flatten(m)
|
||||
|
||||
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
||||
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
||||
import modules.models.diffusion.ddpm_edit
|
||||
|
||||
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
|
||||
sd_unet.original_forward = ldm_original_forward
|
||||
elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
|
||||
sd_unet.original_forward = ldm_original_forward
|
||||
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
|
||||
sd_unet.original_forward = sgm_original_forward
|
||||
else:
|
||||
sd_unet.original_forward = None
|
||||
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||
|
||||
def undo_hijack(self, m):
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
@ -279,7 +303,8 @@ class StableDiffusionModelHijack:
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
||||
sd_unet.original_forward = None
|
||||
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
|
@ -1,22 +1,22 @@
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
import gc
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from os import mkdir
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
import numpy as np
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@ -49,11 +49,12 @@ class CheckpointInfo:
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
abspath = os.path.abspath(filename)
|
||||
abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None
|
||||
|
||||
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
|
||||
|
||||
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
|
||||
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
|
||||
if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):
|
||||
name = abspath.replace(abs_ckpt_dir, '')
|
||||
elif abspath.startswith(model_path):
|
||||
name = abspath.replace(model_path, '')
|
||||
else:
|
||||
@ -129,9 +130,12 @@ except Exception:
|
||||
|
||||
|
||||
def setup_model():
|
||||
"""called once at startup to do various one-time tasks related to SD models"""
|
||||
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
enable_midas_autodownload()
|
||||
patch_given_betas()
|
||||
|
||||
|
||||
def checkpoint_tiles(use_short=False):
|
||||
@ -309,6 +313,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
print(f"Loading weights [{sd_model_hash}] from cache")
|
||||
# move to end as latest
|
||||
checkpoints_loaded.move_to_end(checkpoint_info)
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
|
||||
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
|
||||
@ -346,16 +352,19 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
|
||||
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
if model.is_ssd:
|
||||
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
timer.record("apply weights to model")
|
||||
|
||||
del state_dict
|
||||
|
||||
@ -453,6 +462,20 @@ def enable_midas_autodownload():
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
|
||||
def patch_given_betas():
|
||||
import ldm.models.diffusion.ddpm
|
||||
|
||||
def patched_register_schedule(*args, **kwargs):
|
||||
"""a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""
|
||||
|
||||
if isinstance(args[1], ListConfig):
|
||||
args = (args[0], np.array(args[1]), *args[2:])
|
||||
|
||||
original_register_schedule(*args, **kwargs)
|
||||
|
||||
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
|
||||
|
||||
|
||||
def repair_config(sd_config):
|
||||
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
@ -777,17 +800,7 @@ def reload_model_weights(sd_model=None, info=None):
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
model_data.sd_model.to(devices.cpu)
|
||||
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
sd_model = None
|
||||
gc.collect()
|
||||
devices.torch_gc()
|
||||
|
||||
print(f"Unloaded weights {timer.summary()}.")
|
||||
send_model_to_cpu(sd_model or shared.sd_model)
|
||||
|
||||
return sd_model
|
||||
|
||||
|
@ -21,7 +21,7 @@ config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inf
|
||||
config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml")
|
||||
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
|
||||
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
|
||||
|
||||
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
|
||||
|
||||
def is_using_v_parameterization_for_sd2(state_dict):
|
||||
"""
|
||||
@ -95,7 +95,10 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
if diffusion_model_input.shape[1] == 8:
|
||||
return config_instruct_pix2pix
|
||||
|
||||
|
||||
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
|
||||
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
|
||||
return config_alt_diffusion_m18
|
||||
return config_alt_diffusion
|
||||
|
||||
return config_default
|
||||
|
@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
|
||||
"""structure with additional information about the file with model's weights"""
|
||||
|
||||
is_sdxl: bool
|
||||
"""True if the model's architecture is SDXL"""
|
||||
"""True if the model's architecture is SDXL or SSD"""
|
||||
|
||||
is_ssd: bool
|
||||
"""True if the model is SSD"""
|
||||
|
||||
is_sd2: bool
|
||||
"""True if the model's architecture is SD 2.x"""
|
||||
|
@ -1,11 +1,11 @@
|
||||
import torch.nn
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
|
||||
from modules import script_callbacks, shared, devices
|
||||
|
||||
unet_options = []
|
||||
current_unet_option = None
|
||||
current_unet = None
|
||||
original_forward = None
|
||||
|
||||
|
||||
def list_unets():
|
||||
@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
|
||||
if current_unet is not None:
|
||||
return current_unet.forward(x, timesteps, context, *args, **kwargs)
|
||||
|
||||
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
|
||||
return original_forward(self, x, timesteps, context, *args, **kwargs)
|
||||
|
||||
|
@ -14,5 +14,5 @@ if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
|
||||
else:
|
||||
cmd_opts, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
cmd_opts.disable_extension_access = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) and not cmd_opts.enable_insecure_extension_access
|
||||
cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])
|
||||
cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access
|
||||
|
@ -44,9 +44,9 @@ def refresh_unet_list():
|
||||
modules.sd_unet.list_unets()
|
||||
|
||||
|
||||
def list_checkpoint_tiles():
|
||||
def list_checkpoint_tiles(use_short=False):
|
||||
import modules.sd_models
|
||||
return modules.sd_models.checkpoint_tiles()
|
||||
return modules.sd_models.checkpoint_tiles(use_short)
|
||||
|
||||
|
||||
def refresh_checkpoints():
|
||||
@ -67,6 +67,8 @@ def reload_hypernetworks():
|
||||
|
||||
|
||||
ui_reorder_categories_builtin_items = [
|
||||
"prompt",
|
||||
"image",
|
||||
"inpaint",
|
||||
"sampler",
|
||||
"accordions",
|
||||
|
@ -26,7 +26,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"samples_format": OptionInfo('png', 'File format for images'),
|
||||
"samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"),
|
||||
"save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs),
|
||||
|
||||
"save_images_replace_action": OptionInfo("Replace", "Saving the image to an existing file", gr.Radio, {"choices": ["Replace", "Add number suffix"], **hide_dirs}),
|
||||
"grid_save": OptionInfo(True, "Always save all generated image grids"),
|
||||
"grid_format": OptionInfo('png', 'File format for grids'),
|
||||
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
|
||||
@ -62,6 +62,9 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||
|
||||
"save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."),
|
||||
|
||||
"notification_audio": OptionInfo(True, "Play notification sound after image generation").info("notification.mp3 should be present in the root directory").needs_reload_ui(),
|
||||
"notification_volume": OptionInfo(100, "Notification sound volume", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}).info("in %"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
||||
@ -100,6 +103,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
|
||||
|
||||
options_templates.update(options_section(('system', "System"), {
|
||||
"auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}),
|
||||
"enable_console_prompts": OptionInfo(shared.cmd_opts.enable_console_prompts, "Print prompts to console when generating with txt2img and img2img."),
|
||||
"show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(),
|
||||
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(),
|
||||
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
|
||||
@ -109,6 +113,7 @@ options_templates.update(options_section(('system', "System"), {
|
||||
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
|
||||
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
|
||||
"hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
|
||||
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('API', "API"), {
|
||||
@ -133,7 +138,7 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),
|
||||
"sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||
"sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
|
||||
@ -230,6 +235,8 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"),
|
||||
"extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"),
|
||||
"extra_networks_card_show_desc": OptionInfo(True, "Show description on card"),
|
||||
"extra_networks_card_order_field": OptionInfo("Name", "Default order field for Extra Networks cards", gr.Dropdown, {"choices": ['Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),
|
||||
"extra_networks_card_order": OptionInfo("Ascending", "Default order for Extra Networks cards", gr.Dropdown, {"choices": ['Ascending', 'Descending']}).needs_reload_ui(),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"),
|
||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(),
|
||||
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
|
||||
@ -255,15 +262,18 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(),
|
||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||
"keyedit_delimiters": OptionInfo(r".,\/!?%^*;:{}=`~() ", "Ctrl+up/down word delimiters"),
|
||||
"keyedit_delimiters_whitespace": OptionInfo(["Tab", "Carriage Return", "Line Feed"], "Ctrl+up/down whitespace delimiters", gr.CheckboxGroup, lambda: {"choices": ["Tab", "Carriage Return", "Line Feed"]}),
|
||||
"keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"),
|
||||
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(),
|
||||
"ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(),
|
||||
"ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(),
|
||||
"sd_checkpoint_dropdown_use_short": OptionInfo(False, "Checkpoint dropdown: use filenames without paths").info("models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt"),
|
||||
"hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(),
|
||||
"hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(),
|
||||
"disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(),
|
||||
"compact_prompt_box": OptionInfo(False, "Compact prompt layout").info("puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right").needs_reload_ui(),
|
||||
}))
|
||||
|
||||
|
||||
@ -305,8 +315,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
||||
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
|
||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
|
||||
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule min sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
|
||||
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule max sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
|
||||
'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"),
|
||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"),
|
||||
'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"),
|
||||
@ -329,4 +339,3 @@ options_templates.update(options_section((None, "Hidden options"), {
|
||||
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||
}))
|
||||
|
||||
|
@ -103,6 +103,7 @@ class State:
|
||||
|
||||
def begin(self, job: str = "(unknown)"):
|
||||
self.sampling_step = 0
|
||||
self.time_start = time.time()
|
||||
self.job_count = -1
|
||||
self.processing_has_refined_job_count = False
|
||||
self.job_no = 0
|
||||
@ -114,7 +115,6 @@ class State:
|
||||
self.skipped = False
|
||||
self.interrupted = False
|
||||
self.textinfo = None
|
||||
self.time_start = time.time()
|
||||
self.job = job
|
||||
devices.torch_gc()
|
||||
log.info("Starting job %s", job)
|
||||
|
@ -15,7 +15,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
|
||||
def narrow_trunc(
|
||||
@ -97,7 +97,7 @@ def _query_chunk_attention(
|
||||
)
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunks: list[AttnChunk] = [
|
||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
|
@ -181,40 +181,7 @@ class EmbeddingDatabase:
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||
vectors = data['clip_g'].shape[0]
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vectors
|
||||
embedding.shape = shape
|
||||
embedding.filename = path
|
||||
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
|
||||
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||
return fn
|
||||
|
||||
|
||||
def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
|
||||
if 'string_to_param' in data: # textual inversion embeddings
|
||||
param_dict = data['string_to_param']
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
|
||||
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
|
||||
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
|
||||
vectors = data['clip_g'].shape[0]
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
shape = vec.shape[-1]
|
||||
vectors = vec.shape[0]
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vectors
|
||||
embedding.shape = shape
|
||||
|
||||
if filepath:
|
||||
embedding.filename = filepath
|
||||
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')
|
||||
|
||||
return embedding
|
||||
|
||||
|
||||
def write_loss(log_directory, filename, step, epoch_len, values):
|
||||
if shared.opts.training_write_csv_every == 0:
|
||||
return
|
||||
@ -386,7 +392,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
from modules import processing
|
||||
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
@ -590,7 +596,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
||||
p.prompt = preview_prompt
|
||||
p.negative_prompt = preview_negative_prompt
|
||||
p.steps = preview_steps
|
||||
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
|
||||
p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]
|
||||
p.cfg_scale = preview_cfg_scale
|
||||
p.seed = preview_seed
|
||||
p.width = preview_width
|
||||
|
@ -3,7 +3,7 @@ from contextlib import closing
|
||||
import modules.scripts
|
||||
from modules import processing
|
||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||
from modules.shared import opts, cmd_opts
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
from modules.ui import plaintext_to_html
|
||||
import gradio as gr
|
||||
@ -45,7 +45,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
|
||||
|
||||
p.user = request.username
|
||||
|
||||
if cmd_opts.enable_console_prompts:
|
||||
if shared.opts.enable_console_prompts:
|
||||
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
with closing(p):
|
||||
|
286
modules/ui.py
286
modules/ui.py
@ -12,7 +12,7 @@ from PIL import Image, PngImagePlugin # noqa: F401
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||
|
||||
from modules import gradio_extensons # noqa: F401
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
|
||||
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow
|
||||
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow
|
||||
from modules.paths import script_path
|
||||
from modules.ui_common import create_refresh_button
|
||||
@ -25,7 +25,6 @@ import modules.hypernetworks.ui as hypernetworks_ui
|
||||
import modules.textual_inversion.ui as textual_inversion_ui
|
||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||
import modules.shared as shared
|
||||
import modules.images
|
||||
from modules import prompt_parser
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.generation_parameters_copypaste import image_from_url_text
|
||||
@ -151,11 +150,15 @@ def connect_clear_prompt(button):
|
||||
)
|
||||
|
||||
|
||||
def update_token_counter(text, steps):
|
||||
def update_token_counter(text, steps, *, is_positive=True):
|
||||
try:
|
||||
text, _ = extra_networks.parse_prompt(text)
|
||||
|
||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||
if is_positive:
|
||||
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
|
||||
else:
|
||||
prompt_flat_list = [text]
|
||||
|
||||
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
|
||||
|
||||
except Exception:
|
||||
@ -169,76 +172,9 @@ def update_token_counter(text, steps):
|
||||
return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
|
||||
|
||||
|
||||
class Toprow:
|
||||
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
||||
def update_negative_prompt_token_counter(text, steps):
|
||||
return update_token_counter(text, steps, is_positive=False)
|
||||
|
||||
def __init__(self, is_img2img):
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
self.id_part = id_part
|
||||
|
||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||
with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
self.prompt_img = gr.File(label="", elem_id=f"{id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=80):
|
||||
with gr.Row():
|
||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
|
||||
self.button_interrogate = None
|
||||
self.button_deepbooru = None
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_classes="interrogate-col"):
|
||||
self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
|
||||
with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
|
||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
|
||||
self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
|
||||
|
||||
self.skip.click(
|
||||
fn=lambda: shared.state.skip(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
self.interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with gr.Row(elem_id=f"{id_part}_tools"):
|
||||
self.paste = ToolButton(value=paste_symbol, elem_id="paste")
|
||||
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
|
||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||
|
||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||
self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
|
||||
|
||||
self.clear_prompt_button.click(
|
||||
fn=lambda *x: x,
|
||||
_js="confirm_clear_prompt",
|
||||
inputs=[self.prompt, self.negative_prompt],
|
||||
outputs=[self.prompt, self.negative_prompt],
|
||||
)
|
||||
|
||||
self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
|
||||
|
||||
self.prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[self.prompt_img],
|
||||
outputs=[self.prompt, self.prompt_img],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
|
||||
def setup_progressbar(*args, **kwargs):
|
||||
@ -278,8 +214,8 @@ def apply_setting(key, value):
|
||||
return getattr(opts, key)
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
return ui_common.create_output_panel(tabname, outdir)
|
||||
def create_output_panel(tabname, outdir, toprow=None):
|
||||
return ui_common.create_output_panel(tabname, outdir, toprow)
|
||||
|
||||
|
||||
def create_sampler_and_steps_selection(choices, tabname):
|
||||
@ -326,7 +262,7 @@ def create_ui():
|
||||
scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||
toprow = Toprow(is_img2img=False)
|
||||
toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)
|
||||
|
||||
dummy_component = gr.Label(visible=False)
|
||||
|
||||
@ -338,6 +274,9 @@ def create_ui():
|
||||
scripts.scripts_txt2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "prompt":
|
||||
toprow.create_inline_toprow_prompts()
|
||||
|
||||
if category == "sampler":
|
||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img")
|
||||
|
||||
@ -348,7 +287,7 @@ def create_ui():
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||
|
||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", tooltip="Switch width/height")
|
||||
|
||||
if opts.dimensions_and_batch_together:
|
||||
with gr.Column(elem_id="txt2img_column_batch"):
|
||||
@ -432,7 +371,7 @@ def create_ui():
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples)
|
||||
txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
|
||||
|
||||
txt2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
|
||||
@ -533,7 +472,7 @@ def create_ui():
|
||||
]
|
||||
|
||||
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
|
||||
|
||||
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
|
||||
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
|
||||
@ -544,7 +483,7 @@ def create_ui():
|
||||
scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||
toprow = Toprow(is_img2img=True)
|
||||
toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)
|
||||
|
||||
extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")
|
||||
extra_tabs.__enter__()
|
||||
@ -567,85 +506,89 @@ def create_ui():
|
||||
button = gr.Button(title)
|
||||
copy_image_buttons.append((button, name, elem))
|
||||
|
||||
with gr.Tabs(elem_id="mode_img2img"):
|
||||
img2img_selected_tab = gr.State(0)
|
||||
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||
add_copy_image_controls('img2img', init_img)
|
||||
|
||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||
add_copy_image_controls('sketch', sketch)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||
|
||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||
inpaint_color_sketch_orig = gr.State(None)
|
||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||
|
||||
def update_orig(image, state):
|
||||
if image is not None:
|
||||
same_size = state is not None and state.size == image.size
|
||||
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
||||
edited = same_size and has_exact_match
|
||||
return image if not edited or state is None else state
|
||||
|
||||
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
||||
|
||||
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
|
||||
|
||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(
|
||||
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||
f"{hidden}</p>"
|
||||
)
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
with gr.Accordion("PNG info", open=False):
|
||||
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
||||
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
||||
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
||||
|
||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||
|
||||
for i, tab in enumerate(img2img_tabs):
|
||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||
|
||||
def copy_image(img):
|
||||
if isinstance(img, dict) and 'image' in img:
|
||||
return img['image']
|
||||
|
||||
return img
|
||||
|
||||
for button, name, elem in copy_image_buttons:
|
||||
button.click(
|
||||
fn=copy_image,
|
||||
inputs=[elem],
|
||||
outputs=[copy_image_destinations[name]],
|
||||
)
|
||||
button.click(
|
||||
fn=lambda: None,
|
||||
_js=f"switch_to_{name.replace(' ', '_')}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with FormRow():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||
|
||||
scripts.scripts_img2img.prepare_ui()
|
||||
|
||||
for category in ordered_ui_categories():
|
||||
if category == "prompt":
|
||||
toprow.create_inline_toprow_prompts()
|
||||
|
||||
if category == "image":
|
||||
with gr.Tabs(elem_id="mode_img2img"):
|
||||
img2img_selected_tab = gr.State(0)
|
||||
|
||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA", height=opts.img2img_editor_height)
|
||||
add_copy_image_controls('img2img', init_img)
|
||||
|
||||
with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:
|
||||
sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)
|
||||
add_copy_image_controls('sketch', sketch)
|
||||
|
||||
with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:
|
||||
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)
|
||||
add_copy_image_controls('inpaint', init_img_with_mask)
|
||||
|
||||
with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:
|
||||
inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGB", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)
|
||||
inpaint_color_sketch_orig = gr.State(None)
|
||||
add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)
|
||||
|
||||
def update_orig(image, state):
|
||||
if image is not None:
|
||||
same_size = state is not None and state.size == image.size
|
||||
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
|
||||
edited = same_size and has_exact_match
|
||||
return image if not edited or state is None else state
|
||||
|
||||
inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)
|
||||
|
||||
with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload:
|
||||
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base")
|
||||
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", image_mode="RGBA", elem_id="img_inpaint_mask")
|
||||
|
||||
with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch:
|
||||
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
|
||||
gr.HTML(
|
||||
"<p style='padding-bottom: 1em;' class=\"text-gray-500\">Process images in a directory on the same machine where the server is running." +
|
||||
"<br>Use an empty output directory to save pictures normally instead of writing to the output directory." +
|
||||
f"<br>Add inpaint batch mask directory to enable inpaint batch processing."
|
||||
f"{hidden}</p>"
|
||||
)
|
||||
img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir")
|
||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||
with gr.Accordion("PNG info", open=False):
|
||||
img2img_batch_use_png_info = gr.Checkbox(label="Append png info to prompts", **shared.hide_dirs, elem_id="img2img_batch_use_png_info")
|
||||
img2img_batch_png_info_dir = gr.Textbox(label="PNG info directory", **shared.hide_dirs, placeholder="Leave empty to use input directory", elem_id="img2img_batch_png_info_dir")
|
||||
img2img_batch_png_info_props = gr.CheckboxGroup(["Prompt", "Negative prompt", "Seed", "CFG scale", "Sampler", "Steps", "Model hash"], label="Parameters to take from png info", info="Prompts from png info will be appended to prompts set in ui.")
|
||||
|
||||
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||
|
||||
for i, tab in enumerate(img2img_tabs):
|
||||
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||
|
||||
def copy_image(img):
|
||||
if isinstance(img, dict) and 'image' in img:
|
||||
return img['image']
|
||||
|
||||
return img
|
||||
|
||||
for button, name, elem in copy_image_buttons:
|
||||
button.click(
|
||||
fn=copy_image,
|
||||
inputs=[elem],
|
||||
outputs=[copy_image_destinations[name]],
|
||||
)
|
||||
button.click(
|
||||
fn=lambda: None,
|
||||
_js=f"switch_to_{name.replace(' ', '_')}",
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
with FormRow():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
|
||||
|
||||
if category == "sampler":
|
||||
steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img")
|
||||
|
||||
@ -661,8 +604,8 @@ def create_ui():
|
||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn")
|
||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn", tooltip="Switch width/height")
|
||||
detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id="img2img_detect_image_size_btn", tooltip="Auto detect size from img2img")
|
||||
|
||||
with gr.Tab(label="Resize by", elem_id="img2img_tab_resize_by") as tab_scale_by:
|
||||
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||
@ -746,20 +689,20 @@ def create_ui():
|
||||
with gr.Column(scale=4):
|
||||
inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding")
|
||||
|
||||
def select_img2img_tab(tab):
|
||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||
|
||||
for i, elem in enumerate(img2img_tabs):
|
||||
elem.select(
|
||||
fn=lambda tab=i: select_img2img_tab(tab),
|
||||
inputs=[],
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
|
||||
if category not in {"accordions"}:
|
||||
scripts.scripts_img2img.setup_ui_for_section(category)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
|
||||
def select_img2img_tab(tab):
|
||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||
|
||||
for i, elem in enumerate(img2img_tabs):
|
||||
elem.select(
|
||||
fn=lambda tab=i: select_img2img_tab(tab),
|
||||
inputs=[],
|
||||
outputs=[inpaint_controls, mask_alpha],
|
||||
)
|
||||
|
||||
img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
|
||||
|
||||
img2img_args = dict(
|
||||
fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
|
||||
@ -1286,7 +1229,7 @@ def create_ui():
|
||||
|
||||
loadsave.setup_ui()
|
||||
|
||||
if os.path.exists(os.path.join(script_path, "notification.mp3")):
|
||||
if os.path.exists(os.path.join(script_path, "notification.mp3")) and shared.opts.notification_audio:
|
||||
gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)
|
||||
|
||||
footer = shared.html("footer.html")
|
||||
@ -1338,7 +1281,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||
|
||||
def setup_ui_api(app):
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
class QuicksettingsHint(BaseModel):
|
||||
name: str = Field(title="Name of the quicksettings field")
|
||||
@ -1347,7 +1289,7 @@ def setup_ui_api(app):
|
||||
def quicksettings_hint():
|
||||
return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
|
||||
|
||||
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
|
||||
app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
|
||||
|
||||
app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
|
||||
|
||||
|
@ -104,7 +104,7 @@ def save_files(js_data, images, do_make_zip, index):
|
||||
return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
|
||||
|
||||
|
||||
def create_output_panel(tabname, outdir):
|
||||
def create_output_panel(tabname, outdir, toprow=None):
|
||||
|
||||
def open_folder(f):
|
||||
if not os.path.exists(f):
|
||||
@ -130,12 +130,15 @@ Requested path was: {f}
|
||||
else:
|
||||
sp.Popen(["xdg-open", path])
|
||||
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||
with gr.Column(elem_id=f"{tabname}_results"):
|
||||
if toprow:
|
||||
toprow.create_inline_toprow_image()
|
||||
|
||||
generation_info = None
|
||||
with gr.Column():
|
||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
|
||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
|
||||
|
||||
generation_info = None
|
||||
with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
|
||||
open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
|
||||
|
||||
|
@ -197,7 +197,7 @@ def update_config_states_table(state_name):
|
||||
config_state = config_states.all_config_states[state_name]
|
||||
|
||||
config_name = config_state.get("name", "Config")
|
||||
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
||||
created_date = datetime.fromtimestamp(config_state["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
|
||||
filepath = config_state.get("filepath", "<unknown>")
|
||||
|
||||
try:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import os.path
|
||||
import urllib.parse
|
||||
from pathlib import Path
|
||||
@ -15,6 +16,17 @@ from modules.ui_components import ToolButton
|
||||
extra_pages = []
|
||||
allowed_dirs = set()
|
||||
|
||||
default_allowed_preview_extensions = ["png", "jpg", "jpeg", "webp", "gif"]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def allowed_preview_extensions_with_extra(extra_extensions=None):
|
||||
return set(default_allowed_preview_extensions) | set(extra_extensions or [])
|
||||
|
||||
|
||||
def allowed_preview_extensions():
|
||||
return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))
|
||||
|
||||
|
||||
def register_page(page):
|
||||
"""registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""
|
||||
@ -33,9 +45,9 @@ def fetch_file(filename: str = ""):
|
||||
if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")
|
||||
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"):
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Only png, jpg, webp, and gif.")
|
||||
ext = os.path.splitext(filename)[1].lower()[1:]
|
||||
if ext not in allowed_preview_extensions():
|
||||
raise ValueError(f"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.")
|
||||
|
||||
# would profit from returning 304
|
||||
return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
|
||||
@ -91,6 +103,7 @@ class ExtraNetworksPage:
|
||||
self.name = title.lower()
|
||||
self.id_page = self.name.replace(" ", "_")
|
||||
self.card_page = shared.html("extra-networks-card.html")
|
||||
self.allow_prompt = True
|
||||
self.allow_negative_prompt = False
|
||||
self.metadata = {}
|
||||
self.items = {}
|
||||
@ -213,9 +226,9 @@ class ExtraNetworksPage:
|
||||
metadata_button = ""
|
||||
metadata = item.get("metadata")
|
||||
if metadata:
|
||||
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(item['name'])})'></div>"
|
||||
metadata_button = f"<div class='metadata-button card-button' title='Show internal metadata' onclick='extraNetworksRequestMetadata(event, {quote_js(self.name)}, {quote_js(html.escape(item['name']))})'></div>"
|
||||
|
||||
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(item['name'])})'></div>"
|
||||
edit_button = f"<div class='edit-button card-button' title='Edit metadata' onclick='extraNetworksEditUserMetadata(event, {quote_js(tabname)}, {quote_js(self.id_page)}, {quote_js(html.escape(item['name']))})'></div>"
|
||||
|
||||
local_path = ""
|
||||
filename = item.get("filename", "")
|
||||
@ -235,7 +248,7 @@ class ExtraNetworksPage:
|
||||
if search_only and shared.opts.extra_networks_hidden_models == "Never":
|
||||
return ""
|
||||
|
||||
sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip()
|
||||
sort_keys = " ".join([f'data-sort-{k}="{html.escape(str(v))}"' for k, v in item.get("sort_keys", {}).items()]).strip()
|
||||
|
||||
args = {
|
||||
"background_image": background_image,
|
||||
@ -273,11 +286,7 @@ class ExtraNetworksPage:
|
||||
Find a preview PNG for a given path (without extension) and call link_preview on it.
|
||||
"""
|
||||
|
||||
preview_extensions = ["png", "jpg", "jpeg", "webp"]
|
||||
if shared.opts.samples_format not in preview_extensions:
|
||||
preview_extensions.append(shared.opts.samples_format)
|
||||
|
||||
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in preview_extensions], [])
|
||||
potential_files = sum([[path + "." + ext, path + ".preview." + ext] for ext in allowed_preview_extensions()], [])
|
||||
|
||||
for file in potential_files:
|
||||
if os.path.isfile(file):
|
||||
@ -359,7 +368,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
related_tabs = []
|
||||
|
||||
for page in ui.stored_extra_pages:
|
||||
with gr.Tab(page.title, id=page.id_page) as tab:
|
||||
with gr.Tab(page.title, elem_id=f"{tabname}_{page.id_page}", elem_classes=["extra-page"]) as tab:
|
||||
elem_id = f"{tabname}_{page.id_page}_cards_html"
|
||||
page_elem = gr.HTML('Loading...', elem_id=elem_id)
|
||||
ui.pages.append(page_elem)
|
||||
@ -373,19 +382,28 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
|
||||
related_tabs.append(tab)
|
||||
|
||||
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
|
||||
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
|
||||
dropdown_sort = gr.Dropdown(choices=['Name', 'Date Created', 'Date Modified', ], value=shared.opts.extra_networks_card_order_field, elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
|
||||
button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes=["sortorder"] + ([] if shared.opts.extra_networks_card_order == "Ascending" else ["sortReverse"]), visible=False, tooltip="Invert sort order")
|
||||
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
|
||||
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
|
||||
|
||||
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
|
||||
ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False)
|
||||
|
||||
for tab in unrelated_tabs:
|
||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
||||
tab_controls = [edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs]
|
||||
|
||||
for tab in related_tabs:
|
||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False)
|
||||
for tab in unrelated_tabs:
|
||||
tab.select(fn=lambda: [gr.update(visible=False) for _ in tab_controls], _js='function(){ extraNetworksUrelatedTabSelected("' + tabname + '"); }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||
|
||||
for page, tab in zip(ui.stored_extra_pages, related_tabs):
|
||||
allow_prompt = "true" if page.allow_prompt else "false"
|
||||
allow_negative_prompt = "true" if page.allow_negative_prompt else "false"
|
||||
|
||||
jscode = 'extraNetworksTabSelected("' + tabname + '", "' + f"{tabname}_{page.id_page}" + '", ' + allow_prompt + ', ' + allow_negative_prompt + ');'
|
||||
|
||||
tab.select(fn=lambda: [gr.update(visible=True) for _ in tab_controls], _js='function(){ ' + jscode + ' }', inputs=[], outputs=tab_controls, show_progress=False)
|
||||
|
||||
dropdown_sort.change(fn=lambda: None, _js="function(){ applyExtraNetworkSort('" + tabname + "'); }")
|
||||
|
||||
def pages_html():
|
||||
if not ui.pages_contents:
|
||||
|
@ -10,6 +10,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
||||
def __init__(self):
|
||||
super().__init__('Checkpoints')
|
||||
|
||||
self.allow_prompt = False
|
||||
|
||||
def refresh(self):
|
||||
shared.refresh_checkpoints()
|
||||
|
||||
|
@ -2,12 +2,12 @@ import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import localization, shared, scripts
|
||||
from modules.paths import script_path, data_path
|
||||
from modules.paths import script_path, data_path, cwd
|
||||
|
||||
|
||||
def webpath(fn):
|
||||
if fn.startswith(script_path):
|
||||
web_path = os.path.relpath(fn, script_path).replace('\\', '/')
|
||||
if fn.startswith(cwd):
|
||||
web_path = os.path.relpath(fn, cwd)
|
||||
else:
|
||||
web_path = os.path.abspath(fn)
|
||||
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import errors
|
||||
from modules.ui_components import ToolButton
|
||||
from modules.ui_components import ToolButton, InputAccordion
|
||||
|
||||
|
||||
def radio_choices(comp): # gradio 3.41 changes choices from list of values to list of pairs
|
||||
@ -32,8 +32,6 @@ class UiLoadsave:
|
||||
self.error_loading = True
|
||||
errors.display(e, "loading settings")
|
||||
|
||||
|
||||
|
||||
def add_component(self, path, x):
|
||||
"""adds component to the registry of tracked components"""
|
||||
|
||||
@ -43,20 +41,24 @@ class UiLoadsave:
|
||||
key = f"{path}/{field}"
|
||||
|
||||
if getattr(obj, 'custom_script_source', None) is not None:
|
||||
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||
key = f"customscript/{obj.custom_script_source}/{key}"
|
||||
|
||||
if getattr(obj, 'do_not_save_to_config', False):
|
||||
return
|
||||
|
||||
saved_value = self.ui_settings.get(key, None)
|
||||
|
||||
if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value':
|
||||
field = 'open'
|
||||
|
||||
if saved_value is None:
|
||||
self.ui_settings[key] = getattr(obj, field)
|
||||
elif condition and not condition(saved_value):
|
||||
pass
|
||||
else:
|
||||
if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
||||
if isinstance(obj, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
|
||||
saved_value = str(saved_value)
|
||||
elif isinstance(x, gr.Number) and field == 'value':
|
||||
elif isinstance(obj, gr.Number) and field == 'value':
|
||||
try:
|
||||
saved_value = float(saved_value)
|
||||
except ValueError:
|
||||
@ -67,7 +69,7 @@ class UiLoadsave:
|
||||
init_field(saved_value)
|
||||
|
||||
if field == 'value' and key not in self.component_mapping:
|
||||
self.component_mapping[key] = x
|
||||
self.component_mapping[key] = obj
|
||||
|
||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:
|
||||
apply_field(x, 'visible')
|
||||
@ -100,6 +102,12 @@ class UiLoadsave:
|
||||
|
||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||
|
||||
if type(x) == InputAccordion:
|
||||
if x.accordion.visible:
|
||||
apply_field(x.accordion, 'visible')
|
||||
apply_field(x, 'value')
|
||||
apply_field(x.accordion, 'value')
|
||||
|
||||
def check_tab_id(tab_id):
|
||||
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||
if type(tab_id) == str:
|
||||
|
@ -4,6 +4,7 @@ from modules import shared, ui_common, ui_components, styles
|
||||
|
||||
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
||||
styles_materialize_symbol = '\U0001f4cb' # 📋
|
||||
styles_copy_symbol = '\U0001f4dd' # 📝
|
||||
|
||||
|
||||
def select_style(name):
|
||||
@ -52,6 +53,8 @@ def refresh_styles():
|
||||
class UiPromptStyles:
|
||||
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
||||
self.tabname = tabname
|
||||
self.main_ui_prompt = main_ui_prompt
|
||||
self.main_ui_negative_prompt = main_ui_negative_prompt
|
||||
|
||||
with gr.Row(elem_id=f"{tabname}_styles_row"):
|
||||
self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
|
||||
@ -61,7 +64,8 @@ class UiPromptStyles:
|
||||
with gr.Row():
|
||||
self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
|
||||
ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
|
||||
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
||||
self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply_dialog", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
|
||||
self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f"{tabname}_style_copy", tooltip="Copy main UI prompt to style.")
|
||||
|
||||
with gr.Row():
|
||||
self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
|
||||
@ -96,15 +100,21 @@ class UiPromptStyles:
|
||||
show_progress=False,
|
||||
).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
|
||||
|
||||
self.materialize.click(
|
||||
fn=materialize_styles,
|
||||
inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
||||
outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
|
||||
self.setup_apply_button(self.materialize)
|
||||
|
||||
self.copy.click(
|
||||
fn=lambda p, n: (p, n),
|
||||
inputs=[main_ui_prompt, main_ui_negative_prompt],
|
||||
outputs=[self.prompt, self.neg_prompt],
|
||||
show_progress=False,
|
||||
).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
|
||||
)
|
||||
|
||||
ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
|
||||
|
||||
|
||||
|
||||
|
||||
def setup_apply_button(self, button):
|
||||
button.click(
|
||||
fn=materialize_styles,
|
||||
inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
|
||||
outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],
|
||||
show_progress=False,
|
||||
).then(fn=None, _js="function(){update_"+self.tabname+"_tokens(); closePopup();}", show_progress=False)
|
||||
|
@ -1,10 +1,11 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
|
||||
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
|
||||
from modules.call_queue import wrap_gradio_call
|
||||
from modules.shared import opts
|
||||
from modules.ui_components import FormRow
|
||||
from modules.ui_gradio_extensions import reload_javascript
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
def get_value_for_setting(key):
|
||||
@ -63,6 +64,9 @@ class UiSettings:
|
||||
quicksettings_list = None
|
||||
quicksettings_names = None
|
||||
text_settings = None
|
||||
show_all_pages = None
|
||||
show_one_page = None
|
||||
search_input = None
|
||||
|
||||
def run_settings(self, *args):
|
||||
changed = []
|
||||
@ -135,7 +139,7 @@ class UiSettings:
|
||||
gr.Group()
|
||||
current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
|
||||
current_tab.__enter__()
|
||||
current_row = gr.Column(variant='compact')
|
||||
current_row = gr.Column(elem_id=f"column_settings_{elem_id}", variant='compact')
|
||||
current_row.__enter__()
|
||||
|
||||
previous_section = item.section
|
||||
@ -173,26 +177,43 @@ class UiSettings:
|
||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||
with gr.Row():
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
|
||||
reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
|
||||
with gr.Row():
|
||||
calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
|
||||
calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
|
||||
|
||||
with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
|
||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||
|
||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
self.show_all_pages = gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||
self.show_one_page = gr.Button(value="Show only one page", elem_id="settings_show_one_page", visible=False)
|
||||
self.show_one_page.click(lambda: None)
|
||||
|
||||
self.search_input = gr.Textbox(value="", elem_id="settings_search", max_lines=1, placeholder="Search...", show_label=False)
|
||||
|
||||
self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
|
||||
|
||||
def call_func_and_return_text(func, text):
|
||||
def handler():
|
||||
t = timer.Timer()
|
||||
func()
|
||||
t.record(text)
|
||||
|
||||
return f'{text} in {t.total:.1f}s'
|
||||
|
||||
return handler
|
||||
|
||||
unload_sd_model.click(
|
||||
fn=sd_models.unload_model_weights,
|
||||
fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
outputs=[self.result]
|
||||
)
|
||||
|
||||
reload_sd_model.click(
|
||||
fn=sd_models.reload_model_weights,
|
||||
fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
|
||||
inputs=[],
|
||||
outputs=[]
|
||||
outputs=[self.result]
|
||||
)
|
||||
|
||||
request_notifications.click(
|
||||
@ -241,6 +262,21 @@ class UiSettings:
|
||||
outputs=[sysinfo_check_output],
|
||||
)
|
||||
|
||||
def calculate_all_checkpoint_hash_fn(max_thread):
|
||||
checkpoints_list = sd_models.checkpoints_list.values()
|
||||
with ThreadPoolExecutor(max_workers=max_thread) as executor:
|
||||
futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]
|
||||
completed = 0
|
||||
for _ in as_completed(futures):
|
||||
completed += 1
|
||||
print(f"{completed} / {len(checkpoints_list)} ")
|
||||
print("Finish calculating hash for all checkpoints")
|
||||
|
||||
calculate_all_checkpoint_hash.click(
|
||||
fn=calculate_all_checkpoint_hash_fn,
|
||||
inputs=[calculate_all_checkpoint_hash_threads],
|
||||
)
|
||||
|
||||
self.interface = settings_interface
|
||||
|
||||
def add_quicksettings(self):
|
||||
@ -294,3 +330,8 @@ class UiSettings:
|
||||
outputs=[self.component_dict[k] for k in component_keys],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def search(self, text):
|
||||
print(text)
|
||||
|
||||
return [gr.update(visible=text in (comp.label or "")) for comp in self.components]
|
||||
|
141
modules/ui_toprow.py
Normal file
141
modules/ui_toprow.py
Normal file
@ -0,0 +1,141 @@
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, ui_prompt_styles
|
||||
import modules.images
|
||||
|
||||
from modules.ui_components import ToolButton
|
||||
|
||||
|
||||
class Toprow:
|
||||
"""Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation"""
|
||||
|
||||
prompt = None
|
||||
prompt_img = None
|
||||
negative_prompt = None
|
||||
|
||||
button_interrogate = None
|
||||
button_deepbooru = None
|
||||
|
||||
interrupt = None
|
||||
skip = None
|
||||
submit = None
|
||||
|
||||
paste = None
|
||||
clear_prompt_button = None
|
||||
apply_styles = None
|
||||
restore_progress_button = None
|
||||
|
||||
token_counter = None
|
||||
token_button = None
|
||||
negative_token_counter = None
|
||||
negative_token_button = None
|
||||
|
||||
ui_styles = None
|
||||
|
||||
submit_box = None
|
||||
|
||||
def __init__(self, is_img2img, is_compact=False):
|
||||
id_part = "img2img" if is_img2img else "txt2img"
|
||||
self.id_part = id_part
|
||||
self.is_img2img = is_img2img
|
||||
self.is_compact = is_compact
|
||||
|
||||
if not is_compact:
|
||||
with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
|
||||
self.create_classic_toprow()
|
||||
else:
|
||||
self.create_submit_box()
|
||||
|
||||
def create_classic_toprow(self):
|
||||
self.create_prompts()
|
||||
|
||||
with gr.Column(scale=1, elem_id=f"{self.id_part}_actions_column"):
|
||||
self.create_submit_box()
|
||||
|
||||
self.create_tools_row()
|
||||
|
||||
self.create_styles_ui()
|
||||
|
||||
def create_inline_toprow_prompts(self):
|
||||
if not self.is_compact:
|
||||
return
|
||||
|
||||
self.create_prompts()
|
||||
|
||||
with gr.Row(elem_classes=["toprow-compact-stylerow"]):
|
||||
with gr.Column(elem_classes=["toprow-compact-tools"]):
|
||||
self.create_tools_row()
|
||||
with gr.Column():
|
||||
self.create_styles_ui()
|
||||
|
||||
def create_inline_toprow_image(self):
|
||||
if not self.is_compact:
|
||||
return
|
||||
|
||||
self.submit_box.render()
|
||||
|
||||
def create_prompts(self):
|
||||
with gr.Column(elem_id=f"{self.id_part}_prompt_container", elem_classes=["prompt-container-compact"] if self.is_compact else [], scale=6):
|
||||
with gr.Row(elem_id=f"{self.id_part}_prompt_row", elem_classes=["prompt-row"]):
|
||||
self.prompt = gr.Textbox(label="Prompt", elem_id=f"{self.id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
self.prompt_img = gr.File(label="", elem_id=f"{self.id_part}_prompt_image", file_count="single", type="binary", visible=False)
|
||||
|
||||
with gr.Row(elem_id=f"{self.id_part}_neg_prompt_row", elem_classes=["prompt-row"]):
|
||||
self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{self.id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
|
||||
|
||||
self.prompt_img.change(
|
||||
fn=modules.images.image_data,
|
||||
inputs=[self.prompt_img],
|
||||
outputs=[self.prompt, self.prompt_img],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
def create_submit_box(self):
|
||||
with gr.Row(elem_id=f"{self.id_part}_generate_box", elem_classes=["generate-box"] + (["generate-box-compact"] if self.is_compact else []), render=not self.is_compact) as submit_box:
|
||||
self.submit_box = submit_box
|
||||
|
||||
self.interrupt = gr.Button('Interrupt', elem_id=f"{self.id_part}_interrupt", elem_classes="generate-box-interrupt")
|
||||
self.skip = gr.Button('Skip', elem_id=f"{self.id_part}_skip", elem_classes="generate-box-skip")
|
||||
self.submit = gr.Button('Generate', elem_id=f"{self.id_part}_generate", variant='primary')
|
||||
|
||||
self.skip.click(
|
||||
fn=lambda: shared.state.skip(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
self.interrupt.click(
|
||||
fn=lambda: shared.state.interrupt(),
|
||||
inputs=[],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
def create_tools_row(self):
|
||||
with gr.Row(elem_id=f"{self.id_part}_tools"):
|
||||
from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol
|
||||
|
||||
self.paste = ToolButton(value=paste_symbol, elem_id="paste", tooltip="Read generation parameters from prompt or last generation if prompt is empty into user interface.")
|
||||
self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{self.id_part}_clear_prompt", tooltip="Clear prompt")
|
||||
self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f"{self.id_part}_style_apply", tooltip="Apply all selected styles to prompts.")
|
||||
|
||||
if self.is_img2img:
|
||||
self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id="interrogate")
|
||||
self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id="deepbooru")
|
||||
|
||||
self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{self.id_part}_restore_progress", visible=False, tooltip="Restore progress")
|
||||
|
||||
self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_token_counter", elem_classes=["token-counter"])
|
||||
self.token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_token_button")
|
||||
self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{self.id_part}_negative_token_counter", elem_classes=["token-counter"])
|
||||
self.negative_token_button = gr.Button(visible=False, elem_id=f"{self.id_part}_negative_token_button")
|
||||
|
||||
self.clear_prompt_button.click(
|
||||
fn=lambda *x: x,
|
||||
_js="confirm_clear_prompt",
|
||||
inputs=[self.prompt, self.negative_prompt],
|
||||
outputs=[self.prompt, self.negative_prompt],
|
||||
)
|
||||
|
||||
def create_styles_ui(self):
|
||||
self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)
|
||||
self.ui_styles.setup_apply_button(self.apply_styles)
|
164
modules/xlmr_m18.py
Normal file
164
modules/xlmr_m18.py
Normal file
@ -0,0 +1,164 @@
|
||||
from transformers import BertPreTrainedModel,BertConfig
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
|
||||
from transformers import XLMRobertaModel,XLMRobertaTokenizer
|
||||
from typing import Optional
|
||||
|
||||
class BertSeriesConfig(BertConfig):
|
||||
def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs):
|
||||
|
||||
super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
class RobertaSeriesConfig(XLMRobertaConfig):
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
self.project_dim = project_dim
|
||||
self.pooler_fn = pooler_fn
|
||||
self.learn_encoder = learn_encoder
|
||||
|
||||
|
||||
class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
config_class = BertSeriesConfig
|
||||
|
||||
def __init__(self, config=None, **kargs):
|
||||
# modify initialization for autoloading
|
||||
if config is None:
|
||||
config = XLMRobertaConfig()
|
||||
config.attention_probs_dropout_prob= 0.1
|
||||
config.bos_token_id=0
|
||||
config.eos_token_id=2
|
||||
config.hidden_act='gelu'
|
||||
config.hidden_dropout_prob=0.1
|
||||
config.hidden_size=1024
|
||||
config.initializer_range=0.02
|
||||
config.intermediate_size=4096
|
||||
config.layer_norm_eps=1e-05
|
||||
config.max_position_embeddings=514
|
||||
|
||||
config.num_attention_heads=16
|
||||
config.num_hidden_layers=24
|
||||
config.output_past=True
|
||||
config.pad_token_id=1
|
||||
config.position_embedding_type= "absolute"
|
||||
|
||||
config.type_vocab_size= 1
|
||||
config.use_cache=True
|
||||
config.vocab_size= 250002
|
||||
config.project_dim = 1024
|
||||
config.learn_encoder = False
|
||||
super().__init__(config)
|
||||
self.roberta = XLMRobertaModel(config)
|
||||
self.transformation = nn.Linear(config.hidden_size,config.project_dim)
|
||||
# self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||
# self.pooler = lambda x: x[:,0]
|
||||
# self.post_init()
|
||||
|
||||
self.has_pre_transformation = True
|
||||
if self.has_pre_transformation:
|
||||
self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)
|
||||
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
def encode(self,c):
|
||||
device = next(self.parameters()).device
|
||||
text = self.tokenizer(c,
|
||||
truncation=True,
|
||||
max_length=77,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
text["input_ids"] = torch.tensor(text["input_ids"]).to(device)
|
||||
text["attention_mask"] = torch.tensor(
|
||||
text['attention_mask']).to(device)
|
||||
features = self(**text)
|
||||
return features['projection_state']
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
) :
|
||||
r"""
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
outputs = self.roberta(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
# # last module outputs
|
||||
# sequence_output = outputs[0]
|
||||
|
||||
|
||||
# # project every module
|
||||
# sequence_output_ln = self.pre_LN(sequence_output)
|
||||
|
||||
# # pooler
|
||||
# pooler_output = self.pooler(sequence_output_ln)
|
||||
# pooler_output = self.transformation(pooler_output)
|
||||
# projection_state = self.transformation(outputs.last_hidden_state)
|
||||
|
||||
if self.has_pre_transformation:
|
||||
sequence_output2 = outputs["hidden_states"][-2]
|
||||
sequence_output2 = self.pre_LN(sequence_output2)
|
||||
projection_state2 = self.transformation_pre(sequence_output2)
|
||||
|
||||
return {
|
||||
"projection_state": projection_state2,
|
||||
"last_hidden_state": outputs.last_hidden_state,
|
||||
"hidden_states": outputs.hidden_states,
|
||||
"attentions": outputs.attentions,
|
||||
}
|
||||
else:
|
||||
projection_state = self.transformation(outputs.last_hidden_state)
|
||||
return {
|
||||
"projection_state": projection_state,
|
||||
"last_hidden_state": outputs.last_hidden_state,
|
||||
"hidden_states": outputs.hidden_states,
|
||||
"attentions": outputs.attentions,
|
||||
}
|
||||
|
||||
|
||||
# return {
|
||||
# 'pooler_output':pooler_output,
|
||||
# 'last_hidden_state':outputs.last_hidden_state,
|
||||
# 'hidden_states':outputs.hidden_states,
|
||||
# 'attentions':outputs.attentions,
|
||||
# 'projection_state':projection_state,
|
||||
# 'sequence_out': sequence_output
|
||||
# }
|
||||
|
||||
|
||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||
base_model_prefix = 'roberta'
|
||||
config_class= RobertaSeriesConfig
|
@ -27,6 +27,6 @@ timm==0.9.2
|
||||
tomesd==0.1.3
|
||||
torch
|
||||
torchdiffeq==0.2.3
|
||||
torchsde==0.2.5
|
||||
torchsde==0.2.6
|
||||
transformers==4.30.2
|
||||
httpx==0.24.1
|
||||
|
24
script.js
24
script.js
@ -124,16 +124,20 @@ document.addEventListener("DOMContentLoaded", function() {
|
||||
* Add a ctrl+enter as a shortcut to start a generation
|
||||
*/
|
||||
document.addEventListener('keydown', function(e) {
|
||||
var handled = false;
|
||||
if (e.key !== undefined) {
|
||||
if ((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||
} else if (e.keyCode !== undefined) {
|
||||
if ((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
|
||||
}
|
||||
if (handled) {
|
||||
var button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||
if (button) {
|
||||
button.click();
|
||||
const isEnter = e.key === 'Enter' || e.keyCode === 13;
|
||||
const isModifierKey = e.metaKey || e.ctrlKey || e.altKey;
|
||||
|
||||
const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');
|
||||
const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
|
||||
|
||||
if (isEnter && isModifierKey) {
|
||||
if (interruptButton.style.display === 'block') {
|
||||
interruptButton.click();
|
||||
setTimeout(function() {
|
||||
generateButton.click();
|
||||
}, 500);
|
||||
} else {
|
||||
generateButton.click();
|
||||
}
|
||||
e.preventDefault();
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
||||
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
|
||||
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
|
||||
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
|
||||
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn", tooltip="Switch width/height")
|
||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
||||
|
||||
with FormRow():
|
||||
|
@ -5,11 +5,17 @@ import shlex
|
||||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_samplers, errors
|
||||
from modules import sd_samplers, errors, sd_models
|
||||
from modules.processing import Processed, process_images
|
||||
from modules.shared import state
|
||||
|
||||
|
||||
def process_model_tag(tag):
|
||||
info = sd_models.get_closet_checkpoint_match(tag)
|
||||
assert info is not None, f'Unknown checkpoint: {tag}'
|
||||
return info.name
|
||||
|
||||
|
||||
def process_string_tag(tag):
|
||||
return tag
|
||||
|
||||
@ -27,7 +33,7 @@ def process_boolean_tag(tag):
|
||||
|
||||
|
||||
prompt_tags = {
|
||||
"sd_model": None,
|
||||
"sd_model": process_model_tag,
|
||||
"outpath_samples": process_string_tag,
|
||||
"outpath_grids": process_string_tag,
|
||||
"prompt_for_display": process_string_tag,
|
||||
@ -108,6 +114,7 @@ class Script(scripts.Script):
|
||||
def ui(self, is_img2img):
|
||||
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
|
||||
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
|
||||
prompt_position = gr.Radio(["start", "end"], label="Insert prompts at the", elem_id=self.elem_id("prompt_position"), value="start")
|
||||
|
||||
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
|
||||
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
|
||||
@ -118,9 +125,9 @@ class Script(scripts.Script):
|
||||
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
|
||||
# be unclear to the user that shift-enter is needed.
|
||||
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
|
||||
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
|
||||
return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]
|
||||
|
||||
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
|
||||
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):
|
||||
lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]
|
||||
|
||||
p.do_not_save_grid = True
|
||||
@ -156,7 +163,22 @@ class Script(scripts.Script):
|
||||
|
||||
copy_p = copy.copy(p)
|
||||
for k, v in args.items():
|
||||
setattr(copy_p, k, v)
|
||||
if k == "sd_model":
|
||||
copy_p.override_settings['sd_model_checkpoint'] = v
|
||||
else:
|
||||
setattr(copy_p, k, v)
|
||||
|
||||
if args.get("prompt") and p.prompt:
|
||||
if prompt_position == "start":
|
||||
copy_p.prompt = args.get("prompt") + " " + p.prompt
|
||||
else:
|
||||
copy_p.prompt = p.prompt + " " + args.get("prompt")
|
||||
|
||||
if args.get("negative_prompt") and p.negative_prompt:
|
||||
if prompt_position == "start":
|
||||
copy_p.negative_prompt = args.get("negative_prompt") + " " + p.negative_prompt
|
||||
else:
|
||||
copy_p.negative_prompt = p.negative_prompt + " " + args.get("negative_prompt")
|
||||
|
||||
proc = process_images(copy_p)
|
||||
images += proc.images
|
||||
|
@ -205,13 +205,14 @@ def csv_string_to_list_strip(data_str):
|
||||
|
||||
|
||||
class AxisOption:
|
||||
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
|
||||
def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):
|
||||
self.label = label
|
||||
self.type = type
|
||||
self.apply = apply
|
||||
self.format_value = format_value
|
||||
self.confirm = confirm
|
||||
self.cost = cost
|
||||
self.prepare = prepare
|
||||
self.choices = choices
|
||||
|
||||
|
||||
@ -536,6 +537,8 @@ class Script(scripts.Script):
|
||||
|
||||
if opt.choices is not None and not csv_mode:
|
||||
valslist = vals_dropdown
|
||||
elif opt.prepare is not None:
|
||||
valslist = opt.prepare(vals)
|
||||
else:
|
||||
valslist = csv_string_to_list_strip(vals)
|
||||
|
||||
@ -773,6 +776,8 @@ class Script(scripts.Script):
|
||||
# TODO: See previous comment about intentional data misalignment.
|
||||
adj_g = g-1 if g > 0 else g
|
||||
images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)
|
||||
if not include_sub_grids: # if not include_sub_grids then skip saving after the first grid
|
||||
break
|
||||
|
||||
if not include_sub_grids:
|
||||
# Done with sub-grids, drop all related information:
|
||||
|
50
style.css
50
style.css
@ -83,8 +83,10 @@ div.compact{
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.gradio-dropdown ul.options li.item {
|
||||
padding: 0.05em 0;
|
||||
@media (pointer:fine) {
|
||||
.gradio-dropdown ul.options li.item {
|
||||
padding: 0.05em 0;
|
||||
}
|
||||
}
|
||||
|
||||
.gradio-dropdown ul.options li.item.selected {
|
||||
@ -202,6 +204,11 @@ div.block.gradio-accordion {
|
||||
padding: 8px 8px;
|
||||
}
|
||||
|
||||
input[type="checkbox"].input-accordion-checkbox{
|
||||
vertical-align: sub;
|
||||
margin-right: 0.5em;
|
||||
}
|
||||
|
||||
|
||||
/* txt2img/img2img specific */
|
||||
|
||||
@ -289,6 +296,13 @@ div.block.gradio-accordion {
|
||||
min-height: 4.5em;
|
||||
}
|
||||
|
||||
#txt2img_generate, #img2img_generate {
|
||||
min-height: 4.5em;
|
||||
}
|
||||
.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate {
|
||||
min-height: 3em;
|
||||
}
|
||||
|
||||
@media screen and (min-width: 2500px) {
|
||||
#txt2img_gallery, #img2img_gallery {
|
||||
min-height: 768px;
|
||||
@ -396,6 +410,15 @@ div#extras_scale_to_tab div.form{
|
||||
min-width: 0.5em;
|
||||
}
|
||||
|
||||
div.toprow-compact-stylerow{
|
||||
margin: 0.5em 0;
|
||||
}
|
||||
|
||||
div.toprow-compact-tools{
|
||||
min-width: fit-content !important;
|
||||
max-width: fit-content;
|
||||
}
|
||||
|
||||
/* settings */
|
||||
#quicksettings {
|
||||
align-items: end;
|
||||
@ -421,6 +444,7 @@ div#extras_scale_to_tab div.form{
|
||||
#settings > div{
|
||||
border: none;
|
||||
margin-left: 10em;
|
||||
padding: 0 var(--spacing-xl);
|
||||
}
|
||||
|
||||
#settings > div.tab-nav{
|
||||
@ -435,6 +459,7 @@ div#extras_scale_to_tab div.form{
|
||||
border: none;
|
||||
text-align: left;
|
||||
white-space: initial;
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
#settings_result{
|
||||
@ -516,7 +541,8 @@ table.popup-table .link{
|
||||
height: 20px;
|
||||
background: #b4c0cc;
|
||||
border-radius: 3px !important;
|
||||
top: -20px;
|
||||
top: -14px;
|
||||
left: 0px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@ -581,7 +607,6 @@ table.popup-table .link{
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
overflow: auto;
|
||||
background-color: rgba(20, 20, 20, 0.95);
|
||||
}
|
||||
|
||||
.global-popup *{
|
||||
@ -590,9 +615,6 @@ table.popup-table .link{
|
||||
|
||||
.global-popup-close:before {
|
||||
content: "×";
|
||||
}
|
||||
|
||||
.global-popup-close{
|
||||
position: fixed;
|
||||
right: 0.25em;
|
||||
top: 0;
|
||||
@ -601,10 +623,20 @@ table.popup-table .link{
|
||||
font-size: 32pt;
|
||||
}
|
||||
|
||||
.global-popup-close{
|
||||
position: fixed;
|
||||
left: 0;
|
||||
top: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(20, 20, 20, 0.95);
|
||||
}
|
||||
|
||||
.global-popup-inner{
|
||||
display: inline-block;
|
||||
margin: auto;
|
||||
padding: 2em;
|
||||
z-index: 1001;
|
||||
}
|
||||
|
||||
/* fullpage image viewer */
|
||||
@ -808,6 +840,10 @@ footer {
|
||||
|
||||
/* extra networks UI */
|
||||
|
||||
.extra-page .prompt{
|
||||
margin: 0 0 0.5em 0;
|
||||
}
|
||||
|
||||
.extra-network-cards{
|
||||
height: calc(100vh - 24rem);
|
||||
overflow: clip scroll;
|
||||
|
@ -1,6 +1,11 @@
|
||||
@echo off
|
||||
|
||||
if exist webui.settings.bat (
|
||||
call webui.settings.bat
|
||||
)
|
||||
|
||||
if not defined PYTHON (set PYTHON=python)
|
||||
if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%")
|
||||
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
|
||||
|
||||
set SD_WEBUI_RESTART=tmp/restart
|
||||
|
2
webui.py
2
webui.py
@ -74,7 +74,7 @@ def webui():
|
||||
if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
|
||||
auto_launch_browser = True
|
||||
elif shared.opts.auto_launch_browser == "Local":
|
||||
auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok, cmd_opts.server_name])
|
||||
auto_launch_browser = not cmd_opts.webui_is_non_local
|
||||
|
||||
app, local_url, share_url = shared.demo.launch(
|
||||
share=cmd_opts.share,
|
||||
|
19
webui.sh
19
webui.sh
@ -4,12 +4,6 @@
|
||||
# change the variables in webui-user.sh instead #
|
||||
#################################################
|
||||
|
||||
|
||||
use_venv=1
|
||||
if [[ $venv_dir == "-" ]]; then
|
||||
use_venv=0
|
||||
fi
|
||||
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
|
||||
|
||||
@ -28,6 +22,12 @@ then
|
||||
source "$SCRIPT_DIR"/webui-user.sh
|
||||
fi
|
||||
|
||||
# If $venv_dir is "-", then disable venv support
|
||||
use_venv=1
|
||||
if [[ $venv_dir == "-" ]]; then
|
||||
use_venv=0
|
||||
fi
|
||||
|
||||
# Set defaults
|
||||
# Install directory without trailing slash
|
||||
if [[ -z "${install_dir}" ]]
|
||||
@ -51,6 +51,8 @@ fi
|
||||
if [[ -z "${GIT}" ]]
|
||||
then
|
||||
export GIT="git"
|
||||
else
|
||||
export GIT_PYTHON_GIT_EXECUTABLE="${GIT}"
|
||||
fi
|
||||
|
||||
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
|
||||
@ -141,9 +143,8 @@ case "$gpu_info" in
|
||||
*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0
|
||||
;;
|
||||
*"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \
|
||||
export TORCH_COMMAND="pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.6"
|
||||
# Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5)
|
||||
# so switch to nightly rocm5.6 without explicit versions this time
|
||||
export TORCH_COMMAND="pip install torch torchvision --index-url https://download.pytorch.org/whl/test/rocm5.6"
|
||||
# Navi 3 needs at least 5.5 which is only on the torch 2.1.0 release candidates right now
|
||||
;;
|
||||
*"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0
|
||||
printf "\n%s\n" "${delimiter}"
|
||||
|
Loading…
Reference in New Issue
Block a user