stable-diffusion-webui/modules/upscaler_utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

142 lines
4.3 KiB
Python
Raw Normal View History

import logging
from typing import Callable
import numpy as np
import torch
import tqdm
from PIL import Image
from modules import images, shared
from modules.torch_utils import get_param
logger = logging.getLogger(__name__)
def upscale_without_tiling(model, img: Image.Image):
img = np.array(img)
img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
param = get_param(model)
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, 'RGB')
def upscale_with_model(
model: Callable[[torch.Tensor], torch.Tensor],
img: Image.Image,
*,
tile_size: int,
tile_overlap: int = 0,
desc="tiled upscale",
) -> Image.Image:
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img)
output = upscale_without_tiling(model, img)
logger.debug("=> %s", output)
return output
grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
newtiles = []
with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
for y, h, row in grid.tiles:
newrow = []
for x, w, tile in row:
logger.debug("Tile (%d, %d) %s...", x, y, tile)
output = upscale_without_tiling(model, tile)
scale_factor = output.width // tile.width
logger.debug("=> %s (scale factor %s)", output, scale_factor)
newrow.append([x * scale_factor, w * scale_factor, output])
p.update(1)
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = images.Grid(
newtiles,
tile_w=grid.tile_w * scale_factor,
tile_h=grid.tile_h * scale_factor,
image_w=grid.image_w * scale_factor,
image_h=grid.image_h * scale_factor,
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)
def tiled_upscale_2(
img,
model,
*,
tile_size: int,
tile_overlap: int,
scale: int,
device,
desc="Tiled upscale",
):
# Alternative implementation of `upscale_with_model` originally used by
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
# Pillow space without weighting.
b, c, h, w = img.size()
tile_size = min(tile_size, h, w)
if tile_size <= 0:
logger.debug("Upscaling %s without tiling", img.shape)
return model(img)
stride = tile_size - tile_overlap
h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
result = torch.zeros(
b,
c,
h * scale,
w * scale,
device=device,
).type_as(img)
weights = torch.zeros_like(result)
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
for h_idx in h_idx_list:
if shared.state.interrupted or shared.state.skipped:
break
for w_idx in w_idx_list:
if shared.state.interrupted or shared.state.skipped:
break
in_patch = img[
...,
h_idx : h_idx + tile_size,
w_idx : w_idx + tile_size,
]
out_patch = model(in_patch)
result[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch)
out_patch_mask = torch.ones_like(out_patch)
weights[
...,
h_idx * scale : (h_idx + tile_size) * scale,
w_idx * scale : (w_idx + tile_size) * scale,
].add_(out_patch_mask)
pbar.update(1)
output = result.div_(weights)
return output