Merge branch 'master' into master

This commit is contained in:
zhaohu xing 2022-11-30 10:13:17 +08:00 committed by GitHub
commit 0831ab476c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1498 additions and 900 deletions

31
.github/workflows/run_tests.yaml vendored Normal file
View File

@ -0,0 +1,31 @@
name: Run basic features tests on CPU with empty SD model
on:
- push
- pull_request
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.10.6
- uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Run tests
run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
- name: Upload main app stdout-stderr
uses: actions/upload-artifact@v3
if: always()
with:
name: stdout-stderr
path: |
test/stdout.txt
test/stderr.txt

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
__pycache__ __pycache__
*.ckpt *.ckpt
*.safetensors
*.pth *.pth
/ESRGAN/* /ESRGAN/*
/SwinIR/* /SwinIR/*

View File

@ -84,26 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
```commandline
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
```
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Where is History/Image browser?!?!
Image browser is now an extension. You can install it using git:
```commandline
git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser
```
After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart
the UI. The interface for Image browser should appear exactly the same as it was.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.

View File

@ -62,8 +62,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Process an image, use it as an input, repeat.",

View File

@ -23,7 +23,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
if(progressbar.innerText){ if(progressbar.innerText){
let newtitle = 'Stable Diffusion - ' + progressbar.innerText let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
if(document.title != newtitle){ if(document.title != newtitle){
document.title = newtitle; document.title = newtitle;
} }

View File

@ -8,8 +8,8 @@ function set_theme(theme){
} }
function selected_gallery_index(){ function selected_gallery_index(){
var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item')
var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') var button = gradioApp().querySelector('[style="display: block;"].tabitem div[id$=_gallery] .gallery-item.\\!ring-2')
var result = -1 var result = -1
buttons.forEach(function(v, i){ if(v==button) { result = i } }) buttons.forEach(function(v, i){ if(v==button) { result = i } })

View File

@ -5,6 +5,8 @@ import sys
import importlib.util import importlib.util
import shlex import shlex
import platform import platform
import argparse
import json
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions" dir_extensions = "extensions"
@ -17,6 +19,19 @@ def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
def extract_opt(args, name):
opt = None
is_present = False
if name in args:
is_present = True
idx = args.index(name)
del args[idx]
if idx < len(args) and args[idx][0] != "-":
opt = args[idx]
del args[idx]
return args, is_present, opt
def run(command, desc=None, errdesc=None, custom_env=None): def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
@ -119,11 +134,26 @@ def run_extension_installer(extension_dir):
print(e, file=sys.stderr) print(e, file=sys.stderr)
def run_extensions_installers(): def list_extensions(settings_file):
settings = {}
try:
if os.path.isfile(settings_file):
with open(settings_file, "r", encoding="utf8") as file:
settings = json.load(file)
except Exception as e:
print(e, file=sys.stderr)
disabled_extensions = set(settings.get('disabled_extensions', []))
return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
def run_extensions_installers(settings_file):
if not os.path.isdir(dir_extensions): if not os.path.isdir(dir_extensions):
return return
for dirname_extension in os.listdir(dir_extensions): for dirname_extension in list_extensions(settings_file):
run_extension_installer(os.path.join(dir_extensions, dirname_extension)) run_extension_installer(os.path.join(dir_extensions, dirname_extension))
@ -134,28 +164,32 @@ def prepare_enviroment():
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
test_argv = [x for x in sys.argv if x != '--tests']
parser = argparse.ArgumentParser()
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
args, _ = parser.parse_known_args(sys.argv)
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests = extract_arg(sys.argv, '--tests') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
xformers = '--xformers' in sys.argv xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv ngrok = '--ngrok' in sys.argv
@ -179,6 +213,9 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip")
if (not is_installed("xformers") or reinstall_xformers) and xformers: if (not is_installed("xformers") or reinstall_xformers) and xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
if platform.python_version().startswith("3.10"): if platform.python_version().startswith("3.10"):
@ -196,7 +233,7 @@ def prepare_enviroment():
os.makedirs(dir_repos, exist_ok=True) os.makedirs(dir_repos, exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
@ -207,7 +244,7 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers() run_extensions_installers(settings_file=args.ui_settings_file)
if update_check: if update_check:
version_check(commit) version_check(commit)
@ -217,24 +254,30 @@ def prepare_enviroment():
exit(0) exit(0)
if run_tests: if run_tests:
tests(test_argv) exitcode = tests(test_dir)
exit(0) exit(exitcode)
def tests(argv): def tests(test_dir):
if "--api" not in argv: if "--api" not in sys.argv:
argv.append("--api") sys.argv.append("--api")
if "--ckpt" not in sys.argv:
sys.argv.append("--ckpt")
sys.argv.append("./test/test_files/empty.pt")
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr) proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
import test.server_poll import test.server_poll
test.server_poll.run_tests() exitcode = test.server_poll.run_tests(proc, test_dir)
print(f"Stopping Web UI process with id {proc.pid}") print(f"Stopping Web UI process with id {proc.pid}")
proc.kill() proc.kill()
return exitcode
def start(): def start():

View File

@ -3,7 +3,8 @@ import io
import time import time
import uvicorn import uvicorn
from threading import Lock from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from io import BytesIO
from gradio.processing_utils import decode_base64_to_file
from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi import APIRouter, Depends, FastAPI, HTTPException
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest from secrets import compare_digest
@ -13,7 +14,7 @@ from modules import sd_samplers, deepbooru
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models from modules.realesrgan_model import get_realesrgan_models
from typing import List from typing import List
@ -40,6 +41,10 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2') reqDict.pop('upscaler_2')
return reqDict return reqDict
def decode_base64_to_image(encoding):
if encoding.startswith("data:image/"):
encoding = encoding.split(";")[1].split(",")[1]
return Image.open(BytesIO(base64.b64decode(encoding)))
def encode_pil_to_base64(image): def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes: with io.BytesIO() as output_bytes:
@ -107,11 +112,13 @@ class Api:
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
} }
) )
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingTxt2Img(**vars(populate)) p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param # Override object param
@ -137,12 +144,14 @@ class Api:
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_name": validate_sampler_name(img2imgreq.sampler_index), "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
} }
) )
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
p = StableDiffusionProcessingImg2Img(**vars(populate)) p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = [] imgs = []
@ -305,7 +314,7 @@ class Api:
styleList = [] styleList = []
for k in shared.prompt_styles.styles: for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k] style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]}) styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
return styleList return styleList

98
modules/call_queue.py Normal file
View File

@ -0,0 +1,98 @@
import html
import sys
import threading
import traceback
import time
from modules import shared
queue_lock = threading.Lock()
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
shared.state.begin()
with queue_lock:
res = func(*args, **kwargs)
shared.state.end()
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
return tuple(res)
return f

View File

@ -36,6 +36,7 @@ def setup_model(dirname):
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.detection.retinaface import retinaface
from modules.shared import cmd_opts from modules.shared import cmd_opts
net_class = CodeFormer net_class = CodeFormer
@ -65,6 +66,8 @@ def setup_model(dirname):
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
if hasattr(retinaface, 'device'):
retinaface.device = devices.device_codeformer
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer)
self.net = net self.net = net

View File

@ -58,7 +58,7 @@ class DeepDanbooru:
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).cuda() x = torch.from_numpy(a).to(devices.device)
y = self.model(x)[0].detach().cpu().numpy() y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {} probability_dict = {}

View File

@ -2,9 +2,10 @@ import sys, os, shlex
import contextlib import contextlib
import torch import torch
from modules import errors from modules import errors
from packaging import version
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility # check `getattr` and try it for compatibility
def has_mps() -> bool: def has_mps() -> bool:
if not getattr(torch, 'has_mps', False): if not getattr(torch, 'has_mps', False):
@ -24,17 +25,18 @@ def extract_device_id(args, name):
return None return None
def get_cuda_device_string():
from modules import shared
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
return "cuda"
def get_optimal_device(): def get_optimal_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
from modules import shared return torch.device(get_cuda_device_string())
device_id = shared.cmd_opts.device_id
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
else:
return torch.device("cuda")
# if has_mps(): # if has_mps():
# return torch.device("mps") # return torch.device("mps")
@ -44,8 +46,9 @@ def get_optimal_device():
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() with torch.cuda.device(get_cuda_device_string()):
torch.cuda.ipc_collect() torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def enable_tf32(): def enable_tf32():
@ -97,9 +100,25 @@ def autocast(disable=False):
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): orig_tensor_to = torch.Tensor.to
return input_tensor.contiguous() if device.type == 'mps' else input_tensor def tensor_to_fix(self, *args, **kwargs):
if self.device.type != 'mps' and \
((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
(isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
self = self.contiguous()
return orig_tensor_to(self, *args, **kwargs)
def mps_contiguous_to(input_tensor, device): # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
return mps_contiguous(input_tensor, device).to(device) orig_layer_norm = torch.nn.functional.layer_norm
def layer_norm_fix(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
args = list(args)
args[0] = args[0].contiguous()
return orig_layer_norm(*args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix

View File

@ -199,7 +199,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()

View File

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import math import math
import os import os
import sys
import traceback
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -12,7 +14,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models from modules import processing, shared, images, devices, sd_models, sd_samplers
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
@ -20,7 +22,7 @@ import modules.codeformer_model
import piexif import piexif
import piexif.helper import piexif.helper
import gradio as gr import gradio as gr
import safetensors.torch
class LruCache(OrderedDict): class LruCache(OrderedDict):
@dataclass(frozen=True) @dataclass(frozen=True)
@ -213,25 +215,8 @@ def run_pnginfo(image):
if image is None: if image is None:
return '', '', '' return '', '', ''
items = image.info geninfo, items = images.read_info_from_image(image)
geninfo = '' items = {**{'parameters': geninfo}, **items}
if "exif" in image.info:
exif = piexif.load(image.info["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
items['exif comment'] = exif_comment
geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
items.pop(field, None)
geninfo = items.get('parameters', geninfo)
info = '' info = ''
for key, text in items.items(): for key, text in items.items():
@ -249,7 +234,7 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
@ -264,19 +249,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
if teritary_model_info is not None: if teritary_model_info is not None:
print(f"Loading {teritary_model_info.filename}...") print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else: else:
teritary_model = None
theta_2 = None theta_2 = None
theta_funcs = { theta_funcs = {
@ -295,7 +276,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2) theta_1[key] = theta_func1(theta_1[key], t2)
else: else:
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
del theta_2, teritary_model del theta_2
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
@ -314,12 +295,17 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
filename = filename if custom_name == '' else (custom_name + '.ckpt') filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
else:
torch.save(theta_0, output_modelname)
sd_models.list_models() sd_models.list_models()

View File

@ -2,6 +2,8 @@ import base64
import io import io
import os import os
import re import re
from pathlib import Path
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
@ -35,9 +37,8 @@ def quote(text):
def image_from_url_text(filedata): def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]: if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"] filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir()) is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs)
normfn = os.path.normpath(filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories'
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
return Image.open(filename) return Image.open(filename)
@ -75,6 +76,7 @@ def integrate_settings_paste_fields(component_dict):
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight', 'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
'eta_noise_seed_delta': 'ENSD',
} }
settings_paste_fields = [ settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None))) (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))

View File

@ -36,7 +36,9 @@ def gfpgann():
else: else:
print("Unable to load gfpgan model!") print("Unable to load gfpgan model!")
return None return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) if hasattr(facexlib.detection.retinaface, 'device'):
facexlib.detection.retinaface.device = devices.device_gfpgan
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
loaded_gfpgan_model = model loaded_gfpgan_model = model
return model return model

View File

@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True): add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
@ -154,16 +154,28 @@ class Hypernetwork:
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
) )
self.eval_mode()
def weights(self): def weights(self):
res = [] res = []
for k, layers in self.layers.items():
for layer in layers:
res += layer.parameters()
return res
def train_mode(self):
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train() layer.train()
res += layer.trainables() for param in layer.parameters():
param.requires_grad = True
return res def eval_mode(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
@ -367,13 +379,13 @@ def report_statistics(loss_info:dict):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
save_hypernetwork_every = save_hypernetwork_every or 0 save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
@ -403,32 +415,30 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
hypernetwork = shared.loaded_hypernetwork hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0 initial_step = hypernetwork.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
previous_mean_losses = [0]
previous_mean_loss = 0
print("Mean loss of {} elements".format(size))
weights = hypernetwork.weights() weights = hypernetwork.weights()
for weight in weights: hypernetwork.train_mode()
weight.requires_grad = True
# Here we use optimizer from saved HN, or we can specify as UI option. # Here we use optimizer from saved HN, or we can specify as UI option.
if hypernetwork.optimizer_name in optimizer_dict: if hypernetwork.optimizer_name in optimizer_dict:
@ -446,131 +456,156 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
print("Cannot resume from saved optimizer!") print("Cannot resume from saved optimizer!")
print(e) print(e)
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
# size = len(ds.indexes)
# loss_dict = defaultdict(lambda : deque(maxlen = 1024))
# losses = torch.zeros((size,))
# previous_mean_losses = [0]
# previous_mean_loss = 0
# print("Mean loss of {} elements".format(size))
steps_without_grad = 0 steps_without_grad = 0
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(total=steps - initial_step)
for i, entries in pbar: try:
hypernetwork.step = i + ititial_step for i in range((steps-initial_step) * gradient_step):
if len(loss_dict) > 0: if scheduler.finished:
previous_mean_losses = [i[-1] for i in loss_dict.values()] break
previous_mean_loss = mean(previous_mean_losses) if shared.state.interrupted:
break
scheduler.apply(optimizer, hypernetwork.step) for j, batch in enumerate(dl):
if scheduler.finished: # works as a drop_last=True for gradient accumulation
break if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
if shared.state.interrupted: with torch.autocast("cuda"):
break x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if tag_drop_out != 0 or shuffle_tags:
shared.sd_model.cond_stage_model.to(devices.device)
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
shared.sd_model.cond_stage_model.to(devices.cpu)
else:
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
del c
with torch.autocast("cuda"): _loss_step += loss.item()
c = stack_conds([entry.cond for entry in entries]).to(devices.device) scaler.scale(loss).backward()
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) # go back until we reach gradient accumulation steps
x = torch.stack([entry.latent for entry in entries]).to(devices.device) if (j + 1) % gradient_step != 0:
loss = shared.sd_model(x, c)[0] continue
del x # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
del c # scaler.unscale_(optimizer)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
# torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
# print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
losses[hypernetwork.step % losses.shape[0]] = loss.item() steps_done = hypernetwork.step + 1
for entry in entries:
loss_dict[entry.filename].append(loss.item())
optimizer.zero_grad() epoch_num = hypernetwork.step // steps_per_epoch
weights[0].grad = None epoch_step = hypernetwork.step % steps_per_epoch
loss.backward()
if weights[0].grad is None: pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
steps_without_grad += 1 if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
else: # Before saving, change name to match current checkpoint.
steps_without_grad = 0 hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.optimizer_name = optimizer_name
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
optimizer.step() textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
steps_done = hypernetwork.step + 1 if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
hypernetwork.eval_mode()
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): p = processing.StableDiffusionProcessingTxt2Img(
raise RuntimeError("Loss diverged.") sd_model=shared.sd_model,
do_not_save_grid=True,
if len(previous_mean_losses) > 1: do_not_save_samples=True,
std = stdev(previous_mean_losses) )
else:
std = 0
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: if preview_from_txt2img:
# Before saving, change name to match current checkpoint. p.prompt = preview_prompt
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' p.negative_prompt = preview_negative_prompt
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') p.steps = preview_steps
hypernetwork.optimizer_name = optimizer_name p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
if shared.opts.save_optimizer_state: p.cfg_scale = preview_cfg_scale
hypernetwork.optimizer_state_dict = optimizer.state_dict() p.seed = preview_seed
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) p.width = preview_width
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { preview_text = p.prompt
"loss": f"{previous_mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
if images_dir is not None and steps_done % create_image_every == 0: processed = processing.process_images(p)
forced_filename = f'{hypernetwork_name}-{steps_done}' image = processed.images[0] if len(processed.images) > 0 else None
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad() if unload:
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork.train_mode()
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
p = processing.StableDiffusionProcessingTxt2Img( shared.state.job_no = hypernetwork.step
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
)
if preview_from_txt2img: shared.state.textinfo = f"""
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
shared.state.textinfo = f"""
<p> <p>
Loss: {previous_mean_loss:.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/> Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
except Exception:
report_statistics(loss_dict) print(traceback.format_exc(), file=sys.stderr)
finally:
pbar.leave = False
pbar.close()
hypernetwork.eval_mode()
#report_statistics(loss_dict)
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.optimizer_name = optimizer_name hypernetwork.optimizer_name = optimizer_name
@ -579,6 +614,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
del optimizer del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory. hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):

View File

@ -15,6 +15,7 @@ import piexif.helper
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto from fonts.ttf import Roboto
import string import string
import json
from modules import sd_samplers, shared, script_callbacks from modules import sd_samplers, shared, script_callbacks
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
@ -305,6 +306,7 @@ class FilenameGenerator:
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
@ -524,6 +526,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
else: else:
image.save(fullfn, quality=opts.jpeg_quality) image.save(fullfn, quality=opts.jpeg_quality)
image.already_saved_as = fullfn
target_side_length = 4000 target_side_length = 4000
oversize = image.width > target_side_length or image.height > target_side_length oversize = image.width > target_side_length or image.height > target_side_length
if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024): if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
@ -550,10 +554,45 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
return fullfn, txt_fullfn return fullfn, txt_fullfn
def read_info_from_image(image):
items = image.info or {}
geninfo = items.pop('parameters', None)
if "exif" in items:
exif = piexif.load(items["exif"])
exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b'')
try:
exif_comment = piexif.helper.UserComment.load(exif_comment)
except ValueError:
exif_comment = exif_comment.decode('utf8', errors="ignore")
items['exif comment'] = exif_comment
geninfo = exif_comment
for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',
'loop', 'background', 'timestamp', 'duration']:
items.pop(field, None)
if items.get("Software", None) == "NovelAI":
try:
json_info = json.loads(items["Comment"])
sampler = sd_samplers.samplers_map.get(json_info["sampler"], "Euler a")
geninfo = f"""{items["Description"]}
Negative prompt: {json_info["uc"]}
Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337"""
except Exception:
print(f"Error parsing NovelAI iamge generation parameters:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return geninfo, items
def image_data(data): def image_data(data):
try: try:
image = Image.open(io.BytesIO(data)) image = Image.open(io.BytesIO(data))
textinfo = image.text["parameters"] textinfo, _ = read_info_from_image(image)
return textinfo, None return textinfo, None
except Exception: except Exception:
pass pass

View File

@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name, sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,

View File

@ -51,6 +51,10 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None) send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z) return first_stage_model_decode(z)
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
# remove three big modules, cond, first_stage, and unet from the model and then # remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
@ -65,6 +69,10 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
del sd_model.cond_stage_model.transformer
if use_medvram: if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else: else:

View File

@ -82,6 +82,7 @@ def cleanup_models():
src_path = models_path src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion") dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt") move_files(src_path, dest_path, ".ckpt")
move_files(src_path, dest_path, ".safetensors")
src_path = os.path.join(root_path, "ESRGAN") src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path) move_files(src_path, dest_path)

View File

@ -15,9 +15,9 @@ def connect(token, port, region):
) )
try: try:
if account == None: if account == None:
public_url = ngrok.connect(port, pyngrok_config=config).public_url public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
else: else:
public_url = ngrok.connect(port, pyngrok_config=config, auth=account).public_url public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True, auth=account).public_url
except exception.PyngrokNgrokError: except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')

View File

@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places # search for directory of stable diffusion in following places
sd_path = None sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path) sd_path = os.path.abspath(possible_sd_path)

View File

@ -74,7 +74,7 @@ class StableDiffusionProcessing():
""" """
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
if sampler_index is not None: if sampler_index is not None:
warnings.warn("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name") print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples

View File

@ -54,7 +54,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), device) img = img.unsqueeze(0).to(device)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)

View File

@ -8,19 +8,31 @@ from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip
from modules.sd_hijack_optimizations import invokeAI_mps_available from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
def apply_optimizations(): def apply_optimizations():
undo_optimizations() undo_optimizations()
@ -49,16 +61,15 @@ def apply_optimizations():
def undo_optimizations(): def undo_optimizations():
from modules.hypernetworks import hypernetwork
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def get_target_prompt_token_count(token_count): def fix_checkpoint():
return math.ceil(max(token_count, 1) / 75) * 75 ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
class StableDiffusionModelHijack: class StableDiffusionModelHijack:
fixes = None fixes = None
@ -70,19 +81,24 @@ class StableDiffusionModelHijack:
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m): def hijack(self, m):
if shared.text_model_name == "XLMR-Large": if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
apply_optimizations()
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
apply_optimizations()
elif shared.text_model_name == "XLMR-Large":
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
else : m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
# apply_optimizations() fix_checkpoint()
def flatten(el): def flatten(el):
flattened = [flatten(children) for children in el.children()] flattened = [flatten(children) for children in el.children()]
@ -94,12 +110,15 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
m.cond_stage_model = m.cond_stage_model.wrapped
self.apply_circular(False) self.apply_circular(False)
self.layers = None self.layers = None
@ -119,267 +138,8 @@ class StableDiffusionModelHijack:
def tokenize(self, text): def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.token_mults = {}
try:
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
except:
self.comma_token = None
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
if shared.text_model_name == "XLMR-Large":
return self.wrapped.encode(text)
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class EmbeddingsWithFixes(torch.nn.Module): class EmbeddingsWithFixes(torch.nn.Module):

View File

@ -0,0 +1,10 @@
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)

301
modules/sd_hijack_clip.py Normal file
View File

@ -0,0 +1,301 @@
import math
import torch
from modules import prompt_parser, devices
from modules.shared import opts
def get_target_prompt_token_count(token_count):
return math.ceil(max(token_count, 1) / 75) * 75
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
def tokenize(self, texts):
raise NotImplementedError
def encode_with_transformers(self, tokens):
raise NotImplementedError
def encode_embedding_init_text(self, init_text, nvpt):
raise NotImplementedError
def tokenize_line(self, line, used_custom_terms, hijack_comments):
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
fixes = []
remade_tokens = []
multipliers = []
last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if token == self.comma_token:
last_comma = len(remade_tokens)
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
last_comma += 1
reloc_tokens = remade_tokens[last_comma:]
reloc_mults = multipliers[last_comma:]
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [self.id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
iteration = len(remade_tokens) // 75
if (len(remade_tokens) + emb_len) // 75 != iteration:
rem = (75 * (iteration + 1) - len(remade_tokens))
remade_tokens += [self.id_end] * rem
multipliers += [1.0] * rem
iteration += 1
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
token_count = len(remade_tokens)
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens)
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
multipliers = multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, texts):
id_start = self.id_start
id_end = self.id_end
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_tokens = self.tokenize(texts)
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
for fix in unfiltered:
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.id_end] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
if self.id_end != self.id_pad:
for batch_pos in range(len(remade_batch_tokens)):
index = remade_batch_tokens[batch_pos].index(self.id_end)
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
z = self.encode_with_transformers(tokens)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
self.id_start = self.wrapped.tokenizer.bos_token_id
self.id_end = self.wrapped.tokenizer.eos_token_id
self.id_pad = self.id_end
def tokenize(self, texts):
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
return tokenized
def encode_with_transformers(self, tokens):
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
return z
def encode_embedding_init_text(self, init_text, nvpt):
embedding_layer = self.wrapped.transformer.text_model.embeddings
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
return embedded

View File

@ -199,8 +199,8 @@ def sample_plms(self,
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info):
def do_inpainting_hijack(): def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning # most of this stuff seems to no longer be needed because it is already included into SD2.0
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
# this file should be cleaned up later if weverything tuens out to work fine
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms

View File

@ -0,0 +1,37 @@
import open_clip.tokenizer
import torch
from modules import sd_hijack_clip, devices
from modules.shared import opts
tokenizer = open_clip.tokenizer._tokenizer
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
self.id_start = tokenizer.encoder["<start_of_text>"]
self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0
def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [tokenizer.encode(text) for text in texts]
return tokenized
def encode_with_transformers(self, tokens):
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
z = self.wrapped.encode_with_transformer(tokens)
return z
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded

View File

@ -5,6 +5,7 @@ import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
import safetensors.torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
@ -45,7 +46,7 @@ def checkpoint_tiles():
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
def modeltitle(path, shorthash): def modeltitle(path, shorthash):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
@ -143,8 +144,8 @@ def transform_checkpoint_dict_key(k):
def get_state_dict_from_checkpoint(pl_sd): def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd: pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd = pl_sd["state_dict"] pl_sd.pop("state_dict", None)
sd = {} sd = {}
for k, v in pl_sd.items(): for k, v in pl_sd.items():
@ -159,6 +160,20 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
if print_global_state and "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
return sd
def load_model_weights(model, checkpoint_info, vae_file="auto"): def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
@ -173,12 +188,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
# load from file # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) sd = read_state_dict(checkpoint_file)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)
del sd del sd
@ -244,6 +254,9 @@ def load_model(checkpoint_info=None):
do_inpainting_hijack() do_inpainting_hijack()
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)

View File

@ -1,4 +1,4 @@
from collections import namedtuple from collections import namedtuple, deque
import numpy as np import numpy as np
from math import floor from math import floor
import torch import torch
@ -18,7 +18,7 @@ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
samplers_k_diffusion = [ samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}), ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
('Euler', 'sample_euler', ['k_euler'], {}), ('Euler', 'sample_euler', ['k_euler'], {}),
('LMS', 'sample_lms', ['k_lms'], {}), ('LMS', 'sample_lms', ['k_lms'], {}),
('Heun', 'sample_heun', ['k_heun'], {}), ('Heun', 'sample_heun', ['k_heun'], {}),
@ -26,6 +26,7 @@ samplers_k_diffusion = [
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}), ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
@ -33,6 +34,7 @@ samplers_k_diffusion = [
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}), ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras'}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
@ -50,6 +52,7 @@ all_samplers_map = {x.name: x for x in all_samplers}
samplers = [] samplers = []
samplers_for_img2img = [] samplers_for_img2img = []
samplers_map = {}
def create_sampler(name, model): def create_sampler(name, model):
@ -75,6 +78,12 @@ def set_samplers():
samplers = [x for x in all_samplers if x.name not in hidden] samplers = [x for x in all_samplers if x.name not in hidden]
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
samplers_map.clear()
for sampler in all_samplers:
samplers_map[sampler.name.lower()] = sampler.name
for alias in sampler.aliases:
samplers_map[alias.lower()] = sampler.name
set_samplers() set_samplers()
@ -127,7 +136,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
@ -218,7 +228,6 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps): def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps) valid_step = 999 / (1000 // num_steps)
@ -227,7 +236,6 @@ class VanillaStableDiffusionSampler:
return num_steps return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps) steps = self.adjust_steps_if_invalid(p, steps)
@ -260,9 +268,10 @@ class VanillaStableDiffusionSampler:
steps = self.adjust_steps_if_invalid(p, steps or p.steps) steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@ -335,28 +344,39 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack: class TorchHijack:
def __init__(self, kdiff_sampler): def __init__(self, sampler_noises):
self.kdiff_sampler = kdiff_sampler # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item): def __getattr__(self, item):
if item == 'randn_like': if item == 'randn_like':
return self.kdiff_sampler.randn_like return self.randn_like
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return torch.randn_like(x)
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None self.sampler_noises = None
self.sampler_noise_index = 0
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
@ -389,26 +409,14 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return p.steps return p.steps
def randn_like(self, x):
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
if noise is not None and x.shape == noise.shape:
res = noise
else:
res = torch.randn_like(x)
self.sampler_noise_index += 1
return res
def initialize(self, p): def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0 self.model_wrap.step = 0
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
extra_params_kwargs = {} extra_params_kwargs = {}
for param_name in self.extra_params: for param_name in self.extra_params:

View File

@ -11,13 +11,14 @@ import tqdm
import modules.artists import modules.artists
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading from modules import localization, sd_vae, extensions, script_loading
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
demo = None
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -80,13 +81,14 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for api like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(s) in the form of a comma-separated list (no spaces)", default=None)
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
@ -125,10 +127,12 @@ xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = {}
loaded_hypernetwork = None loaded_hypernetwork = None
def reload_hypernetworks(): def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
global hypernetworks global hypernetworks
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
@ -210,10 +214,11 @@ class State:
if self.current_latent is None: if self.current_latent is None:
return return
import modules.sd_samplers
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
else: else:
self.current_image = sd_samplers.sample_to_image(self.current_latent) self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step self.current_image_sampling_step = self.sampling_step
@ -252,6 +257,21 @@ def options_section(section_identifier, options_dict):
return options_dict return options_dict
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
options_templates = {} options_templates = {}
@ -280,6 +300,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@ -304,7 +328,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
@ -326,8 +350,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
"tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
"save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
@ -337,7 +360,7 @@ options_templates.update(options_section(('training', "Training"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
@ -389,7 +412,7 @@ options_templates.update(options_section(('ui', "User interface"), {
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),

View File

@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) img = img.unsqueeze(0).to(devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old

View File

@ -3,7 +3,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset, DataLoader
from torchvision import transforms from torchvision import transforms
import random import random
@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
re_numbers_at_start = re.compile(r"^[-\d]+\s*") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry: class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None): def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
self.filename = filename self.filename = filename
self.latent = latent
self.filename_text = filename_text self.filename_text = filename_text
self.cond = None self.latent_dist = latent_dist
self.cond_text = None self.latent_sample = latent_sample
self.cond = cond
self.cond_text = cond_text
self.pixel_values = pixel_values
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
raise Exception("inturrupted")
try: try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception: except Exception:
@ -71,58 +79,94 @@ class PersonalizedBase(Dataset):
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0) latent_sample = None
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() with torch.autocast("cuda"):
init_latent = init_latent.to(devices.cpu) latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
latent_sampling_method = "once"
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "deterministic":
# Works only for DiagonalGaussianDistribution
latent_dist.std = 0
latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
elif latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)
if include_cond: if not (self.tag_drop_out != 0 or self.shuffle_tags):
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
with torch.autocast("cuda"):
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) self.dataset.append(entry)
del torchdata
del latent_dist
del latent_sample
assert len(self.dataset) > 0, "No images have been found in the dataset." self.length = len(self.dataset)
self.length = len(self.dataset) * repeats // batch_size assert self.length > 0, "No images have been found in the dataset."
self.batch_size = min(batch_size, self.length)
self.dataset_length = len(self.dataset) self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.indexes = None self.latent_sampling_method = latent_sampling_method
self.shuffle()
def shuffle(self):
self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
tags = filename_text.split(',') tags = filename_text.split(',')
if shared.opts.tag_drop_out != 0: if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] tags = [t for t in tags if random.random() > self.tag_drop_out]
if shared.opts.shuffle_tags: if self.shuffle_tags:
random.shuffle(tags) random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags)) text = text.replace("[filewords]", ','.join(tags))
text = text.replace("[name]", self.placeholder_token)
return text return text
def __len__(self): def __len__(self):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
res = [] entry = self.dataset[i]
if self.tag_drop_out != 0 or self.shuffle_tags:
entry.cond_text = self.create_text(entry.filename_text)
if self.latent_sampling_method == "random":
entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
for j in range(self.batch_size): class PersonalizedDataLoader(DataLoader):
position = i * self.batch_size + j def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
if position % len(self.indexes) == 0: super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
self.shuffle() if latent_sampling_method == "random":
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
index = self.indexes[position % len(self.indexes)] class BatchLoader:
entry = self.dataset[index] def __init__(self, data):
self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
if entry.cond is None: def pin_memory(self):
entry.cond_text = self.create_text(entry.filename_text) self.latent_sample = self.latent_sample.pin_memory()
return self
res.append(entry) def collate_wrapper(batch):
return BatchLoader(batch)
return res class BatchLoaderRandom(BatchLoader):
def __init__(self, data):
super().__init__(data)
def pin_memory(self):
return self
def collate_wrapper_random(batch):
return BatchLoaderRandom(batch)

View File

@ -64,7 +64,8 @@ class EmbeddingDatabase:
self.word_embeddings[embedding.name] = embedding self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0] first_id = ids[0]
if first_id not in self.ids_lookup: if first_id not in self.ids_lookup:
@ -155,13 +156,11 @@ class EmbeddingDatabase:
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
with devices.autocast(): with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token): for i in range(num_vectors_per_token):
@ -184,7 +183,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0: if shared.opts.training_write_csv_every == 0:
return return
if (step + 1) % shared.opts.training_write_csv_every != 0: if step % shared.opts.training_write_csv_every != 0:
return return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
@ -194,21 +193,23 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header: if write_csv_header:
csv_writer.writeheader() csv_writer.writeheader()
epoch = step // epoch_len epoch = (step - 1) // epoch_len
epoch_step = step % epoch_len epoch_step = (step - 1) % epoch_len
csv_writer.writerow({ csv_writer.writerow({
"step": step + 1, "step": step,
"epoch": epoch, "epoch": epoch,
"epoch_step": epoch_step + 1, "epoch_step": epoch_step,
**values, **values,
}) })
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected" assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0" assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer" assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive" assert batch_size > 0, "Batch size must be positive"
assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
assert gradient_step > 0, "Gradient accumulation step must be positive"
assert data_root, "Dataset directory is empty" assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty" assert os.listdir(data_root), "Dataset directory is empty"
@ -224,10 +225,10 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps shared.state.job_count = steps
@ -255,161 +256,200 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
else: else:
images_embeds_dir = None images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name] embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0 initial_step = embedding.step or 0
if ititial_step >= steps: if initial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps" shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) # dataset loading may take a while, so input validations and early returns should be done before this
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) pin_memory = shared.opts.pin_memory
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload: if unload:
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
scaler = torch.cuda.amp.GradScaler()
losses = torch.zeros((32,)) batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
steps_per_epoch = len(ds) // batch_size // gradient_step
max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
loss_step = 0
_loss_step = 0 #internal
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
if scheduler.finished:
break
if shared.state.interrupted:
break
for j, batch in enumerate(dl):
# works as a drop_last=True for gradient accumulation
if j == max_steps_per_epoch:
break
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) with torch.autocast("cuda"):
for i, entries in pbar: # c = stack_conds(batch.cond).to(devices.device)
embedding.step = i + ititial_step # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
# c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
_loss_step += loss.item()
scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
scaler.step(optimizer)
scaler.update()
embedding.step += 1
pbar.update()
optimizer.zero_grad(set_to_none=True)
loss_step = _loss_step
_loss_step = 0
scheduler.apply(optimizer, embedding.step) steps_done = embedding.step + 1
if scheduler.finished:
break
if shared.state.interrupted: epoch_num = embedding.step // steps_per_epoch
break epoch_step = embedding.step % steps_per_epoch
with torch.autocast("cuda"): pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
c = cond_model([entry.cond_text for entry in entries]) if embedding_dir is not None and steps_done % save_embedding_every == 0:
x = torch.stack([entry.latent for entry in entries]).to(devices.device) # Before saving, change name to match current checkpoint.
loss = shared.sd_model(x, c)[0] embedding_name_every = f'{embedding_name}-{steps_done}'
del x last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
#if shared.opts.save_optimizer_state:
#embedding.optimizer_state_dict = optimizer.state_dict()
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
losses[embedding.step % losses.shape[0]] = loss.item() write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
})
optimizer.zero_grad() if images_dir is not None and steps_done % create_image_every == 0:
loss.backward() forced_filename = f'{embedding_name}-{steps_done}'
optimizer.step() last_saved_image = os.path.join(images_dir, forced_filename)
steps_done = embedding.step + 1 shared.sd_model.first_stage_model.to(devices.device)
epoch_num = embedding.step // len(ds) p = processing.StableDiffusionProcessingTxt2Img(
epoch_step = embedding.step % len(ds) sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = batch.cond_text[0]
p.steps = 20
p.width = training_width
p.height = training_height
if embedding_dir is not None and steps_done % save_embedding_every == 0: preview_text = p.prompt
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { processed = processing.process_images(p)
"loss": f"{losses.mean():.7f}", image = processed.images[0] if len(processed.images) > 0 else None
"learn_rate": scheduler.learn_rate
})
if images_dir is not None and steps_done % create_image_every == 0: if unload:
forced_filename = f'{embedding_name}-{steps_done}' shared.sd_model.first_stage_model.to(devices.cpu)
last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device) if image is not None:
shared.state.current_image = image
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
p = processing.StableDiffusionProcessingTxt2Img( if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
if preview_from_txt2img: last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
processed = processing.process_images(p) title = "<{}>".format(data.get('name', '???'))
image = processed.images[0]
if unload: try:
shared.sd_model.first_stage_model.to(devices.cpu) vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
shared.state.current_image = image checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
info = PngImagePlugin.PngInfo() last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
data = torch.load(last_saved_file) last_saved_image += f", prompt: {preview_text}"
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???')) shared.state.job_no = embedding.step
try: shared.state.textinfo = f"""
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {loss_step:.7f}<br/>
Step: {embedding.step}<br/> Step: {steps_done}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/> Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception:
shared.sd_model.first_stage_model.to(devices.device) print(traceback.format_exc(), file=sys.stderr)
pass
finally:
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename

View File

@ -17,7 +17,7 @@ import gradio.routes
import gradio.utils import gradio.utils
import numpy as np import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru
from modules.paths import script_path from modules.paths import script_path
@ -157,84 +157,7 @@ def save_files(js_data, images, do_make_zip, index):
return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def save_pil_to_file(pil_image, dir=None):
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon:
shared.mem_mon.monitor()
t = time.perf_counter()
try:
res = list(func(*args, **kwargs))
except Exception as e:
# When printing out our debug argument list, do not print out more than a MB of text
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s"
if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon:
mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
active_peak = mem_stats['active_peak']
reserved_peak = mem_stats['reserved_peak']
sys_peak = mem_stats['system_peak']
sys_total = mem_stats['total']
sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
vram_html = f"<p class='vram'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
else:
vram_html = ''
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
return tuple(res)
return f
def calc_time_left(progress, threshold, label, force_display): def calc_time_left(progress, threshold, label, force_display):
@ -478,9 +401,7 @@ def create_toprow(is_img2img):
if is_img2img: if is_img2img:
with gr.Column(scale=1, elem_id="interrogate_col"): with gr.Column(scale=1, elem_id="interrogate_col"):
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
if cmd_opts.deepdanbooru:
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row(): with gr.Row():
@ -684,7 +605,7 @@ Requested path was: {f}
return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info
def create_ui(wrap_gradio_gpu_call): def create_ui():
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
@ -844,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call):
height, height,
] ]
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_current = modules.scripts.scripts_img2img
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
@ -1004,11 +925,10 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[img2img_prompt], outputs=[img2img_prompt],
) )
if cmd_opts.deepdanbooru: img2img_deepbooru.click(
img2img_deepbooru.click( fn=interrogate_deepbooru,
fn=interrogate_deepbooru, inputs=[init_img],
inputs=[init_img], outputs=[img2img_prompt],
outputs=[img2img_prompt],
) )
@ -1063,6 +983,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"), (seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields *modules.scripts.scripts_img2img.infotext_fields
] ]
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
@ -1183,7 +1104,11 @@ def create_ui(wrap_gradio_gpu_call):
custom_name = gr.Textbox(label="Custom Name (Optional)") custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)
interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
with gr.Row():
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
@ -1213,7 +1138,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
@ -1240,7 +1165,7 @@ def create_ui(wrap_gradio_gpu_call):
process_split = gr.Checkbox(label='Split oversized images') process_split = gr.Checkbox(label='Split oversized images')
process_focal_crop = gr.Checkbox(label='Auto focal point crop') process_focal_crop = gr.Checkbox(label='Auto focal point crop')
process_caption = gr.Checkbox(label='Use BLIP for caption') process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True)
with gr.Row(visible=False) as process_split_extra_row: with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
@ -1259,7 +1184,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
with gr.Row(): with gr.Row():
interrupt_preprocessing = gr.Button("Interrupt") interrupt_preprocessing = gr.Button("Interrupt")
run_preprocess = gr.Button(value="Preprocess", variant='primary') run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change( process_split.change(
fn=lambda show: gr_show(show), fn=lambda show: gr_show(show),
@ -1286,6 +1211,7 @@ def create_ui(wrap_gradio_gpu_call):
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
@ -1296,6 +1222,11 @@ def create_ui(wrap_gradio_gpu_call):
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
with gr.Row():
shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False)
tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0)
with gr.Row():
latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'])
with gr.Row(): with gr.Row():
interrupt_training = gr.Button(value="Interrupt") interrupt_training = gr.Button(value="Interrupt")
@ -1384,11 +1315,15 @@ def create_ui(wrap_gradio_gpu_call):
train_embedding_name, train_embedding_name,
embedding_learn_rate, embedding_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
@ -1409,11 +1344,15 @@ def create_ui(wrap_gradio_gpu_call):
train_hypernetwork_name, train_hypernetwork_name,
hypernetwork_learn_rate, hypernetwork_learn_rate,
batch_size, batch_size,
gradient_step,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width, training_width,
training_height, training_height,
steps, steps,
shuffle_tags,
tag_drop_out,
latent_sampling_method,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,
@ -1697,6 +1636,7 @@ def create_ui(wrap_gradio_gpu_call):
interp_amount, interp_amount,
save_as_half, save_as_half,
custom_name, custom_name,
checkpoint_format,
], ],
outputs=[ outputs=[
submit_result, submit_result,

62
modules/ui_tempdir.py Normal file
View File

@ -0,0 +1,62 @@
import os
import tempfile
from collections import namedtuple
import gradio as gr
from PIL import PngImagePlugin
from modules import shared
Savedfile = namedtuple("Savedfile", ["name"])
def save_pil_to_file(pil_image, dir=None):
already_saved_as = getattr(pil_image, 'already_saved_as', None)
if already_saved_as and os.path.isfile(already_saved_as):
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
file_obj = Savedfile(already_saved_as)
return file_obj
if shared.opts.temp_dir != "":
dir = shared.opts.temp_dir
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in pil_image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
return file_obj
# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file
def on_tmpdir_changed():
if shared.opts.temp_dir == "" or shared.demo is None:
return
os.makedirs(shared.opts.temp_dir, exist_ok=True)
shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}
def cleanup_tmpdr():
temp_dir = shared.opts.temp_dir
if temp_dir == "" or not os.path.isdir(temp_dir):
return
for root, dirs, files in os.walk(temp_dir, topdown=False):
for name in files:
_, extension = os.path.splitext(name)
if extension != ".png":
continue
filename = os.path.join(root, name)
os.remove(filename)

View File

@ -28,3 +28,5 @@ kornia
lark lark
inflection inflection
GitPython GitPython
torchsde
safetensors

View File

@ -25,3 +25,5 @@ kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27 GitPython==3.1.27
torchsde==0.2.5
safetensors==0.2.5

View File

@ -58,29 +58,19 @@ def apply_order(p, x, xs):
prompt_tmp += part prompt_tmp += part
prompt_tmp += x[idx] prompt_tmp += x[idx]
p.prompt = prompt_tmp + p.prompt p.prompt = prompt_tmp + p.prompt
def build_samplers_dict():
samplers_dict = {}
for i, sampler in enumerate(sd_samplers.all_samplers):
samplers_dict[sampler.name.lower()] = i
for alias in sampler.aliases:
samplers_dict[alias.lower()] = i
return samplers_dict
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = build_samplers_dict().get(x.lower(), None) sampler_name = sd_samplers.samplers_map.get(x.lower(), None)
if sampler_index is None: if sampler_name is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
p.sampler_index = sampler_index p.sampler_name = sampler_name
def confirm_samplers(p, xs): def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict()
for x in xs: for x in xs:
if x.lower() not in samplers_dict.keys(): if x.lower() not in sd_samplers.samplers_map:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")

View File

View File

@ -11,8 +11,8 @@ class TestExtrasWorking(unittest.TestCase):
"codeformer_visibility": 0, "codeformer_visibility": 0,
"codeformer_weight": 0, "codeformer_weight": 0,
"upscaling_resize": 2, "upscaling_resize": 2,
"upscaling_resize_w": 512, "upscaling_resize_w": 128,
"upscaling_resize_h": 512, "upscaling_resize_h": 128,
"upscaling_crop": True, "upscaling_crop": True,
"upscaler_1": "None", "upscaler_1": "None",
"upscaler_2": "None", "upscaler_2": "None",

View File

@ -0,0 +1,47 @@
import unittest
import requests
class TestTxt2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
self.simple_txt2img = {
"enable_hr": False,
"denoising_strength": 0,
"firstphase_width": 0,
"firstphase_height": 0,
"prompt": "example prompt",
"styles": [],
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 3,
"cfg_scale": 7,
"width": 64,
"height": 64,
"restore_faces": False,
"tiling": False,
"negative_prompt": "",
"eta": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 1,
"sampler_index": "Euler a"
}
def test_txt2img_with_restore_faces_performed(self):
self.simple_txt2img["restore_faces"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
class TestTxt2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__":
unittest.main()

View File

View File

@ -51,9 +51,5 @@ class TestImg2ImgWorking(unittest.TestCase):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
class TestImg2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -49,26 +49,20 @@ class TestTxt2ImgWorking(unittest.TestCase):
self.simple_txt2img["enable_hr"] = True self.simple_txt2img["enable_hr"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_restore_faces_performed(self): def test_txt2img_with_tiling_performed(self):
self.simple_txt2img["restore_faces"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_tiling_faces_performed(self):
self.simple_txt2img["tiling"] = True self.simple_txt2img["tiling"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_vanilla_sampler_performed(self): def test_txt2img_with_vanilla_sampler_performed(self):
self.simple_txt2img["sampler_index"] = "PLMS" self.simple_txt2img["sampler_index"] = "PLMS"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
self.simple_txt2img["sampler_index"] = "DDIM"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self): def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2 self.simple_txt2img["n_iter"] = 2
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
class TestTxt2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -18,20 +18,6 @@ class UtilsTests(unittest.TestCase):
def test_options_get(self): def test_options_get(self):
self.assertEqual(requests.get(self.url_options).status_code, 200) self.assertEqual(requests.get(self.url_options).status_code, 200)
def test_options_write(self):
response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200)
pre_value = response.json()["send_seed"]
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["send_seed"], not pre_value)
requests.post(self.url_options, json={"send_seed": pre_value})
def test_cmd_flags(self): def test_cmd_flags(self):
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
@ -60,4 +46,8 @@ class UtilsTests(unittest.TestCase):
self.assertEqual(requests.get(self.url_artist_categories).status_code, 200) self.assertEqual(requests.get(self.url_artist_categories).status_code, 200)
def test_artists(self): def test_artists(self):
self.assertEqual(requests.get(self.url_artists).status_code, 200) self.assertEqual(requests.get(self.url_artists).status_code, 200)
if __name__ == "__main__":
unittest.main()

View File

@ -3,7 +3,7 @@ import requests
import time import time
def run_tests(): def run_tests(proc, test_dir):
timeout_threshold = 240 timeout_threshold = 240
start_time = time.time() start_time = time.time()
while time.time()-start_time < timeout_threshold: while time.time()-start_time < timeout_threshold:
@ -11,9 +11,14 @@ def run_tests():
requests.head("http://localhost:7860/") requests.head("http://localhost:7860/")
break break
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
pass if proc.poll() is not None:
if time.time()-start_time < timeout_threshold: break
suite = unittest.TestLoader().discover('', pattern='*_test.py') if proc.poll() is None:
if test_dir is None:
test_dir = ""
suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test")
result = unittest.TextTestRunner(verbosity=2).run(suite) result = unittest.TextTestRunner(verbosity=2).run(suite)
return len(result.failures) + len(result.errors)
else: else:
print("Launch unsuccessful") print("Launch unsuccessful")
return 1

BIN
test/test_files/empty.pt Normal file

Binary file not shown.

70
v1-inference.yaml Normal file
View File

@ -0,0 +1,70 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

View File

@ -8,9 +8,10 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
from modules import devices, sd_samplers, upscaler, extensions, localization from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
@ -23,7 +24,6 @@ import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.sd_vae import modules.sd_vae
import modules.shared as shared
import modules.txt2img import modules.txt2img
import modules.script_callbacks import modules.script_callbacks
@ -32,36 +32,12 @@ from modules import modelloader
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock()
if cmd_opts.server_name: if cmd_opts.server_name:
server_name = cmd_opts.server_name server_name = cmd_opts.server_name
else: else:
server_name = "0.0.0.0" if cmd_opts.listen else None server_name = "0.0.0.0" if cmd_opts.listen else None
def wrap_queued_call(func):
def f(*args, **kwargs):
with queue_lock:
res = func(*args, **kwargs)
return res
return f
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
shared.state.begin()
with queue_lock:
res = func(*args, **kwargs)
shared.state.end()
return res
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
def initialize(): def initialize():
extensions.list_extensions() extensions.list_extensions()
@ -86,8 +62,9 @@ def initialize():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
@ -111,8 +88,12 @@ def initialize():
def setup_cors(app): def setup_cors(app):
if cmd_opts.cors_allow_origins: if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
elif cmd_opts.cors_allow_origins_regex:
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
def create_api(app): def create_api(app):
@ -150,9 +131,12 @@ def webui():
initialize() initialize()
while 1: while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr()
app, local_url, share_url = demo.launch( shared.demo = modules.ui.create_ui()
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name=server_name, server_name=server_name,
server_port=cmd_opts.port, server_port=cmd_opts.port,
@ -179,9 +163,10 @@ def webui():
if launch_api: if launch_api:
create_api(app) create_api(app)
modules.script_callbacks.app_started_callback(demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
modules.script_callbacks.app_started_callback(shared.demo, app)
wait_on_server(demo) wait_on_server(shared.demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()

View File

@ -3,6 +3,7 @@
# Please do not make any changes to this file, # # Please do not make any changes to this file, #
# change the variables in webui-user.sh instead # # change the variables in webui-user.sh instead #
################################################# #################################################
# Read variables from webui-user.sh # Read variables from webui-user.sh
# shellcheck source=/dev/null # shellcheck source=/dev/null
if [[ -f webui-user.sh ]] if [[ -f webui-user.sh ]]
@ -46,6 +47,17 @@ then
LAUNCH_SCRIPT="launch.py" LAUNCH_SCRIPT="launch.py"
fi fi
# this script cannot be run as root by default
can_run_as_root=0
# read any command line flags to the webui.sh script
while getopts "f" flag
do
case ${flag} in
f) can_run_as_root=1;;
esac
done
# Disable sentry logging # Disable sentry logging
export ERROR_REPORTING=FALSE export ERROR_REPORTING=FALSE
@ -61,7 +73,7 @@ printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
# Do not run as root # Do not run as root
if [[ $(id -u) -eq 0 ]] if [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]
then then
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m" printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"