added first version of inpainting

fixed flag option
This commit is contained in:
AUTOMATIC 2022-08-30 12:55:38 +03:00
parent 587db9c420
commit 54f74d4472

View File

@ -9,7 +9,7 @@ import torch.nn as nn
import numpy as np import numpy as np
import gradio as gr import gradio as gr
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from PIL import Image, ImageFont, ImageDraw, PngImagePlugin, ImageFilter, ImageOps
from torch import autocast from torch import autocast
import mimetypes import mimetypes
import random import random
@ -158,6 +158,7 @@ class Options:
"samples_save": OptionInfo(True, "Save indiviual samples"), "samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": OptionInfo('png', 'File format for indiviual samples'), "samples_format": OptionInfo('png', 'File format for indiviual samples'),
"grid_save": OptionInfo(True, "Save image grids"), "grid_save": OptionInfo(True, "Save image grids"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"grid_format": OptionInfo('png', 'File format for grids'), "grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
@ -957,6 +958,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count: if (p.prompt_matrix or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
return_grid = opts.return_grid
if p.prompt_matrix: if p.prompt_matrix:
grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2)) grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
@ -967,10 +970,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
print("Error creating prompt_matrix text:", file=sys.stderr) print("Error creating prompt_matrix text:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
output_images.insert(0, grid) return_grid = True
else: else:
grid = image_grid(output_images, p.batch_size) grid = image_grid(output_images, p.batch_size)
if return_grid:
output_images.insert(0, grid)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1 grid_count += 1
@ -1042,7 +1048,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True) os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function # those must match the "txt2img" function
prompt, ddim_steps, sampler_name, use_gfpgan, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, code, images, seed, comment = flag_data prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
filenames = [] filenames = []
@ -1067,7 +1073,7 @@ class Flagging(gr.FlaggingCallback):
filenames.append(filename) filenames.append(filename)
writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]]) writer.writerow([prompt, seed, width, height, cfg_scale, steps, filenames[0]])
print("Logged:", filenames[0]) print("Logged:", filenames[0])
@ -1097,27 +1103,64 @@ txt2img_interface = gr.Interface(
flagging_callback=Flagging() flagging_callback=Flagging()
) )
def fill(image, mask):
image_mod = Image.new('RGBA', (image.width, image.height))
image_masked = Image.new('RGBa', (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
image_masked = image_masked.convert('RGBa')
for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
for _ in range(repeats):
image_mod.alpha_composite(blurred)
return image_mod.convert("RGB")
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, **kwargs): def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.init_images = init_images self.init_images = init_images
self.resize_mode: int = resize_mode self.resize_mode: int = resize_mode
self.denoising_strength: float = denoising_strength self.denoising_strength: float = denoising_strength
self.init_latent = None self.init_latent = None
self.original_mask = mask
self.mask_blur = mask_blur
self.mask = None
self.nmask = None
def init(self): def init(self):
self.sampler = samplers_for_img2img[self.sampler_index].constructor() self.sampler = samplers_for_img2img[self.sampler_index].constructor()
if self.original_mask is not None:
if self.mask_blur > 0:
self.original_mask = self.original_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
latmask = self.original_mask.convert('RGB').resize((64, 64))
latmask = np.moveaxis(np.array(latmask, dtype=np.float), 2, 0) / 255
latmask = latmask[0]
latmask = np.tile(latmask[None], (4, 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
imgs = [] imgs = []
for img in self.init_images: for img in self.init_images:
image = img.convert("RGB") image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height) image = resize_image(self.resize_mode, image, self.width, self.height)
if self.original_mask is not None
image = fill(image, self.original_mask)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
imgs.append(image) imgs.append(image)
if len(imgs) == 1: if len(imgs) == 1:
@ -1139,16 +1182,33 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sigmas = self.sampler.model_wrap.get_sigmas(self.steps) sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
noise = x * sigmas[self.steps - t_enc - 1] noise = x * sigmas[self.steps - t_enc - 1]
xi = self.init_latent + noise xi = self.init_latent + noise
sigma_sched = sigmas[self.steps - t_enc - 1:] sigma_sched = sigmas[self.steps - t_enc - 1:]
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False)
#if self.mask is not None:
# xi = xi * self.mask + noise * self.nmask
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
samples_ddim = self.sampler.func(self.sampler.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.cfg_scale}, disable=False, callback=mask_cb if self.mask is not None else None)
if self.mask is not None:
samples_ddim = samples_ddim * self.nmask + self.init_latent * self.mask
return samples_ddim return samples_ddim
def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples" outpath = opts.outdir or "outputs/img2img-samples"
if init_img_with_mask is not None:
image = init_img_with_mask['image']
mask = init_img_with_mask['mask']
else:
image = init_img
mask = None
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img( p = StableDiffusionProcessingImg2Img(
@ -1164,7 +1224,8 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
height=height, height=height,
prompt_matrix=prompt_matrix, prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN, use_GFPGAN=use_GFPGAN,
init_images=[init_img], init_images=[image],
mask=mask,
resize_mode=resize_mode, resize_mode=resize_mode,
denoising_strength=denoising_strength, denoising_strength=denoising_strength,
extra_generation_params={"Denoising Strength": denoising_strength} extra_generation_params={"Denoising Strength": denoising_strength}
@ -1262,7 +1323,8 @@ img2img_interface = gr.Interface(
wrap_gradio_call(img2img), wrap_gradio_call(img2img),
inputs=[ inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1), gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(value=sample_img2img, source="upload", interactive=True, type="pil"), gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil"),
gr.Image(label="Image for inpainting with mask", source="upload", interactive=True, type="pil", tool="sketch"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20), gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"), gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),