2023-07-11 18:16:43 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import sgm.models.diffusion
|
|
|
|
import sgm.modules.diffusionmodules.denoiser_scaling
|
|
|
|
import sgm.modules.diffusionmodules.discretizer
|
2023-07-12 20:52:43 +00:00
|
|
|
from modules import devices, shared, prompt_parser
|
2023-12-31 19:38:30 +00:00
|
|
|
from modules import torch_utils
|
2023-07-11 18:16:43 +00:00
|
|
|
|
|
|
|
|
2023-07-12 20:52:43 +00:00
|
|
|
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
2023-07-11 18:16:43 +00:00
|
|
|
for embedder in self.conditioner.embedders:
|
|
|
|
embedder.ucg_rate = 0.0
|
|
|
|
|
2023-07-20 15:22:52 +00:00
|
|
|
width = getattr(batch, 'width', 1024)
|
|
|
|
height = getattr(batch, 'height', 1024)
|
2023-07-14 06:16:01 +00:00
|
|
|
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
|
|
|
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
|
|
|
|
|
|
|
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
2023-07-12 20:52:43 +00:00
|
|
|
|
|
|
|
sdxl_conds = {
|
|
|
|
"txt": batch,
|
2023-07-14 06:16:01 +00:00
|
|
|
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
|
|
|
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
|
|
|
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
|
|
|
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
2023-07-12 20:52:43 +00:00
|
|
|
}
|
|
|
|
|
2023-07-14 06:16:01 +00:00
|
|
|
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
2023-07-13 08:35:52 +00:00
|
|
|
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
2023-07-11 18:16:43 +00:00
|
|
|
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|
|
|
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
2023-12-21 12:15:51 +00:00
|
|
|
sd = self.model.state_dict()
|
|
|
|
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
2023-12-27 02:20:56 +00:00
|
|
|
if diffusion_model_input is not None:
|
|
|
|
if diffusion_model_input.shape[1] == 9:
|
|
|
|
x = torch.cat([x] + cond['c_concat'], dim=1)
|
2023-12-21 12:15:51 +00:00
|
|
|
|
2023-07-11 18:16:43 +00:00
|
|
|
return self.model(x, t, cond)
|
|
|
|
|
|
|
|
|
2023-07-13 13:18:39 +00:00
|
|
|
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
|
|
|
return x
|
|
|
|
|
2023-07-14 06:16:01 +00:00
|
|
|
|
|
|
|
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
|
|
|
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
|
|
|
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
|
|
|
|
|
|
|
|
|
|
|
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
|
|
|
res = []
|
|
|
|
|
|
|
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
|
|
|
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
|
|
|
res.append(encoded)
|
|
|
|
|
|
|
|
return torch.cat(res, dim=1)
|
|
|
|
|
|
|
|
|
2023-07-29 12:15:06 +00:00
|
|
|
def tokenize(self: sgm.modules.GeneralConditioner, texts):
|
|
|
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
|
|
|
|
return embedder.tokenize(texts)
|
|
|
|
|
|
|
|
raise AssertionError('no tokenizer available')
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-07-14 06:16:01 +00:00
|
|
|
def process_texts(self, texts):
|
|
|
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
|
|
|
return embedder.process_texts(texts)
|
|
|
|
|
|
|
|
|
|
|
|
def get_target_prompt_token_count(self, token_count):
|
|
|
|
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
|
|
|
return embedder.get_target_prompt_token_count(token_count)
|
|
|
|
|
|
|
|
|
|
|
|
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
|
|
|
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
2023-07-29 12:15:06 +00:00
|
|
|
sgm.modules.GeneralConditioner.tokenize = tokenize
|
2023-07-14 06:16:01 +00:00
|
|
|
sgm.modules.GeneralConditioner.process_texts = process_texts
|
|
|
|
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
|
|
|
|
|
|
|
|
2023-07-11 18:16:43 +00:00
|
|
|
def extend_sdxl(model):
|
2023-07-14 06:16:01 +00:00
|
|
|
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
|
|
|
|
2023-12-31 19:38:30 +00:00
|
|
|
dtype = torch_utils.get_param(model.model.diffusion_model).dtype
|
2023-07-11 18:16:43 +00:00
|
|
|
model.model.diffusion_model.dtype = dtype
|
|
|
|
model.model.conditioning_key = 'crossattn'
|
2023-07-14 06:16:01 +00:00
|
|
|
model.cond_stage_key = 'txt'
|
|
|
|
# model.cond_stage_model will be set in sd_hijack
|
2023-07-11 18:16:43 +00:00
|
|
|
|
|
|
|
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
|
|
|
|
|
|
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
2023-10-25 04:54:28 +00:00
|
|
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)
|
2023-07-11 18:16:43 +00:00
|
|
|
|
2023-07-14 06:16:01 +00:00
|
|
|
model.conditioner.wrapped = torch.nn.Module()
|
2023-07-12 20:52:43 +00:00
|
|
|
|
2023-07-11 18:16:43 +00:00
|
|
|
|
2023-07-31 21:24:48 +00:00
|
|
|
sgm.modules.attention.print = shared.ldm_print
|
|
|
|
sgm.modules.diffusionmodules.model.print = shared.ldm_print
|
|
|
|
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
|
|
|
|
sgm.modules.encoders.modules.print = shared.ldm_print
|
2023-07-12 20:52:43 +00:00
|
|
|
|
2023-07-13 06:30:33 +00:00
|
|
|
# this gets the code to load the vanilla attention that we override
|
|
|
|
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
2023-07-13 06:38:54 +00:00
|
|
|
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|