Add Pad conds v0 option

This commit is contained in:
AUTOMATIC1111 2024-01-27 22:30:12 +03:00
parent e717eaff86
commit 757dda9ade
6 changed files with 78 additions and 19 deletions

View File

@ -31,9 +31,12 @@ def backcompat(d):
if ver is None:
return
if ver < v160:
if ver < v160 and '[' in d.get('Prompt', ''):
d["Old prompt editing timelines"] = True
if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):
d["Pad conds v0"] = True
if ver < v170_tsnr:
d["Downcast alphas_cumprod"] = True

View File

@ -53,6 +53,7 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.padded_cond_uncond_v0 = False
self.sampler = sampler
self.model_wrap = None
self.p = None
@ -91,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
def pad_cond_uncond(self, cond, uncond):
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (cond.shape[1] - cond.shape[1]) // empty.shape[1]
if num_repeats < 0:
cond = pad_cond(cond, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
return cond, uncond
def pad_cond_uncond_v0(self, cond, uncond):
"""
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
If 'uncond' is a tensor, it is padded directly.
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
is repeated to match the number of columns in 'cond'.
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
to match the number of columns in 'cond'.
Args:
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
Returns:
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
Note:
This is the padding that was always used in DDIM before version 1.6.0
"""
is_dict_cond = isinstance(uncond, dict)
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
if uncond_vec.shape[1] < cond.shape[1]:
last_vector = uncond_vec[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
self.padded_cond_uncond_v0 = True
elif uncond_vec.shape[1] > cond.shape[1]:
uncond_vec = uncond_vec[:, :cond.shape[1]]
self.padded_cond_uncond_v0 = True
if is_dict_cond:
uncond['crossattn'] = uncond_vec
else:
uncond = uncond_vec
return cond, uncond
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@ -162,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
self.padded_cond_uncond_v0 = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:

View File

@ -335,3 +335,10 @@ class Sampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
def add_infotext(self, p):
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
if self.model_wrap_cfg.padded_cond_uncond_v0:
p.extra_generation_params["Pad conds v0"] = True

View File

@ -187,8 +187,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
self.add_infotext(p)
return samples
@ -234,8 +233,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
self.add_infotext(p)
return samples

View File

@ -133,8 +133,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
self.add_infotext(p)
return samples
@ -158,8 +157,7 @@ class CompVisSampler(sd_samplers_common.Sampler):
}
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True
self.add_infotext(p)
return samples

View File

@ -210,7 +210,8 @@ options_templates.update(options_section(('optimizations', "Optimizations", "sd"
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"),
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"pad_cond_uncond_v0": OptionInfo(False, "Pad prompt/negative prompt (v0)", infotext='Pad conds v0').info("alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; ignored if the above is set; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Radio, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."),