mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
add built-in extension system
add support for adding upscalers in extensions move LDSR, ScuNET and SwinIR to built-in extensions
This commit is contained in:
parent
46b0d230e7
commit
b6e5edd746
6
extensions-builtin/LDSR/preload.py
Normal file
6
extensions-builtin/LDSR/preload.py
Normal file
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
|
@ -5,8 +5,8 @@ import traceback
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.ldsr_model_arch import LDSR
|
||||
from modules import shared
|
||||
from ldsr_model_arch import LDSR
|
||||
from modules import shared, script_callbacks
|
||||
|
||||
|
||||
class UpscalerLDSR(Upscaler):
|
||||
@ -52,3 +52,12 @@ class UpscalerLDSR(Upscaler):
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
|
||||
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
6
extensions-builtin/ScuNET/preload.py
Normal file
6
extensions-builtin/ScuNET/preload.py
Normal file
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
|
@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader
|
||||
from modules.scunet_model_arch import SCUNet as net
|
||||
from scunet_model_arch import SCUNet as net
|
||||
|
||||
|
||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
@ -49,7 +49,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
if model is None:
|
||||
return img
|
||||
|
||||
device = devices.device_scunet
|
||||
device = devices.get_device_for('scunet')
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
device = devices.device_scunet
|
||||
device = devices.get_device_for('scunet')
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
6
extensions-builtin/SwinIR/preload.py
Normal file
6
extensions-builtin/SwinIR/preload.py
Normal file
@ -0,0 +1,6 @@
|
||||
import os
|
||||
from modules import paths
|
||||
|
||||
|
||||
def preload(parser):
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
|
@ -7,13 +7,16 @@ from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules import modelloader, devices, script_callbacks, shared
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.swinir_model_arch import SwinIR as net
|
||||
from modules.swinir_model_arch_v2 import Swin2SR as net2
|
||||
from swinir_model_arch import SwinIR as net
|
||||
from swinir_model_arch_v2 import Swin2SR as net2
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
device_swinir = devices.get_device_for('swinir')
|
||||
|
||||
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "SwinIR"
|
||||
@ -38,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
|
||||
model = self.load_model(model_file)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(devices.device_swinir)
|
||||
model = model.to(device_swinir, dtype=devices.dtype)
|
||||
img = upscale(img, model)
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
@ -90,8 +93,6 @@ class UpscalerSwinIR(Upscaler):
|
||||
model.load_state_dict(pretrained_model[params], strict=True)
|
||||
else:
|
||||
model.load_state_dict(pretrained_model, strict=True)
|
||||
if not cmd_opts.no_half:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
|
||||
@ -107,7 +108,7 @@ def upscale(
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(devices.device_swinir)
|
||||
img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
|
||||
with torch.no_grad(), devices.autocast():
|
||||
_, _, h_old, w_old = img.size()
|
||||
h_pad = (h_old // window_size + 1) * window_size - h_old
|
||||
@ -135,8 +136,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
stride = tile - tile_overlap
|
||||
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir)
|
||||
E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
|
||||
W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
|
||||
|
||||
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
|
||||
for h_idx in h_idx_list:
|
||||
@ -155,3 +156,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
import gradio as gr
|
||||
|
||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
@ -44,6 +44,15 @@ def get_optimal_device():
|
||||
return cpu
|
||||
|
||||
|
||||
def get_device_for(task):
|
||||
from modules import shared
|
||||
|
||||
if task in shared.cmd_opts.use_cpu:
|
||||
return cpu
|
||||
|
||||
return get_optimal_device()
|
||||
|
||||
|
||||
def torch_gc():
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(get_cuda_device_string()):
|
||||
@ -67,7 +76,7 @@ def enable_tf32():
|
||||
errors.run(enable_tf32, "Enabling TF32")
|
||||
|
||||
cpu = torch.device("cpu")
|
||||
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
|
||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||
dtype = torch.float16
|
||||
dtype_vae = torch.float16
|
||||
|
||||
|
@ -8,6 +8,7 @@ from modules import paths, shared
|
||||
|
||||
extensions = []
|
||||
extensions_dir = os.path.join(paths.script_path, "extensions")
|
||||
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
|
||||
|
||||
|
||||
def active():
|
||||
@ -15,12 +16,13 @@ def active():
|
||||
|
||||
|
||||
class Extension:
|
||||
def __init__(self, name, path, enabled=True):
|
||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.enabled = enabled
|
||||
self.status = ''
|
||||
self.can_update = False
|
||||
self.is_builtin = is_builtin
|
||||
|
||||
repo = None
|
||||
try:
|
||||
@ -79,11 +81,19 @@ def list_extensions():
|
||||
if not os.path.isdir(extensions_dir):
|
||||
return
|
||||
|
||||
for dirname in sorted(os.listdir(extensions_dir)):
|
||||
path = os.path.join(extensions_dir, dirname)
|
||||
paths = []
|
||||
for dirname in [extensions_dir, extensions_builtin_dir]:
|
||||
if not os.path.isdir(dirname):
|
||||
return
|
||||
|
||||
for extension_dirname in sorted(os.listdir(dirname)):
|
||||
path = os.path.join(dirname, extension_dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
|
||||
paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
|
||||
|
||||
for dirname, path, is_builtin in paths:
|
||||
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
|
||||
extensions.append(extension)
|
||||
|
||||
|
@ -124,10 +124,9 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
sd = shared.script_path
|
||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||
modules_dir = os.path.join(sd, "modules")
|
||||
modules_dir = os.path.join(shared.script_path, "modules")
|
||||
for file in os.listdir(modules_dir):
|
||||
if "_model.py" in file:
|
||||
model_name = file.replace("_model.py", "")
|
||||
@ -136,22 +135,13 @@ def load_upscalers():
|
||||
importlib.import_module(full_model)
|
||||
except:
|
||||
pass
|
||||
|
||||
datas = []
|
||||
c_o = vars(shared.cmd_opts)
|
||||
commandline_options = vars(shared.cmd_opts)
|
||||
for cls in Upscaler.__subclasses__():
|
||||
name = cls.__name__
|
||||
module_name = cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
class_ = getattr(module, name)
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||
opt_string = None
|
||||
try:
|
||||
if cmd_name in c_o:
|
||||
opt_string = c_o[cmd_name]
|
||||
except:
|
||||
pass
|
||||
scaler = class_(opt_string)
|
||||
for child in scaler.scalers:
|
||||
datas.append(child)
|
||||
scaler = cls(commandline_options.get(cmd_name, None))
|
||||
datas += scaler.scalers
|
||||
|
||||
shared.sd_upscalers = datas
|
||||
|
@ -50,9 +50,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
|
||||
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
|
||||
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
|
||||
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
|
||||
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
@ -61,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
|
||||
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
|
||||
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
|
||||
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
|
||||
@ -95,6 +92,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
|
||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||
|
||||
script_loading.preload_extensions(extensions.extensions_dir, parser)
|
||||
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
|
||||
|
||||
cmd_opts = parser.parse_args()
|
||||
|
||||
@ -112,8 +110,8 @@ restricted_opts = {
|
||||
|
||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
|
||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
|
||||
|
||||
device = devices.device
|
||||
weight_load_location = None if cmd_opts.lowram else "cpu"
|
||||
@ -326,9 +324,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
|
||||
}))
|
||||
|
@ -28,7 +28,6 @@ import modules.codeformer_model
|
||||
import modules.generation_parameters_copypaste as parameters_copypaste
|
||||
import modules.gfpgan_model
|
||||
import modules.hypernetworks.ui
|
||||
import modules.ldsr_model
|
||||
import modules.scripts
|
||||
import modules.shared as shared
|
||||
import modules.styles
|
||||
|
@ -78,6 +78,12 @@ def extension_table():
|
||||
"""
|
||||
|
||||
for ext in extensions.extensions:
|
||||
remote = ""
|
||||
if ext.is_builtin:
|
||||
remote = "built-in"
|
||||
elif ext.remote:
|
||||
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
|
||||
|
||||
if ext.can_update:
|
||||
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
|
||||
else:
|
||||
@ -86,7 +92,7 @@ def extension_table():
|
||||
code += f"""
|
||||
<tr>
|
||||
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
|
||||
<td>{remote}</td>
|
||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||
</tr>
|
||||
"""
|
||||
|
5
webui.py
5
webui.py
@ -53,10 +53,11 @@ def initialize():
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
modelloader.load_upscalers()
|
||||
|
||||
modules.scripts.load_scripts()
|
||||
|
||||
modelloader.load_upscalers()
|
||||
|
||||
modules.sd_vae.refresh_vae_list()
|
||||
modules.sd_models.load_model()
|
||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
||||
@ -177,6 +178,8 @@ def webui():
|
||||
|
||||
print('Reloading custom scripts')
|
||||
modules.scripts.reload_scripts()
|
||||
modelloader.load_upscalers()
|
||||
|
||||
print('Reloading modules: modules.ui')
|
||||
importlib.reload(modules.ui)
|
||||
print('Refreshing Model List')
|
||||
|
Loading…
Reference in New Issue
Block a user