added prompt matrix feature

all images in batches now have proper seeds, not just the first one
added code to remove bad characters from filenames
added code to flag output which writes it to csv and saves images
renamed some fields in UI for clarity
This commit is contained in:
AUTOMATIC 2022-08-23 00:34:49 +03:00
parent b63d0726cd
commit 3395c29127
1 changed files with 126 additions and 40 deletions

166
webui.py
View File

@ -8,12 +8,12 @@ from omegaconf import OmegaConf
from PIL import Image
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes
import random
import math
import csv
import k_diffusion as K
from ldm.util import instantiate_from_config
@ -28,6 +28,8 @@ mimetypes.add_type('application/javascript', '.js')
opt_C = 4
opt_f = 8
invalid_filename_chars = '<>:"/\|?*'
parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
@ -127,13 +129,14 @@ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cp
model = model.half().to(device)
def image_grid(imgs, batch_size):
def image_grid(imgs, batch_size, round_down=False):
if opt.n_rows > 0:
rows = opt.n_rows
elif opt.n_rows == 0:
rows = batch_size
else:
rows = round(math.sqrt(len(imgs)))
rows = math.sqrt(len(imgs))
rows = int(rows) if round_down else round(rows)
cols = math.ceil(len(imgs) / rows)
@ -146,7 +149,7 @@ def image_grid(imgs, batch_size):
return grid
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int):
torch.cuda.empty_cache()
outpath = opt.outdir or "outputs/txt2img-samples"
@ -155,6 +158,7 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
seed = random.randrange(4294967294)
seed = int(seed)
keep_same_seed = False
is_PLMS = sampler_name == 'PLMS'
is_DDIM = sampler_name == 'DDIM'
@ -177,59 +181,99 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
batch_size = n_samples
assert prompt is not None
data = [batch_size * [prompt]]
prompts = batch_size * [prompt]
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
prompt_matrix_prompts = []
comment = ""
if prompt_matrix:
keep_same_seed = True
comment = "Image prompts:\n\n"
items = prompt.split("|")
combination_count = 2 ** (len(items)-1)
for combination_num in range(combination_count):
current = items[0]
label = 'A'
for n, text in enumerate(items[1:]):
if combination_num & (2**n) > 0:
current += ("" if text.strip().startswith(",") else ", ") + text
label += chr(ord('B') + n)
comment += " - " + label + "\n"
prompt_matrix_prompts.append(current)
n_iter = math.ceil(len(prompt_matrix_prompts) / batch_size)
comment += "\nwhere:\n"
for n, text in enumerate(items):
comment += " " + chr(ord('A') + n) + " = " + items[n] + "\n"
precision_scope = autocast if opt.precision == "autocast" else nullcontext
output_images = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
for n in range(n_iter):
for batch_index, prompts in enumerate(data):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
if prompt_matrix:
prompts = prompt_matrix_prompts[n*batch_size:(n+1)*batch_size]
current_seed = seed + n * len(data) + batch_index
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt_C, height // opt_f, width // opt_f]
batch_seed = seed if keep_same_seed else seed + n * len(prompts)
# we manually generate all input noises because each one should have a specific seed
xs = []
for i in range(len(prompts)):
current_seed = seed if keep_same_seed else batch_seed + i
torch.manual_seed(current_seed)
xs.append(torch.randn(shape, device=device))
x = torch.stack(xs)
if is_Kdif:
sigmas = model_wrap.get_sigmas(ddim_steps)
x = torch.randn([n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
if is_Kdif:
sigmas = model_wrap.get_sigmas(ddim_steps)
x = x * sigmas[0]
model_wrap_cfg = CFGDenoiser(model_wrap)
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args={'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}, disable=False)
elif sampler is not None:
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None)
elif sampler is not None:
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=c, batch_size=len(prompts), shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=x)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save or not opt.skip_grid:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
if not opt.skip_save or not opt.skip_grid:
for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8)
if use_GFPGAN and GFPGAN is not None:
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img
image = Image.fromarray(x_sample)
filename = f"{base_count:05}-{seed if keep_same_seed else batch_seed + i}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"
image.save(os.path.join(sample_path, filename))
output_images.append(image)
base_count += 1
if use_GFPGAN and GFPGAN is not None:
cropped_faces, restored_faces, restored_img = GFPGAN.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img
image = Image.fromarray(x_sample)
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
output_images.append(image)
base_count += 1
if not opt.skip_grid:
# additionally, save as grid
grid = image_grid(output_images, batch_size)
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
@ -242,8 +286,49 @@ def dream(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, ddi
Steps: {ddim_steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
if len(comment) > 0:
info += "\n\n" + comment
return output_images, seed, info
class Flagging(gr.FlaggingCallback):
def setup(self, components, flagging_dir: str):
pass
def flag(self, flag_data, flag_option=None, flag_index=None, username=None) -> int:
os.makedirs("log/images", exist_ok=True)
# those must match the "dream" function
prompt, ddim_steps, sampler_name, use_GFPGAN, prompt_matrix, ddim_eta, n_iter, n_samples, cfg_scale, request_seed, height, width, images, seed, comment = flag_data
filenames = []
with open("log/log.csv", "a", encoding="utf8", newline='') as file:
import time
import base64
at_start = file.tell() == 0
writer = csv.writer(file)
if at_start:
writer.writerow(["prompt", "seed", "width", "height", "cfgs", "steps", "filename"])
filename_base = str(int(time.time() * 1000))
for i, filedata in enumerate(images):
filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
with open(filename, "wb") as imgfile:
imgfile.write(base64.decodebytes(filedata.encode('utf-8')))
filenames.append(filename)
writer.writerow([prompt, seed, width, height, cfg_scale, ddim_steps, filenames[0]])
print("Logged:", filenames[0])
dream_interface = gr.Interface(
dream,
@ -252,10 +337,11 @@ dream_interface = gr.Interface(
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
gr.Radio(label='Sampling method', choices=["DDIM", "PLMS", "k-diffusion"], value="k-diffusion"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=16, step=1, label='Sampling iterations', value=1),
gr.Slider(minimum=1, maximum=4, step=1, label='Samples per iteration', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=4, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly should the image follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
@ -267,7 +353,7 @@ dream_interface = gr.Interface(
],
title="Stable Diffusion Text-to-Image K",
description="Generate images from text with Stable Diffusion (using K-LMS)",
allow_flagging="never"
flagging_callback=Flagging()
)
@ -346,8 +432,8 @@ def translation(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, ddim_e
x_sample = restored_img
image = Image.fromarray(x_sample)
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.png"))
image.save(os.path.join(sample_path, f"{base_count:05}-{current_seed}_{prompt.replace(' ', '_')[:128]}.png"))
output_images.append(image)
base_count += 1