merge CFGDenoiserEdit and CFGDenoiser into single object

This commit is contained in:
AUTOMATIC 2023-02-04 11:06:17 +03:00
parent 127bfb6c41
commit 72dd5785d9

View File

@ -41,90 +41,6 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
class CFGDenoiserEdit(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None
self.step = 0
def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index])
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
return denoised
class CFGDenoiser(torch.nn.Module):
"""
@ -141,6 +57,7 @@ class CFGDenoiser(torch.nn.Module):
self.nmask = None
self.init_latent = None
self.step = 0
self.image_cfg_scale = None
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@ -152,19 +69,36 @@ class CFGDenoiser(torch.nn.Module):
return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
return denoised
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
@ -173,7 +107,10 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if not is_edit_model:
cond_in = torch.cat([tensor, uncond])
else:
cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
@ -189,7 +126,13 @@ class CFGDenoiser(torch.nn.Module):
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
if not is_edit_model:
c_crossattn = [tensor[a:b]]
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
@ -200,7 +143,10 @@ class CFGDenoiser(torch.nn.Module):
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if not is_edit_model:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
else:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@ -280,12 +226,10 @@ class KDiffusionSampler:
return p.steps
def initialize(self, p):
if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1:
self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap)
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@ -355,9 +299,6 @@ class KDiffusionSampler:
'cond_scale': p.cfg_scale,
}
if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None:
extra_args['image_cfg_scale'] = p.image_cfg_scale
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples