mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
ESRGAN support
This commit is contained in:
parent
78278ce695
commit
f299645aee
0
ESRGAN/Put ESRGAN models here.txt
Normal file
0
ESRGAN/Put ESRGAN models here.txt
Normal file
10
README.md
10
README.md
@ -19,11 +19,14 @@ Original script with Gradio UI was written by a kind anonymous user. This is a m
|
||||
- Loopback
|
||||
- X/Y plot
|
||||
- Textual Inversion
|
||||
- Resizing options
|
||||
- Extras tab with:
|
||||
- GFPGAN, neural network that fixes faces
|
||||
- RealESRGAN, neural network upscaler
|
||||
- ESRGAN, neural network with a lot of third party models
|
||||
- Resizing aspect ratio options
|
||||
- Sampling method selection
|
||||
- Interrupt processing at any time
|
||||
- 4GB videocard support
|
||||
- Option to use GFPGAN
|
||||
- Correct seeds for batches
|
||||
- Prompt length validation
|
||||
- Generation parameters added as text to PNG
|
||||
@ -49,6 +52,9 @@ can obtain it from the following places:
|
||||
|
||||
You optionally can use GPFGAN to improve faces, then you'll need to download the model from [here](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth).
|
||||
|
||||
To use ESRGAN models, put them into ESRGAN directory in the same location as webui.py. A file will be loaded
|
||||
as model if it has .pth extension. Grab models from the [Model Database](https://upscale.wiki/wiki/Model_Database).
|
||||
|
||||
### Automatic installation/launch
|
||||
|
||||
- install [Python 3.10.6](https://www.python.org/downloads/windows/)
|
||||
|
80
modules/esrgam_model_arch.py
Normal file
80
modules/esrgam_model_arch.py
Normal file
@ -0,0 +1,80 @@
|
||||
# this file is taken from https://github.com/xinntao/ESRGAN
|
||||
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
134
modules/esrgan_model.py
Normal file
134
modules/esrgan_model.py
Normal file
@ -0,0 +1,134 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.esrgam_model_arch as arch
|
||||
from modules import shared
|
||||
from modules.shared import opts
|
||||
import modules.images
|
||||
|
||||
|
||||
def load_model(filename):
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
|
||||
pretrained_net = torch.load(filename)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
return crt_model
|
||||
|
||||
crt_net = crt_model.state_dict()
|
||||
load_net_clean = {}
|
||||
for k, v in pretrained_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
pretrained_net = load_net_clean
|
||||
|
||||
tbd = []
|
||||
for k, v in crt_net.items():
|
||||
tbd.append(k)
|
||||
|
||||
# directly copy
|
||||
for k, v in crt_net.items():
|
||||
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
||||
crt_net[k] = pretrained_net[k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
||||
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
||||
|
||||
for k in tbd.copy():
|
||||
if 'RDB' in k:
|
||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[k] = pretrained_net[ori_k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
||||
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
||||
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
||||
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
||||
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
||||
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
||||
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
||||
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||
|
||||
crt_model.load_state_dict(crt_net)
|
||||
crt_model.eval()
|
||||
return crt_model
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
return Image.fromarray(output, 'RGB')
|
||||
|
||||
|
||||
def esrgan_upscale(model, img):
|
||||
if opts.ESRGAN_tile == 0:
|
||||
return upscale_without_tiling(model, img)
|
||||
|
||||
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
newtiles = []
|
||||
scale_factor = 1
|
||||
|
||||
for y, h, row in grid.tiles:
|
||||
newrow = []
|
||||
for tiledata in row:
|
||||
x, w, tile = tiledata
|
||||
|
||||
output = upscale_without_tiling(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = modules.images.combine_grid(newgrid)
|
||||
return output
|
||||
|
||||
|
||||
class UpscalerESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
|
||||
if extension != '.pt' and extension != '.pth':
|
||||
continue
|
||||
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
|
||||
except Exception:
|
||||
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
@ -6,6 +6,7 @@ import re
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
|
||||
import modules.shared
|
||||
from modules.shared import opts
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
@ -45,20 +46,20 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
cols = math.ceil((w - overlap) / non_overlap_width)
|
||||
rows = math.ceil((h - overlap) / non_overlap_height)
|
||||
|
||||
dx = (w - tile_w) // (cols-1) if cols > 1 else 0
|
||||
dy = (h - tile_h) // (rows-1) if rows > 1 else 0
|
||||
dx = (w - tile_w) / (cols-1) if cols > 1 else 0
|
||||
dy = (h - tile_h) / (rows-1) if rows > 1 else 0
|
||||
|
||||
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||
for row in range(rows):
|
||||
row_images = []
|
||||
|
||||
y = row * dy
|
||||
y = int(row * dy)
|
||||
|
||||
if y + tile_h >= h:
|
||||
y = h - tile_h
|
||||
|
||||
for col in range(cols):
|
||||
x = col * dx
|
||||
x = int(col * dx)
|
||||
|
||||
if x+tile_w >= w:
|
||||
x = w - tile_w
|
||||
@ -291,3 +292,32 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
|
||||
file.write(info + "\n")
|
||||
|
||||
|
||||
class Upscaler:
|
||||
name = "Lanczos"
|
||||
|
||||
def do_upscale(self, img):
|
||||
return img
|
||||
|
||||
def upscale(self, img, w, h):
|
||||
for i in range(3):
|
||||
if img.width >= w and img.height >= h:
|
||||
break
|
||||
|
||||
img = self.do_upscale(img)
|
||||
|
||||
if img.width != w or img.height != h:
|
||||
img = img.resize((w, h), resample=LANCZOS)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
|
||||
def upscale(self, img, w, h):
|
||||
return img
|
||||
|
||||
|
||||
modules.shared.sd_upscalers.append(UpscalerNone())
|
||||
modules.shared.sd_upscalers.append(Upscaler())
|
||||
|
@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
|
||||
import modules.images as images
|
||||
import modules.scripts
|
||||
|
||||
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
|
||||
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
|
||||
is_inpaint = mode == 1
|
||||
is_loopback = mode == 2
|
||||
is_upscale = mode == 3
|
||||
@ -81,8 +81,8 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
|
||||
initial_seed = None
|
||||
initial_info = None
|
||||
|
||||
upscaler = shared.sd_upscalers.get(upscaler_name, next(iter(shared.sd_upscalers.values())))
|
||||
img = upscaler(init_img)
|
||||
upscaler = shared.sd_upscalers[upscaler_index]
|
||||
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
|
||||
|
||||
processing.torch_gc()
|
||||
|
||||
|
@ -4,6 +4,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import modules.images
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
|
||||
@ -12,6 +13,17 @@ realesrgan_models = []
|
||||
have_realesrgan = False
|
||||
RealESRGANer_constructor = None
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, upscaling, model_index):
|
||||
self.upscaling = upscaling
|
||||
self.model_index = model_index
|
||||
self.name = realesrgan_models[model_index].name
|
||||
|
||||
def do_upscale(self, img):
|
||||
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
|
||||
|
||||
|
||||
def setup_realesrgan():
|
||||
global realesrgan_models
|
||||
global have_realesrgan
|
||||
@ -42,6 +54,9 @@ def setup_realesrgan():
|
||||
have_realesrgan = True
|
||||
RealESRGANer_constructor = RealESRGANer
|
||||
|
||||
for i, model in enumerate(realesrgan_models):
|
||||
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
|
||||
|
||||
except Exception:
|
||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
@ -28,6 +28,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a w
|
||||
parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN every time after processing images. Warning: seems to cause memory leaks")
|
||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
@ -79,7 +80,8 @@ class Options:
|
||||
"font": OptionInfo("arial.ttf", "Font for image grids that have text"),
|
||||
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"),
|
||||
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
|
||||
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscaling. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscaling. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@ -115,7 +117,6 @@ opts = Options()
|
||||
if os.path.exists(config_filename):
|
||||
opts.load(config_filename)
|
||||
|
||||
|
||||
sd_upscalers = {}
|
||||
sd_upscalers = []
|
||||
|
||||
sd_model = None
|
||||
|
@ -256,10 +256,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||
|
||||
with gr.Row():
|
||||
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
|
||||
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
|
||||
|
||||
with gr.Row():
|
||||
sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=list(shared.sd_upscalers.keys()), value=list(shared.sd_upscalers.keys())[0], visible=False)
|
||||
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
|
||||
sd_upscale_upscaler_name = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
|
||||
@ -401,9 +401,18 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Group():
|
||||
image = gr.Image(label="Source", source="upload", interactive=True, type="pil")
|
||||
gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=gfpgan.have_gfpgan)
|
||||
realesrgan_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=realesrgan.have_realesrgan)
|
||||
realesrgan_model = gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan.realesrgan_models], value=realesrgan.realesrgan_models[0].name, type="index", interactive=realesrgan.have_realesrgan)
|
||||
|
||||
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2)
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
|
||||
with gr.Group():
|
||||
extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
|
||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
|
||||
|
||||
with gr.Group():
|
||||
gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=0, interactive=gfpgan.have_gfpgan)
|
||||
|
||||
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
|
||||
|
||||
@ -417,8 +426,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
|
||||
inputs=[
|
||||
image,
|
||||
gfpgan_strength,
|
||||
realesrgan_resize,
|
||||
realesrgan_model,
|
||||
upscaling_resize,
|
||||
extras_upscaler_1,
|
||||
extras_upscaler_2,
|
||||
extras_upscaler_2_visibility,
|
||||
],
|
||||
outputs=[
|
||||
result_image,
|
||||
|
43
webui.py
43
webui.py
@ -21,17 +21,14 @@ import modules.processing as processing
|
||||
import modules.sd_hijack
|
||||
import modules.gfpgan_model as gfpgan
|
||||
import modules.realesrgan_model as realesrgan
|
||||
import modules.esrgan_model as esrgan
|
||||
import modules.images as images
|
||||
import modules.lowvram
|
||||
import modules.txt2img
|
||||
import modules.img2img
|
||||
|
||||
|
||||
shared.sd_upscalers = {
|
||||
"RealESRGAN": lambda img: realesrgan.upscale_with_realesrgan(img, 2, 0),
|
||||
"Lanczos": lambda img: img.resize((img.width*2, img.height*2), resample=images.LANCZOS),
|
||||
"None": lambda img: img
|
||||
}
|
||||
esrgan.load_models(cmd_opts.esrgan_models_path)
|
||||
realesrgan.setup_realesrgan()
|
||||
gfpgan.setup_gfpgan()
|
||||
|
||||
@ -54,26 +51,48 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
cached_images = {}
|
||||
|
||||
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||
def run_extras(image, gfpgan_strength, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
|
||||
processing.torch_gc()
|
||||
|
||||
image = image.convert("RGB")
|
||||
|
||||
outpath = opts.outdir_samples or opts.outdir_extras_samples
|
||||
|
||||
if gfpgan.have_gfpgan is not None and GFPGAN_strength > 0:
|
||||
|
||||
if gfpgan.have_gfpgan is not None and gfpgan_strength > 0:
|
||||
restored_img = gfpgan.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if GFPGAN_strength < 1.0:
|
||||
res = Image.blend(image, res, GFPGAN_strength)
|
||||
if gfpgan_strength < 1.0:
|
||||
res = Image.blend(image, res, gfpgan_strength)
|
||||
|
||||
image = res
|
||||
|
||||
if realesrgan.have_realesrgan and RealESRGAN_upscaling != 1.0:
|
||||
image = realesrgan.upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height) + pixels
|
||||
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.upscale(image, image.width * resize, image.height * resize)
|
||||
cached_images[key] = c
|
||||
|
||||
return c
|
||||
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
||||
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility>0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
|
||||
image = res
|
||||
|
||||
while len(cached_images) > 2:
|
||||
del cached_images[next(iter(cached_images.keys()))]
|
||||
|
||||
images.save_image(image, outpath, "", None, '', opts.samples_format, short_filename=True, no_prompt=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user