From 9feb034e343d6d7ef63395821658fb3774b30a24 Mon Sep 17 00:00:00 2001 From: wangqyqq Date: Thu, 21 Dec 2023 20:15:51 +0800 Subject: [PATCH] support for sdxl-inpaint model --- configs/sd_xl_inpaint.yaml | 98 +++++++++++++++++++++++++++++++++++++ modules/processing.py | 19 +++++++ modules/sd_models_config.py | 6 ++- modules/sd_models_xl.py | 5 ++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 configs/sd_xl_inpaint.yaml diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml new file mode 100644 index 000000000..3bad37218 --- /dev/null +++ b/configs/sd_xl_inpaint.yaml @@ -0,0 +1,98 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + weighting_config: + target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 # multiplied by two + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/modules/processing.py b/modules/processing.py index 6f01c95f5..159548dba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -106,6 +106,20 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) else: + sd = sd_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + # The "masked-image" in this case will just be all 0.5 since the entire image is masked. + image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5 + image_conditioning = images_tensor_to_samples(image_conditioning, + approximation_indexes.get(opts.sd_vae_encode_method)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + # Dummy zero conditioning if we're not using inpainting or unclip models. # Still takes up a bit of memory, but no encoder call. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. @@ -362,6 +376,11 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key == "crossattn-adm": return self.unclip_image_conditioning(source_image) + sd = self.sampler.model_wrap.inner_model.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index deab2f6e2..b38137eb5 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -15,6 +15,7 @@ config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -71,7 +72,10 @@ def guess_model_config_from_state_dict(sd, filename): sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: - return config_sdxl + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 011233216..d8a9a73bc 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -34,6 +34,11 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + sd = self.model.state_dict() + diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) + if diffusion_model_input.shape[1] == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + return self.model(x, t, cond)