diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py index d4cf3fda3..f56e1e226 100644 --- a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -57,10 +57,14 @@ def latent_blend(settings, a, b, t): # NOTE: We use inplace operations wherever possible. - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) + if len(t.shape) == 3: + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + else: + t2 = t + t3 = t[:, 0][:, None] one_minus_t2 = 1 - t2 one_minus_t3 = 1 - t3 @@ -135,7 +139,10 @@ def apply_adaptive_masks( from PIL import Image, ImageOps, ImageFilter # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - latent_mask = nmask[0].float() + if len(nmask.shape) == 3: + latent_mask = nmask[0].float() + else: + latent_mask = nmask[:, 0].float() # convert the original mask into a form we use to scale distances for thresholding mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) mask_scalar = (0.5 * (1 - settings.composite_mask_influence) @@ -157,7 +164,14 @@ def apply_adaptive_masks( percentile_min=0.25, percentile_max=0.75, min_width=1) # The distance at which opacity of original decreases to 50% - half_weighted_distance = settings.composite_difference_threshold * mask_scalar + if len(mask_scalar.shape) == 3: + if mask_scalar.shape[0] > i: + half_weighted_distance = settings.composite_difference_threshold * mask_scalar[i] + else: + half_weighted_distance = settings.composite_difference_threshold * mask_scalar[0] + else: + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)