mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Add support Stable Diffusion 2.0
This commit is contained in:
parent
828438b4a1
commit
ce6911158b
21
README.md
21
README.md
@ -84,26 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
|
|||||||
- API
|
- API
|
||||||
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
|
||||||
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
|
||||||
|
- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
|
||||||
## Where are Aesthetic Gradients?!?!
|
|
||||||
Aesthetic Gradients are now an extension. You can install it using git:
|
|
||||||
|
|
||||||
```commandline
|
|
||||||
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
|
|
||||||
```
|
|
||||||
|
|
||||||
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
|
|
||||||
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
|
|
||||||
|
|
||||||
## Where is History/Image browser?!?!
|
|
||||||
Image browser is now an extension. You can install it using git:
|
|
||||||
|
|
||||||
```commandline
|
|
||||||
git clone https://github.com/yfszzx/stable-diffusion-webui-images-browser extensions/images-browser
|
|
||||||
```
|
|
||||||
|
|
||||||
After running this command, make sure that you have `images-browser` dir in webui's `extensions` directory and restart
|
|
||||||
the UI. The interface for Image browser should appear exactly the same as it was.
|
|
||||||
|
|
||||||
## Installation and Running
|
## Installation and Running
|
||||||
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
|
||||||
|
12
launch.py
12
launch.py
@ -134,18 +134,19 @@ def prepare_enviroment():
|
|||||||
|
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||||
|
|
||||||
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||||
|
|
||||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
|
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
|
||||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991")
|
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
|
||||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||||
|
|
||||||
@ -179,6 +180,9 @@ def prepare_enviroment():
|
|||||||
if not is_installed("clip"):
|
if not is_installed("clip"):
|
||||||
run_pip(f"install {clip_package}", "clip")
|
run_pip(f"install {clip_package}", "clip")
|
||||||
|
|
||||||
|
if not is_installed("open_clip"):
|
||||||
|
run_pip(f"install {openclip_package}", "open_clip")
|
||||||
|
|
||||||
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
if (not is_installed("xformers") or reinstall_xformers) and xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
if platform.python_version().startswith("3.10"):
|
if platform.python_version().startswith("3.10"):
|
||||||
@ -196,7 +200,7 @@ def prepare_enviroment():
|
|||||||
|
|
||||||
os.makedirs(dir_repos, exist_ok=True)
|
os.makedirs(dir_repos, exist_ok=True)
|
||||||
|
|
||||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
|
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||||
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||||
|
@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
|
|||||||
|
|
||||||
# search for directory of stable diffusion in following places
|
# search for directory of stable diffusion in following places
|
||||||
sd_path = None
|
sd_path = None
|
||||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||||
for possible_sd_path in possible_sd_paths:
|
for possible_sd_path in possible_sd_paths:
|
||||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||||
sd_path = os.path.abspath(possible_sd_path)
|
sd_path = os.path.abspath(possible_sd_path)
|
||||||
|
@ -9,18 +9,29 @@ from torch.nn.functional import silu
|
|||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||||
from modules.shared import opts, device, cmd_opts
|
from modules.shared import cmd_opts
|
||||||
|
from modules import sd_hijack_clip, sd_hijack_open_clip
|
||||||
|
|
||||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||||
|
|
||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
|
import ldm.modules.encoders.modules
|
||||||
|
|
||||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
# new memory efficient cross attention blocks do not support hypernets and we already
|
||||||
|
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||||
|
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||||
|
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||||
|
|
||||||
|
# silence new console spam from SD2
|
||||||
|
ldm.modules.attention.print = lambda *args: None
|
||||||
|
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||||
|
|
||||||
def apply_optimizations():
|
def apply_optimizations():
|
||||||
undo_optimizations()
|
undo_optimizations()
|
||||||
@ -49,16 +60,11 @@ def apply_optimizations():
|
|||||||
|
|
||||||
|
|
||||||
def undo_optimizations():
|
def undo_optimizations():
|
||||||
from modules.hypernetworks import hypernetwork
|
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward # this stops hypernets from working
|
||||||
|
|
||||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
def get_target_prompt_token_count(token_count):
|
|
||||||
return math.ceil(max(token_count, 1) / 75) * 75
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionModelHijack:
|
class StableDiffusionModelHijack:
|
||||||
fixes = None
|
fixes = None
|
||||||
@ -70,10 +76,13 @@ class StableDiffusionModelHijack:
|
|||||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||||
|
|
||||||
def hijack(self, m):
|
def hijack(self, m):
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||||
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||||
|
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||||
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||||
|
|
||||||
self.clip = m.cond_stage_model
|
self.clip = m.cond_stage_model
|
||||||
|
|
||||||
@ -89,12 +98,15 @@ class StableDiffusionModelHijack:
|
|||||||
self.layers = flatten(m)
|
self.layers = flatten(m)
|
||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||||
|
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
||||||
|
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||||
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
self.apply_circular(False)
|
self.apply_circular(False)
|
||||||
self.layers = None
|
self.layers = None
|
||||||
@ -114,262 +126,9 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def tokenize(self, text):
|
def tokenize(self, text):
|
||||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||||
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||||
|
|
||||||
|
|
||||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
|
||||||
def __init__(self, wrapped, hijack):
|
|
||||||
super().__init__()
|
|
||||||
self.wrapped = wrapped
|
|
||||||
self.hijack: StableDiffusionModelHijack = hijack
|
|
||||||
self.tokenizer = wrapped.tokenizer
|
|
||||||
self.token_mults = {}
|
|
||||||
|
|
||||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
|
||||||
|
|
||||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
|
||||||
for text, ident in tokens_with_parens:
|
|
||||||
mult = 1.0
|
|
||||||
for c in text:
|
|
||||||
if c == '[':
|
|
||||||
mult /= 1.1
|
|
||||||
if c == ']':
|
|
||||||
mult *= 1.1
|
|
||||||
if c == '(':
|
|
||||||
mult *= 1.1
|
|
||||||
if c == ')':
|
|
||||||
mult /= 1.1
|
|
||||||
|
|
||||||
if mult != 1.0:
|
|
||||||
self.token_mults[ident] = mult
|
|
||||||
|
|
||||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
|
||||||
|
|
||||||
if opts.enable_emphasis:
|
|
||||||
parsed = prompt_parser.parse_prompt_attention(line)
|
|
||||||
else:
|
|
||||||
parsed = [[line, 1.0]]
|
|
||||||
|
|
||||||
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
|
||||||
|
|
||||||
fixes = []
|
|
||||||
remade_tokens = []
|
|
||||||
multipliers = []
|
|
||||||
last_comma = -1
|
|
||||||
|
|
||||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
token = tokens[i]
|
|
||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
|
||||||
|
|
||||||
if token == self.comma_token:
|
|
||||||
last_comma = len(remade_tokens)
|
|
||||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
|
||||||
last_comma += 1
|
|
||||||
reloc_tokens = remade_tokens[last_comma:]
|
|
||||||
reloc_mults = multipliers[last_comma:]
|
|
||||||
|
|
||||||
remade_tokens = remade_tokens[:last_comma]
|
|
||||||
length = len(remade_tokens)
|
|
||||||
|
|
||||||
rem = int(math.ceil(length / 75)) * 75 - length
|
|
||||||
remade_tokens += [id_end] * rem + reloc_tokens
|
|
||||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
|
||||||
|
|
||||||
if embedding is None:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(weight)
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
|
||||||
iteration = len(remade_tokens) // 75
|
|
||||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
|
||||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
|
||||||
remade_tokens += [id_end] * rem
|
|
||||||
multipliers += [1.0] * rem
|
|
||||||
iteration += 1
|
|
||||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [weight] * emb_len
|
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
||||||
i += embedding_length_in_tokens
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
|
||||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
|
||||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
|
||||||
|
|
||||||
remade_tokens = remade_tokens + [id_end] * tokens_to_add
|
|
||||||
multipliers = multipliers + [1.0] * tokens_to_add
|
|
||||||
|
|
||||||
return remade_tokens, fixes, multipliers, token_count
|
|
||||||
|
|
||||||
def process_text(self, texts):
|
|
||||||
used_custom_terms = []
|
|
||||||
remade_batch_tokens = []
|
|
||||||
hijack_comments = []
|
|
||||||
hijack_fixes = []
|
|
||||||
token_count = 0
|
|
||||||
|
|
||||||
cache = {}
|
|
||||||
batch_multipliers = []
|
|
||||||
for line in texts:
|
|
||||||
if line in cache:
|
|
||||||
remade_tokens, fixes, multipliers = cache[line]
|
|
||||||
else:
|
|
||||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
|
||||||
token_count = max(current_token_count, token_count)
|
|
||||||
|
|
||||||
cache[line] = (remade_tokens, fixes, multipliers)
|
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
|
||||||
hijack_fixes.append(fixes)
|
|
||||||
batch_multipliers.append(multipliers)
|
|
||||||
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
|
||||||
|
|
||||||
def process_text_old(self, text):
|
|
||||||
id_start = self.wrapped.tokenizer.bos_token_id
|
|
||||||
id_end = self.wrapped.tokenizer.eos_token_id
|
|
||||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
|
||||||
used_custom_terms = []
|
|
||||||
remade_batch_tokens = []
|
|
||||||
overflowing_words = []
|
|
||||||
hijack_comments = []
|
|
||||||
hijack_fixes = []
|
|
||||||
token_count = 0
|
|
||||||
|
|
||||||
cache = {}
|
|
||||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
|
||||||
batch_multipliers = []
|
|
||||||
for tokens in batch_tokens:
|
|
||||||
tuple_tokens = tuple(tokens)
|
|
||||||
|
|
||||||
if tuple_tokens in cache:
|
|
||||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
|
||||||
else:
|
|
||||||
fixes = []
|
|
||||||
remade_tokens = []
|
|
||||||
multipliers = []
|
|
||||||
mult = 1.0
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens):
|
|
||||||
token = tokens[i]
|
|
||||||
|
|
||||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
|
||||||
|
|
||||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
|
||||||
if mult_change is not None:
|
|
||||||
mult *= mult_change
|
|
||||||
i += 1
|
|
||||||
elif embedding is None:
|
|
||||||
remade_tokens.append(token)
|
|
||||||
multipliers.append(mult)
|
|
||||||
i += 1
|
|
||||||
else:
|
|
||||||
emb_len = int(embedding.vec.shape[0])
|
|
||||||
fixes.append((len(remade_tokens), embedding))
|
|
||||||
remade_tokens += [0] * emb_len
|
|
||||||
multipliers += [mult] * emb_len
|
|
||||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
|
||||||
i += embedding_length_in_tokens
|
|
||||||
|
|
||||||
if len(remade_tokens) > maxlen - 2:
|
|
||||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
|
||||||
ovf = remade_tokens[maxlen - 2:]
|
|
||||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
|
||||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
|
||||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
|
||||||
|
|
||||||
token_count = len(remade_tokens)
|
|
||||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
|
||||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
|
||||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
|
||||||
|
|
||||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
|
||||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
|
||||||
|
|
||||||
remade_batch_tokens.append(remade_tokens)
|
|
||||||
hijack_fixes.append(fixes)
|
|
||||||
batch_multipliers.append(multipliers)
|
|
||||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
|
||||||
|
|
||||||
def forward(self, text):
|
|
||||||
use_old = opts.use_old_emphasis_implementation
|
|
||||||
if use_old:
|
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
|
||||||
else:
|
|
||||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
|
||||||
|
|
||||||
self.hijack.comments += hijack_comments
|
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
|
||||||
|
|
||||||
if use_old:
|
|
||||||
self.hijack.fixes = hijack_fixes
|
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
|
||||||
|
|
||||||
z = None
|
|
||||||
i = 0
|
|
||||||
while max(map(len, remade_batch_tokens)) != 0:
|
|
||||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
|
||||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
|
||||||
|
|
||||||
self.hijack.fixes = []
|
|
||||||
for unfiltered in hijack_fixes:
|
|
||||||
fixes = []
|
|
||||||
for fix in unfiltered:
|
|
||||||
if fix[0] == i:
|
|
||||||
fixes.append(fix[1])
|
|
||||||
self.hijack.fixes.append(fixes)
|
|
||||||
|
|
||||||
tokens = []
|
|
||||||
multipliers = []
|
|
||||||
for j in range(len(remade_batch_tokens)):
|
|
||||||
if len(remade_batch_tokens[j]) > 0:
|
|
||||||
tokens.append(remade_batch_tokens[j][:75])
|
|
||||||
multipliers.append(batch_multipliers[j][:75])
|
|
||||||
else:
|
|
||||||
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
|
||||||
multipliers.append([1.0] * 75)
|
|
||||||
|
|
||||||
z1 = self.process_tokens(tokens, multipliers)
|
|
||||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
|
||||||
|
|
||||||
remade_batch_tokens = rem_tokens
|
|
||||||
batch_multipliers = rem_multipliers
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
|
||||||
if not opts.use_old_emphasis_implementation:
|
|
||||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
|
||||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
|
||||||
|
|
||||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
|
||||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
|
||||||
|
|
||||||
if opts.CLIP_stop_at_last_layers > 1:
|
|
||||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
|
||||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
|
||||||
else:
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
|
|
||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
|
||||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
|
||||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
|
||||||
original_mean = z.mean()
|
|
||||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
|
||||||
new_mean = z.mean()
|
|
||||||
z *= original_mean / new_mean
|
|
||||||
|
|
||||||
return z
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsWithFixes(torch.nn.Module):
|
class EmbeddingsWithFixes(torch.nn.Module):
|
||||||
def __init__(self, wrapped, embeddings):
|
def __init__(self, wrapped, embeddings):
|
||||||
|
301
modules/sd_hijack_clip.py
Normal file
301
modules/sd_hijack_clip.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import prompt_parser, devices
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
|
def get_target_prompt_token_count(token_count):
|
||||||
|
return math.ceil(max(token_count, 1) / 75) * 75
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__()
|
||||||
|
self.wrapped = wrapped
|
||||||
|
self.hijack = hijack
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||||
|
if opts.enable_emphasis:
|
||||||
|
parsed = prompt_parser.parse_prompt_attention(line)
|
||||||
|
else:
|
||||||
|
parsed = [[line, 1.0]]
|
||||||
|
|
||||||
|
tokenized = self.tokenize([text for text, _ in parsed])
|
||||||
|
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
last_comma = -1
|
||||||
|
|
||||||
|
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
|
if token == self.comma_token:
|
||||||
|
last_comma = len(remade_tokens)
|
||||||
|
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||||
|
last_comma += 1
|
||||||
|
reloc_tokens = remade_tokens[last_comma:]
|
||||||
|
reloc_mults = multipliers[last_comma:]
|
||||||
|
|
||||||
|
remade_tokens = remade_tokens[:last_comma]
|
||||||
|
length = len(remade_tokens)
|
||||||
|
|
||||||
|
rem = int(math.ceil(length / 75)) * 75 - length
|
||||||
|
remade_tokens += [self.id_end] * rem + reloc_tokens
|
||||||
|
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||||
|
|
||||||
|
if embedding is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(weight)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
iteration = len(remade_tokens) // 75
|
||||||
|
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||||
|
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||||
|
remade_tokens += [self.id_end] * rem
|
||||||
|
multipliers += [1.0] * rem
|
||||||
|
iteration += 1
|
||||||
|
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [weight] * emb_len
|
||||||
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||||
|
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||||
|
|
||||||
|
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
||||||
|
multipliers = multipliers + [1.0] * tokens_to_add
|
||||||
|
|
||||||
|
return remade_tokens, fixes, multipliers, token_count
|
||||||
|
|
||||||
|
def process_text(self, texts):
|
||||||
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_multipliers = []
|
||||||
|
for line in texts:
|
||||||
|
if line in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[line]
|
||||||
|
else:
|
||||||
|
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||||
|
token_count = max(current_token_count, token_count)
|
||||||
|
|
||||||
|
cache[line] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
def process_text_old(self, texts):
|
||||||
|
id_start = self.id_start
|
||||||
|
id_end = self.id_end
|
||||||
|
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||||
|
used_custom_terms = []
|
||||||
|
remade_batch_tokens = []
|
||||||
|
hijack_comments = []
|
||||||
|
hijack_fixes = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
cache = {}
|
||||||
|
batch_tokens = self.tokenize(texts)
|
||||||
|
batch_multipliers = []
|
||||||
|
for tokens in batch_tokens:
|
||||||
|
tuple_tokens = tuple(tokens)
|
||||||
|
|
||||||
|
if tuple_tokens in cache:
|
||||||
|
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||||
|
else:
|
||||||
|
fixes = []
|
||||||
|
remade_tokens = []
|
||||||
|
multipliers = []
|
||||||
|
mult = 1.0
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(tokens):
|
||||||
|
token = tokens[i]
|
||||||
|
|
||||||
|
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||||
|
|
||||||
|
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||||
|
if mult_change is not None:
|
||||||
|
mult *= mult_change
|
||||||
|
i += 1
|
||||||
|
elif embedding is None:
|
||||||
|
remade_tokens.append(token)
|
||||||
|
multipliers.append(mult)
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
emb_len = int(embedding.vec.shape[0])
|
||||||
|
fixes.append((len(remade_tokens), embedding))
|
||||||
|
remade_tokens += [0] * emb_len
|
||||||
|
multipliers += [mult] * emb_len
|
||||||
|
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||||
|
i += embedding_length_in_tokens
|
||||||
|
|
||||||
|
if len(remade_tokens) > maxlen - 2:
|
||||||
|
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||||
|
ovf = remade_tokens[maxlen - 2:]
|
||||||
|
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||||
|
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||||
|
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||||
|
|
||||||
|
token_count = len(remade_tokens)
|
||||||
|
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||||
|
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||||
|
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||||
|
|
||||||
|
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||||
|
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||||
|
|
||||||
|
remade_batch_tokens.append(remade_tokens)
|
||||||
|
hijack_fixes.append(fixes)
|
||||||
|
batch_multipliers.append(multipliers)
|
||||||
|
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
use_old = opts.use_old_emphasis_implementation
|
||||||
|
if use_old:
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||||
|
else:
|
||||||
|
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||||
|
|
||||||
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
|
if len(used_custom_terms) > 0:
|
||||||
|
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||||
|
|
||||||
|
if use_old:
|
||||||
|
self.hijack.fixes = hijack_fixes
|
||||||
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
|
||||||
|
z = None
|
||||||
|
i = 0
|
||||||
|
while max(map(len, remade_batch_tokens)) != 0:
|
||||||
|
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||||
|
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||||
|
|
||||||
|
self.hijack.fixes = []
|
||||||
|
for unfiltered in hijack_fixes:
|
||||||
|
fixes = []
|
||||||
|
for fix in unfiltered:
|
||||||
|
if fix[0] == i:
|
||||||
|
fixes.append(fix[1])
|
||||||
|
self.hijack.fixes.append(fixes)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
multipliers = []
|
||||||
|
for j in range(len(remade_batch_tokens)):
|
||||||
|
if len(remade_batch_tokens[j]) > 0:
|
||||||
|
tokens.append(remade_batch_tokens[j][:75])
|
||||||
|
multipliers.append(batch_multipliers[j][:75])
|
||||||
|
else:
|
||||||
|
tokens.append([self.id_end] * 75)
|
||||||
|
multipliers.append([1.0] * 75)
|
||||||
|
|
||||||
|
z1 = self.process_tokens(tokens, multipliers)
|
||||||
|
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||||
|
|
||||||
|
remade_batch_tokens = rem_tokens
|
||||||
|
batch_multipliers = rem_multipliers
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
|
if not opts.use_old_emphasis_implementation:
|
||||||
|
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
||||||
|
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||||
|
|
||||||
|
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||||
|
|
||||||
|
if self.id_end != self.id_pad:
|
||||||
|
for batch_pos in range(len(remade_batch_tokens)):
|
||||||
|
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||||
|
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||||
|
|
||||||
|
z = self.encode_with_transformers(tokens)
|
||||||
|
|
||||||
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
|
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||||
|
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
||||||
|
original_mean = z.mean()
|
||||||
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
|
new_mean = z.mean()
|
||||||
|
z *= original_mean / new_mean
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
self.tokenizer = wrapped.tokenizer
|
||||||
|
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||||
|
|
||||||
|
self.token_mults = {}
|
||||||
|
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||||
|
for text, ident in tokens_with_parens:
|
||||||
|
mult = 1.0
|
||||||
|
for c in text:
|
||||||
|
if c == '[':
|
||||||
|
mult /= 1.1
|
||||||
|
if c == ']':
|
||||||
|
mult *= 1.1
|
||||||
|
if c == '(':
|
||||||
|
mult *= 1.1
|
||||||
|
if c == ')':
|
||||||
|
mult /= 1.1
|
||||||
|
|
||||||
|
if mult != 1.0:
|
||||||
|
self.token_mults[ident] = mult
|
||||||
|
|
||||||
|
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||||
|
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||||
|
self.id_pad = self.id_end
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||||
|
|
||||||
|
if opts.CLIP_stop_at_last_layers > 1:
|
||||||
|
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||||
|
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||||
|
else:
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||||
|
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||||
|
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
@ -199,8 +199,8 @@ def sample_plms(self,
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
def get_model_output(x, t):
|
def get_model_output(x, t):
|
||||||
@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info):
|
|||||||
|
|
||||||
|
|
||||||
def do_inpainting_hijack():
|
def do_inpainting_hijack():
|
||||||
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
||||||
|
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
|
||||||
|
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||||
|
# this file should be cleaned up later if weverything tuens out to work fine
|
||||||
|
|
||||||
|
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||||
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||||
|
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||||
|
|
||||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
# ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||||
|
|
||||||
|
37
modules/sd_hijack_open_clip.py
Normal file
37
modules/sd_hijack_open_clip.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import open_clip.tokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from modules import sd_hijack_clip, devices
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
tokenizer = open_clip.tokenizer._tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||||
|
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||||
|
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||||
|
self.id_pad = 0
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||||
|
|
||||||
|
tokenized = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||||
|
z = self.wrapped.encode_with_transformer(tokens)
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
ids = tokenizer.encode(init_text)
|
||||||
|
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||||
|
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
@ -127,7 +127,8 @@ class InterruptedException(BaseException):
|
|||||||
class VanillaStableDiffusionSampler:
|
class VanillaStableDiffusionSampler:
|
||||||
def __init__(self, constructor, sd_model):
|
def __init__(self, constructor, sd_model):
|
||||||
self.sampler = constructor(sd_model)
|
self.sampler = constructor(sd_model)
|
||||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
|
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||||
|
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||||
self.mask = None
|
self.mask = None
|
||||||
self.nmask = None
|
self.nmask = None
|
||||||
self.init_latent = None
|
self.init_latent = None
|
||||||
@ -218,7 +219,6 @@ class VanillaStableDiffusionSampler:
|
|||||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||||
|
|
||||||
|
|
||||||
def adjust_steps_if_invalid(self, p, num_steps):
|
def adjust_steps_if_invalid(self, p, num_steps):
|
||||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||||
valid_step = 999 / (1000 // num_steps)
|
valid_step = 999 / (1000 // num_steps)
|
||||||
@ -227,7 +227,6 @@ class VanillaStableDiffusionSampler:
|
|||||||
|
|
||||||
return num_steps
|
return num_steps
|
||||||
|
|
||||||
|
|
||||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||||
steps, t_enc = setup_img2img_steps(p, steps)
|
steps, t_enc = setup_img2img_steps(p, steps)
|
||||||
steps = self.adjust_steps_if_invalid(p, steps)
|
steps = self.adjust_steps_if_invalid(p, steps)
|
||||||
@ -260,9 +259,10 @@ class VanillaStableDiffusionSampler:
|
|||||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||||
|
|
||||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||||
|
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||||
if image_conditioning is not None:
|
if image_conditioning is not None:
|
||||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||||
|
|
||||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||||
|
|
||||||
@ -350,7 +350,9 @@ class TorchHijack:
|
|||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, funcname, sd_model):
|
def __init__(self, funcname, sd_model):
|
||||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||||
|
|
||||||
|
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||||
self.funcname = funcname
|
self.funcname = funcname
|
||||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||||
|
@ -11,17 +11,15 @@ import tqdm
|
|||||||
import modules.artists
|
import modules.artists
|
||||||
import modules.interrogate
|
import modules.interrogate
|
||||||
import modules.memmon
|
import modules.memmon
|
||||||
import modules.sd_models
|
|
||||||
import modules.styles
|
import modules.styles
|
||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
|
from modules import localization, sd_vae, extensions, script_loading
|
||||||
from modules.hypernetworks import hypernetwork
|
|
||||||
from modules.paths import models_path, script_path, sd_path
|
from modules.paths import models_path, script_path, sd_path
|
||||||
|
|
||||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||||
default_sd_model_file = sd_model_file
|
default_sd_model_file = sd_model_file
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
|
||||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||||
@ -121,10 +119,12 @@ xformers_available = False
|
|||||||
config_filename = cmd_opts.ui_settings_file
|
config_filename = cmd_opts.ui_settings_file
|
||||||
|
|
||||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
hypernetworks = {}
|
||||||
loaded_hypernetwork = None
|
loaded_hypernetwork = None
|
||||||
|
|
||||||
|
|
||||||
def reload_hypernetworks():
|
def reload_hypernetworks():
|
||||||
|
from modules.hypernetworks import hypernetwork
|
||||||
global hypernetworks
|
global hypernetworks
|
||||||
|
|
||||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||||
@ -206,10 +206,11 @@ class State:
|
|||||||
if self.current_latent is None:
|
if self.current_latent is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
import modules.sd_samplers
|
||||||
if opts.show_progress_grid:
|
if opts.show_progress_grid:
|
||||||
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
|
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||||
else:
|
else:
|
||||||
self.current_image = sd_samplers.sample_to_image(self.current_latent)
|
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||||
|
|
||||||
self.current_image_sampling_step = self.sampling_step
|
self.current_image_sampling_step = self.sampling_step
|
||||||
|
|
||||||
@ -248,6 +249,21 @@ def options_section(section_identifier, options_dict):
|
|||||||
return options_dict
|
return options_dict
|
||||||
|
|
||||||
|
|
||||||
|
def list_checkpoint_tiles():
|
||||||
|
import modules.sd_models
|
||||||
|
return modules.sd_models.checkpoint_tiles()
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_checkpoints():
|
||||||
|
import modules.sd_models
|
||||||
|
return modules.sd_models.list_models()
|
||||||
|
|
||||||
|
|
||||||
|
def list_samplers():
|
||||||
|
import modules.sd_samplers
|
||||||
|
return modules.sd_samplers.all_samplers
|
||||||
|
|
||||||
|
|
||||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||||
|
|
||||||
options_templates = {}
|
options_templates = {}
|
||||||
@ -333,7 +349,7 @@ options_templates.update(options_section(('training', "Training"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||||
@ -385,7 +401,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
|
@ -64,7 +64,8 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
self.word_embeddings[embedding.name] = embedding
|
self.word_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
|
||||||
|
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||||
|
|
||||||
first_id = ids[0]
|
first_id = ids[0]
|
||||||
if first_id not in self.ids_lookup:
|
if first_id not in self.ids_lookup:
|
||||||
@ -155,13 +156,11 @@ class EmbeddingDatabase:
|
|||||||
|
|
||||||
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||||
cond_model = shared.sd_model.cond_stage_model
|
cond_model = shared.sd_model.cond_stage_model
|
||||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||||
|
|
||||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
||||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
|
||||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||||
|
|
||||||
for i in range(num_vectors_per_token):
|
for i in range(num_vectors_per_token):
|
||||||
|
@ -478,9 +478,7 @@ def create_toprow(is_img2img):
|
|||||||
if is_img2img:
|
if is_img2img:
|
||||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||||
|
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||||
if cmd_opts.deepdanbooru:
|
|
||||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
|
||||||
|
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call):
|
|||||||
outputs=[img2img_prompt],
|
outputs=[img2img_prompt],
|
||||||
)
|
)
|
||||||
|
|
||||||
if cmd_opts.deepdanbooru:
|
img2img_deepbooru.click(
|
||||||
img2img_deepbooru.click(
|
fn=interrogate_deepbooru,
|
||||||
fn=interrogate_deepbooru,
|
inputs=[init_img],
|
||||||
inputs=[init_img],
|
outputs=[img2img_prompt],
|
||||||
outputs=[img2img_prompt],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,3 +28,4 @@ kornia
|
|||||||
lark
|
lark
|
||||||
inflection
|
inflection
|
||||||
GitPython
|
GitPython
|
||||||
|
torchsde
|
||||||
|
@ -25,3 +25,4 @@ kornia==0.6.7
|
|||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
GitPython==3.1.27
|
GitPython==3.1.27
|
||||||
|
torchsde==0.2.5
|
||||||
|
70
v1-inference.yaml
Normal file
70
v1-inference.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
5
webui.py
5
webui.py
@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware
|
|||||||
|
|
||||||
from modules.paths import script_path
|
from modules.paths import script_path
|
||||||
|
|
||||||
from modules import devices, sd_samplers, upscaler, extensions, localization
|
from modules import shared, devices, sd_samplers, upscaler, extensions, localization
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.extras
|
import modules.extras
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
@ -23,7 +23,6 @@ import modules.scripts
|
|||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
import modules.sd_models
|
import modules.sd_models
|
||||||
import modules.sd_vae
|
import modules.sd_vae
|
||||||
import modules.shared as shared
|
|
||||||
import modules.txt2img
|
import modules.txt2img
|
||||||
import modules.script_callbacks
|
import modules.script_callbacks
|
||||||
|
|
||||||
@ -86,7 +85,7 @@ def initialize():
|
|||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
|
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
|
||||||
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
|
||||||
|
|
||||||
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user