From f299645aeeb65fcddde2d136fd550b6b01ffebb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Sep 2022 18:54:12 +0300 Subject: [PATCH] ESRGAN support --- ESRGAN/Put ESRGAN models here.txt | 0 README.md | 10 ++- modules/esrgam_model_arch.py | 80 ++++++++++++++++++ modules/esrgan_model.py | 134 ++++++++++++++++++++++++++++++ modules/images.py | 38 ++++++++- modules/img2img.py | 6 +- modules/realesrgan_model.py | 15 ++++ modules/shared.py | 7 +- modules/ui.py | 25 ++++-- webui.py | 43 +++++++--- 10 files changed, 327 insertions(+), 31 deletions(-) create mode 100644 ESRGAN/Put ESRGAN models here.txt create mode 100644 modules/esrgam_model_arch.py create mode 100644 modules/esrgan_model.py diff --git a/ESRGAN/Put ESRGAN models here.txt b/ESRGAN/Put ESRGAN models here.txt new file mode 100644 index 000000000..e69de29bb diff --git a/README.md b/README.md index 610826c29..6cf246d22 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,14 @@ Original script with Gradio UI was written by a kind anonymous user. This is a m - Loopback - X/Y plot - Textual Inversion -- Resizing options +- Extras tab with: + - GFPGAN, neural network that fixes faces + - RealESRGAN, neural network upscaler + - ESRGAN, neural network with a lot of third party models +- Resizing aspect ratio options - Sampling method selection - Interrupt processing at any time - 4GB videocard support -- Option to use GFPGAN - Correct seeds for batches - Prompt length validation - Generation parameters added as text to PNG @@ -49,6 +52,9 @@ can obtain it from the following places: You optionally can use GPFGAN to improve faces, then you'll need to download the model from [here](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth). +To use ESRGAN models, put them into ESRGAN directory in the same location as webui.py. A file will be loaded +as model if it has .pth extension. Grab models from the [Model Database](https://upscale.wiki/wiki/Model_Database). + ### Automatic installation/launch - install [Python 3.10.6](https://www.python.org/downloads/windows/) diff --git a/modules/esrgam_model_arch.py b/modules/esrgam_model_arch.py new file mode 100644 index 000000000..e413d36ed --- /dev/null +++ b/modules/esrgam_model_arch.py @@ -0,0 +1,80 @@ +# this file is taken from https://github.com/xinntao/ESRGAN + +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc, out_nc, nf, nb, gc=32): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py new file mode 100644 index 000000000..3dcef5a6e --- /dev/null +++ b/modules/esrgan_model.py @@ -0,0 +1,134 @@ +import os +import sys +import traceback + +import numpy as np +import torch +from PIL import Image + +import modules.esrgam_model_arch as arch +from modules import shared +from modules.shared import opts +import modules.images + + +def load_model(filename): + # this code is adapted from https://github.com/xinntao/ESRGAN + + pretrained_net = torch.load(filename) + crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + + if 'conv_first.weight' in pretrained_net: + crt_model.load_state_dict(pretrained_net) + return crt_model + + crt_net = crt_model.state_dict() + load_net_clean = {} + for k, v in pretrained_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + pretrained_net = load_net_clean + + tbd = [] + for k, v in crt_net.items(): + tbd.append(k) + + # directly copy + for k, v in crt_net.items(): + if k in pretrained_net and pretrained_net[k].size() == v.size(): + crt_net[k] = pretrained_net[k] + tbd.remove(k) + + crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] + crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] + + for k in tbd.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[k] = pretrained_net[ori_k] + tbd.remove(k) + + crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] + crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] + crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] + crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] + crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] + crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] + crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] + crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] + crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] + crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] + + crt_model.load_state_dict(crt_net) + crt_model.eval() + return crt_model + +def upscale_without_tiling(model, img): + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def esrgan_upscale(model, img): + if opts.ESRGAN_tile == 0: + return upscale_without_tiling(model, img) + + grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) + newtiles = [] + scale_factor = 1 + + for y, h, row in grid.tiles: + newrow = [] + for tiledata in row: + x, w, tile = tiledata + + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + + newrow.append([x * scale_factor, w * scale_factor, output]) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) + output = modules.images.combine_grid(newgrid) + return output + + +class UpscalerESRGAN(modules.images.Upscaler): + def __init__(self, filename, title): + self.name = title + self.model = load_model(filename) + + def do_upscale(self, img): + model = self.model.to(shared.device) + img = esrgan_upscale(model, img) + return img + + +def load_models(dirname): + for file in os.listdir(dirname): + path = os.path.join(dirname, file) + model_name, extension = os.path.splitext(file) + + if extension != '.pt' and extension != '.pth': + continue + + try: + modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name)) + except Exception: + print(f"Error loading ESRGAN model: {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/images.py b/modules/images.py index 4b9667d20..4226db00c 100644 --- a/modules/images.py +++ b/modules/images.py @@ -6,6 +6,7 @@ import re import numpy as np from PIL import Image, ImageFont, ImageDraw, PngImagePlugin +import modules.shared from modules.shared import opts LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -45,20 +46,20 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): cols = math.ceil((w - overlap) / non_overlap_width) rows = math.ceil((h - overlap) / non_overlap_height) - dx = (w - tile_w) // (cols-1) if cols > 1 else 0 - dy = (h - tile_h) // (rows-1) if rows > 1 else 0 + dx = (w - tile_w) / (cols-1) if cols > 1 else 0 + dy = (h - tile_h) / (rows-1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): row_images = [] - y = row * dy + y = int(row * dy) if y + tile_h >= h: y = h - tile_h for col in range(cols): - x = col * dx + x = int(col * dx) if x+tile_w >= w: x = w - tile_w @@ -291,3 +292,32 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: file.write(info + "\n") + +class Upscaler: + name = "Lanczos" + + def do_upscale(self, img): + return img + + def upscale(self, img, w, h): + for i in range(3): + if img.width >= w and img.height >= h: + break + + img = self.do_upscale(img) + + if img.width != w or img.height != h: + img = img.resize((w, h), resample=LANCZOS) + + return img + + +class UpscalerNone(Upscaler): + name = "None" + + def upscale(self, img, w, h): + return img + + +modules.shared.sd_upscalers.append(UpscalerNone()) +modules.shared.sd_upscalers.append(Upscaler()) diff --git a/modules/img2img.py b/modules/img2img.py index d5787dd37..b1ef13267 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import modules.images as images import modules.scripts -def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): +def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): is_inpaint = mode == 1 is_loopback = mode == 2 is_upscale = mode == 3 @@ -81,8 +81,8 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index initial_seed = None initial_info = None - upscaler = shared.sd_upscalers.get(upscaler_name, next(iter(shared.sd_upscalers.values()))) - img = upscaler(init_img) + upscaler = shared.sd_upscalers[upscaler_index] + img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2) processing.torch_gc() diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 5a6666a33..e480887f1 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -4,6 +4,7 @@ from collections import namedtuple import numpy as np from PIL import Image +import modules.images from modules.shared import cmd_opts RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) @@ -12,6 +13,17 @@ realesrgan_models = [] have_realesrgan = False RealESRGANer_constructor = None + +class UpscalerRealESRGAN(modules.images.Upscaler): + def __init__(self, upscaling, model_index): + self.upscaling = upscaling + self.model_index = model_index + self.name = realesrgan_models[model_index].name + + def do_upscale(self, img): + return upscale_with_realesrgan(img, self.upscaling, self.model_index) + + def setup_realesrgan(): global realesrgan_models global have_realesrgan @@ -42,6 +54,9 @@ def setup_realesrgan(): have_realesrgan = True RealESRGANer_constructor = RealESRGANer + for i, model in enumerate(realesrgan_models): + modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i)) + except Exception: print("Error importing Real-ESRGAN:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/shared.py b/modules/shared.py index c8c2749a9..72e92eb9a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -28,6 +28,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a w parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN every time after processing images. Warning: seems to cause memory leaks") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) cmd_opts = parser.parse_args() cpu = torch.device("cpu") @@ -79,7 +80,8 @@ class Options: "font": OptionInfo("arial.ttf", "Font for image grids that have text"), "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), - + "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), + "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), } def __init__(self): @@ -115,7 +117,6 @@ opts = Options() if os.path.exists(config_filename): opts.load(config_filename) - -sd_upscalers = {} +sd_upscalers = [] sd_model = None diff --git a/modules/ui.py b/modules/ui.py index d6b39c2fd..4119369e9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -256,10 +256,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Row(): use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan) + sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) with gr.Row(): - sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(shared.sd_upscalers.keys()), value=list(shared.sd_upscalers.keys())[0], visible=False) - sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) + sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False) with gr.Row(): batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) @@ -401,9 +401,18 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Column(variant='panel'): with gr.Group(): image = gr.Image(label="Source", source="upload", interactive=True, type="pil") - gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=gfpgan.have_gfpgan) - realesrgan_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=realesrgan.have_realesrgan) - realesrgan_model = gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan.realesrgan_models], value=realesrgan.realesrgan_models[0].name, type="index", interactive=realesrgan.have_realesrgan) + + upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) + + with gr.Group(): + gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=0, interactive=gfpgan.have_gfpgan) submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') @@ -417,8 +426,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): inputs=[ image, gfpgan_strength, - realesrgan_resize, - realesrgan_model, + upscaling_resize, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, ], outputs=[ result_image, diff --git a/webui.py b/webui.py index d79b59660..dbc9dd541 100644 --- a/webui.py +++ b/webui.py @@ -21,17 +21,14 @@ import modules.processing as processing import modules.sd_hijack import modules.gfpgan_model as gfpgan import modules.realesrgan_model as realesrgan +import modules.esrgan_model as esrgan import modules.images as images import modules.lowvram import modules.txt2img import modules.img2img -shared.sd_upscalers = { - "RealESRGAN": lambda img: realesrgan.upscale_with_realesrgan(img, 2, 0), - "Lanczos": lambda img: img.resize((img.width*2, img.height*2), resample=images.LANCZOS), - "None": lambda img: img -} +esrgan.load_models(cmd_opts.esrgan_models_path) realesrgan.setup_realesrgan() gfpgan.setup_gfpgan() @@ -54,26 +51,48 @@ def load_model_from_config(config, ckpt, verbose=False): model.eval() return model +cached_images = {} -def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index): +def run_extras(image, gfpgan_strength, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): processing.torch_gc() image = image.convert("RGB") outpath = opts.outdir_samples or opts.outdir_extras_samples - if gfpgan.have_gfpgan is not None and GFPGAN_strength > 0: - + if gfpgan.have_gfpgan is not None and gfpgan_strength > 0: restored_img = gfpgan.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) res = Image.fromarray(restored_img) - if GFPGAN_strength < 1.0: - res = Image.blend(image, res, GFPGAN_strength) + if gfpgan_strength < 1.0: + res = Image.blend(image, res, gfpgan_strength) image = res - if realesrgan.have_realesrgan and RealESRGAN_upscaling != 1.0: - image = realesrgan.upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) + if upscaling_resize != 1.0: + def upscale(image, scaler_index, resize): + small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) + pixels = tuple(np.array(small).flatten().tolist()) + key = (resize, scaler_index, image.width, image.height) + pixels + + c = cached_images.get(key) + if c is None: + upscaler = shared.sd_upscalers[scaler_index] + c = upscaler.upscale(image, image.width * resize, image.height * resize) + cached_images[key] = c + + return c + + res = upscale(image, extras_upscaler_1, upscaling_resize) + + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility>0: + res2 = upscale(image, extras_upscaler_2, upscaling_resize) + res = Image.blend(res, res2, extras_upscaler_2_visibility) + + image = res + + while len(cached_images) > 2: + del cached_images[next(iter(cached_images.keys()))] images.save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True, no_prompt=True)