From 73ab982d1b7394574d1cf2e0a151bc457eeed769 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Sat, 2 Dec 2023 21:07:02 -0700 Subject: [PATCH] Blend masks are now produced afterward, based on an estimate of the visual difference between the original and modified latent images. This should remove ghosting and clipping artifacts from masks, while preserving the details of largely unchanged content. --- modules/processing.py | 119 ++++++++++++++++++++++++++++++++---------- 1 file changed, 90 insertions(+), 29 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 92fdebadd..ad716e11f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field import torch import numpy as np -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageFilter import random import cv2 from skimage import exposure @@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image): return image.convert('RGB') +def uncrop(image, dest_size, paste_loc): + x, y, w, h = paste_loc + base_image = Image.new('RGBA', dest_size) + image = images.resize_image(1, image, w, h) + base_image.paste(image, (x, y)) + image = base_image + + return image + + def apply_overlay(image, paste_loc, index, overlays): if overlays is None or index >= len(overlays): return image @@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays): overlay = overlays[index] if paste_loc is not None: - x, y, w, h = paste_loc - base_image = Image.new('RGBA', (overlay.width, overlay.height)) - image = images.resize_image(1, image, w, h) - base_image.paste(image, (x, y)) - image = base_image + image = uncrop(image, (overlay.width, overlay.height), paste_loc) image = image.convert('RGBA') image.alpha_composite(overlay) @@ -140,6 +146,7 @@ class StableDiffusionProcessing: do_not_save_grid: bool = False extra_generation_params: dict[str, Any] = None overlay_images: list = None + masks_for_overlay: list = None eta: float = None do_not_reload_embeddings: bool = False denoising_strength: float = 0 @@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim + # todo: generate masks the old fashioned way else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + # Generate the mask(s) based on similarity between the original and denoised latent vectors + if getattr(p, "image_mask", None) is not None: + # latent_mask = p.nmask[0].float().cpu() + + # convert the original mask into a form we use to scale distances for thresholding + # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2)) + # mask_scalar = mask_scalar / (1.00001-mask_scalar) + # mask_scalar = mask_scalar.numpy() + + latent_orig = p.init_latent + latent_proc = samples_ddim + latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1) + + kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + # half_weighted_distance = 1 # * mask_scalar + # converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** 2) + converted_mask = images.smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, p.width, p.height) + converted_mask = create_binary_mask(converted_mask) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if p.paste_to is not None: + converted_mask = uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + p.paste_to) + + p.masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + p.overlay_images[i] = image_masked.convert('RGBA') + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, + target_device=devices.cpu, + check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: x_samples_ddim = batch_params.images def infotext(index=0, use_main_prompt=False): - return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + return create_infotext(p, p.prompts, p.seeds, p.subseeds, + use_main_prompt=use_main_prompt, index=index, + all_negative_prompts=p.negative_prompts) save_samples = p.save_samples() @@ -923,19 +987,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + original_denoised_image = image.copy() image = apply_overlay(image, p.paste_to, i, p.overlay_images) if save_samples: - images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) + images.save_image(image, p.outpath_samples, "", p.seeds[i], + p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) infotexts.append(text) if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): - image_mask = p.mask_for_overlay.convert('RGB') - image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + image_mask = p.masks_for_overlay[i].convert('RGB') + image_mask_composite = Image.composite( + original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), + images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA') if opts.save_mask: images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") @@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): nmask: torch.Tensor = field(default=None, init=False) image_conditioning: torch.Tensor = field(default=None, init=False) init_img_hash: str = field(default=None, init=False) - mask_for_overlay: Image = field(default=None, init=False) init_latent: torch.Tensor = field(default=None, init=False) def __post_init__(self): @@ -1415,12 +1486,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) mask = image_mask.convert('L') crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding) crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) @@ -1431,13 +1496,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.paste_to = (x1, y1, x2-x1, y2-y1) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) - np_mask = np.array(image_mask).astype(np.float32) - np_mask /= 255 - np_mask = 1-pow(1-np_mask, 100) - np_mask *= 255 - np_mask = np.clip(np_mask, 0, 255).astype(np.uint8) - self.mask_for_overlay = Image.fromarray(np_mask) + self.masks_for_overlay = [] self.overlay_images = [] latent_mask = self.latent_mask if self.latent_mask is not None else image_mask @@ -1459,10 +1519,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = images.resize_image(self.resize_mode, image, self.width, self.height) if image_mask is not None: - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) - - self.overlay_images.append(image_masked.convert('RGBA')) + self.overlay_images.append(image) + self.masks_for_overlay.append(image_mask) # crop_region is not None if we are doing inpaint full res if crop_region is not None: @@ -1486,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.overlay_images is not None: self.overlay_images = self.overlay_images * self.batch_size + if self.masks_for_overlay is not None: + self.masks_for_overlay = self.masks_for_overlay * self.batch_size + if self.color_corrections is not None and len(self.color_corrections) == 1: self.color_corrections = self.color_corrections * self.batch_size