mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #13948 from aria1th/hypertile-in-sample
support HyperTile optimization
This commit is contained in:
commit
fd8674a4bc
371
modules/hypertile.py
Normal file
371
modules/hypertile.py
Normal file
@ -0,0 +1,371 @@
|
||||
"""
|
||||
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
|
||||
Warn : The patch works well only if the input image has a width and height that are multiples of 128
|
||||
Author : @tfernd Github : https://github.com/tfernd/HyperTile
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
from typing_extensions import Literal
|
||||
|
||||
import logging
|
||||
from functools import wraps, cache
|
||||
from contextlib import contextmanager
|
||||
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import random
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
# TODO add SD-XL layers
|
||||
DEPTH_LAYERS = {
|
||||
0: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.1.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.9.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.10.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.11.1.transformer_blocks.0.attn1",
|
||||
# SD 1.5 VAE
|
||||
"decoder.mid_block.attentions.0",
|
||||
"decoder.mid.attn_1",
|
||||
],
|
||||
1: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.6.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.7.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.8.1.transformer_blocks.0.attn1",
|
||||
],
|
||||
2: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||
],
|
||||
3: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"middle_block.1.transformer_blocks.0.attn1",
|
||||
],
|
||||
}
|
||||
# XL layers, thanks for GitHub@gel-crabs for the help
|
||||
DEPTH_LAYERS_XL = {
|
||||
0: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
|
||||
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
|
||||
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.5.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.3.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.4.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.5.1.transformer_blocks.0.attn1",
|
||||
# SD 1.5 VAE
|
||||
"decoder.mid_block.attentions.0",
|
||||
"decoder.mid.attn_1",
|
||||
],
|
||||
1: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
|
||||
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
|
||||
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
|
||||
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
|
||||
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"input_blocks.4.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.5.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.3.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.4.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.5.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.0.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.0.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.1.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.1.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.2.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.2.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.2.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.2.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.2.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.3.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.3.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.3.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.3.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.3.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.4.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.4.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.4.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.4.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.4.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.5.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.5.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.5.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.5.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.5.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.6.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.6.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.6.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.6.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.6.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.7.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.7.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.7.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.7.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.7.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.8.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.8.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.8.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.8.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.8.attn1",
|
||||
"input_blocks.7.1.transformer_blocks.9.attn1",
|
||||
"input_blocks.8.1.transformer_blocks.9.attn1",
|
||||
"output_blocks.0.1.transformer_blocks.9.attn1",
|
||||
"output_blocks.1.1.transformer_blocks.9.attn1",
|
||||
"output_blocks.2.1.transformer_blocks.9.attn1",
|
||||
],
|
||||
2: [
|
||||
# SD 1.5 U-Net (diffusers)
|
||||
"mid_block.attentions.0.transformer_blocks.0.attn1",
|
||||
# SD 1.5 U-Net (ldm)
|
||||
"middle_block.1.transformer_blocks.0.attn1",
|
||||
"middle_block.1.transformer_blocks.1.attn1",
|
||||
"middle_block.1.transformer_blocks.2.attn1",
|
||||
"middle_block.1.transformer_blocks.3.attn1",
|
||||
"middle_block.1.transformer_blocks.4.attn1",
|
||||
"middle_block.1.transformer_blocks.5.attn1",
|
||||
"middle_block.1.transformer_blocks.6.attn1",
|
||||
"middle_block.1.transformer_blocks.7.attn1",
|
||||
"middle_block.1.transformer_blocks.8.attn1",
|
||||
"middle_block.1.transformer_blocks.9.attn1",
|
||||
],
|
||||
3 : [] # TODO - separate layers for SD-XL
|
||||
}
|
||||
|
||||
|
||||
RNG_INSTANCE = random.Random()
|
||||
|
||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||
"""
|
||||
Returns a random divisor of value that
|
||||
x * min_value <= value
|
||||
if max_options is 1, the behavior is deterministic
|
||||
"""
|
||||
min_value = min(min_value, value)
|
||||
|
||||
# All big divisors of value (inclusive)
|
||||
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
|
||||
|
||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
|
||||
|
||||
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
|
||||
|
||||
return ns[idx]
|
||||
|
||||
def set_hypertile_seed(seed: int) -> None:
|
||||
RNG_INSTANCE.seed(seed)
|
||||
|
||||
def largest_tile_size_available(width:int, height:int) -> int:
|
||||
"""
|
||||
Calculates the largest tile size available for a given width and height
|
||||
Tile size is always a power of 2
|
||||
"""
|
||||
gcd = math.gcd(width, height)
|
||||
largest_tile_size_available = 1
|
||||
while gcd % (largest_tile_size_available * 2) == 0:
|
||||
largest_tile_size_available *= 2
|
||||
return largest_tile_size_available
|
||||
|
||||
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||
"""
|
||||
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||
We check all possible divisors of hw and return the closest to the aspect ratio
|
||||
"""
|
||||
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
|
||||
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
|
||||
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
|
||||
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
|
||||
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
|
||||
return closest_pair
|
||||
|
||||
@cache
|
||||
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
|
||||
"""
|
||||
Finds h and w such that h*w = hw and h/w = aspect_ratio
|
||||
"""
|
||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||
# find h and w such that h*w = hw and h/w = aspect_ratio
|
||||
if h * w != hw:
|
||||
w_candidate = hw / h
|
||||
# check if w is an integer
|
||||
if not w_candidate.is_integer():
|
||||
h_candidate = hw / w
|
||||
# check if h is an integer
|
||||
if not h_candidate.is_integer():
|
||||
return iterative_closest_divisors(hw, aspect_ratio)
|
||||
else:
|
||||
h = int(h_candidate)
|
||||
else:
|
||||
w = int(w_candidate)
|
||||
return h, w
|
||||
|
||||
@contextmanager
|
||||
def split_attention(
|
||||
layer: nn.Module,
|
||||
/,
|
||||
aspect_ratio: float, # width/height
|
||||
tile_size: int = 128, # 128 for VAE
|
||||
swap_size: int = 1, # 1 for VAE
|
||||
*,
|
||||
disable: bool = False,
|
||||
max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1
|
||||
scale_depth: bool = True, # scale the tile-size depending on the depth
|
||||
is_sdxl: bool = False, # is the model SD-XL
|
||||
):
|
||||
# Hijacks AttnBlock from ldm and Attention from diffusers
|
||||
|
||||
if disable:
|
||||
logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
|
||||
yield
|
||||
return
|
||||
|
||||
latent_tile_size = max(128, tile_size) // 8
|
||||
|
||||
def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
|
||||
@wraps(forward)
|
||||
def wrapper(*args, **kwargs):
|
||||
x = args[0]
|
||||
|
||||
# VAE
|
||||
if x.ndim == 4:
|
||||
b, c, h, w = x.shape
|
||||
|
||||
nh = random_divisor(h, latent_tile_size, swap_size)
|
||||
nw = random_divisor(w, latent_tile_size, swap_size)
|
||||
|
||||
if nh * nw > 1:
|
||||
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
|
||||
|
||||
out = forward(x, *args[1:], **kwargs)
|
||||
|
||||
if nh * nw > 1:
|
||||
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
|
||||
|
||||
# U-Net
|
||||
else:
|
||||
hw: int = x.size(1)
|
||||
h, w = find_hw_candidates(hw, aspect_ratio)
|
||||
assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
|
||||
|
||||
factor = 2**depth if scale_depth else 1
|
||||
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||
|
||||
module._split_sizes_hypertile.append((nh, nw)) # type: ignore
|
||||
|
||||
if nh * nw > 1:
|
||||
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||
|
||||
out = forward(x, *args[1:], **kwargs)
|
||||
|
||||
if nh * nw > 1:
|
||||
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
|
||||
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
|
||||
|
||||
return out
|
||||
|
||||
return wrapper
|
||||
|
||||
# Handle hijacking the forward method and recovering afterwards
|
||||
try:
|
||||
if is_sdxl:
|
||||
layers = DEPTH_LAYERS_XL
|
||||
else:
|
||||
layers = DEPTH_LAYERS
|
||||
for depth in range(max_depth + 1):
|
||||
for layer_name, module in layer.named_modules():
|
||||
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
|
||||
# print input shape for debugging
|
||||
logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
|
||||
# hijack
|
||||
module._original_forward_hypertile = module.forward
|
||||
module.forward = self_attn_forward(module.forward, depth, layer_name, module)
|
||||
module._split_sizes_hypertile = []
|
||||
yield
|
||||
finally:
|
||||
for layer_name, module in layer.named_modules():
|
||||
# remove hijack
|
||||
if hasattr(module, "_original_forward_hypertile"):
|
||||
if module._split_sizes_hypertile:
|
||||
logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
|
||||
# recover
|
||||
module.forward = module._original_forward_hypertile
|
||||
del module._original_forward_hypertile
|
||||
del module._split_sizes_hypertile
|
||||
|
||||
def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
|
||||
"""
|
||||
Returns context manager for VAE
|
||||
"""
|
||||
enabled = opts.hypertile_split_vae_attn
|
||||
swap_size = opts.hypertile_swap_size_vae
|
||||
max_depth = opts.hypertile_max_depth_vae
|
||||
tile_size_max = opts.hypertile_max_tile_vae
|
||||
return split_attention(
|
||||
model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
tile_size=min(tile_size, tile_size_max),
|
||||
swap_size=swap_size,
|
||||
disable=not enabled,
|
||||
max_depth=max_depth,
|
||||
is_sdxl=False,
|
||||
)
|
||||
|
||||
def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
|
||||
"""
|
||||
Returns context manager for U-Net
|
||||
"""
|
||||
enabled = opts.hypertile_split_unet_attn
|
||||
swap_size = opts.hypertile_swap_size_unet
|
||||
max_depth = opts.hypertile_max_depth_unet
|
||||
tile_size_max = opts.hypertile_max_tile_unet
|
||||
return split_attention(
|
||||
model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
tile_size=min(tile_size, tile_size_max),
|
||||
swap_size=swap_size,
|
||||
disable=not enabled,
|
||||
max_depth=max_depth,
|
||||
is_sdxl=is_sdxl,
|
||||
)
|
@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
import modules.paths as paths
|
||||
import modules.face_restoration
|
||||
from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae
|
||||
import modules.images as images
|
||||
import modules.styles
|
||||
import modules.sd_models as sd_models
|
||||
@ -799,7 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
infotexts = []
|
||||
output_images = []
|
||||
|
||||
with torch.no_grad(), p.sd_model.ema_scope():
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
@ -861,7 +861,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.comment(comment)
|
||||
|
||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
||||
|
||||
set_hypertile_seed(p.seed)
|
||||
# add batch size + hypertile status to information to reproduce the run
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
@ -873,8 +874,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
else:
|
||||
if opts.sd_vae_decode_method != 'Full':
|
||||
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
|
||||
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts):
|
||||
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
x_samples_ddim = torch.stack(x_samples_ddim).float()
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
@ -1140,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
|
||||
|
||||
aspect_ratio = self.width / self.height
|
||||
x = self.rng.next()
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
tile_size = largest_tile_size_available(self.width, self.height)
|
||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
||||
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
del x
|
||||
|
||||
if not self.enable_hr:
|
||||
return samples
|
||||
devices.torch_gc()
|
||||
|
||||
if self.latent_scale_mode is None:
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
||||
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
|
||||
else:
|
||||
decoded_samples = None
|
||||
|
||||
with sd_models.SkipWritingToConfig():
|
||||
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
|
||||
|
||||
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
|
||||
@ -1165,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
return samples
|
||||
|
||||
self.is_hr_pass = True
|
||||
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
|
||||
@ -1243,18 +1244,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
|
||||
if self.scripts is not None:
|
||||
self.scripts.before_hr(self)
|
||||
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
tile_size = largest_tile_size_available(target_width, target_height)
|
||||
aspect_ratio = self.width / self.height
|
||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
||||
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
|
||||
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||
|
||||
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||
|
||||
self.sampler = None
|
||||
devices.torch_gc()
|
||||
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
||||
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
|
||||
|
||||
self.is_hr_pass = False
|
||||
|
||||
return decoded_samples
|
||||
|
||||
def close(self):
|
||||
@ -1529,8 +1532,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
if self.initial_noise_multiplier != 1.0:
|
||||
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
|
||||
x *= self.initial_noise_multiplier
|
||||
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
aspect_ratio = self.width / self.height
|
||||
tile_size = largest_tile_size_available(self.width, self.height)
|
||||
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
|
||||
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
|
||||
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
||||
|
||||
if self.mask is not None:
|
||||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
|
@ -201,6 +201,14 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
|
||||
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
|
||||
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
|
||||
"hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
|
||||
"hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
|
||||
"hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
||||
"hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
||||
"hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
|
||||
"hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
|
||||
"hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
||||
"hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||
|
Loading…
Reference in New Issue
Block a user