From 294f8a514f982248cda1cafda30d35566f3a0321 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sat, 11 Nov 2023 23:28:12 +0900 Subject: [PATCH 1/9] add hyperTile https://github.com/tfernd/HyperTile --- modules/processing.py | 27 ++++++++++++++++++++++++--- modules/shared_options.py | 2 ++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index b0e240a46..e23095343 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -799,6 +799,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] + unet_object = p.sd_model.model + vae_model = p.sd_model.first_stage_model + try: + from hyper_tile import split_attention, flush + except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df + split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager + flush = lambda: None + import random + saved_rng_state = random.getstate() + random.seed(p.seed) # hyper_tile uses random, so we need to seed it with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): @@ -866,15 +876,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + # get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height + gcd = math.gcd(p.width, p.height) + largest_tile_size_available = 1 + while gcd % (largest_tile_size_available * 2) == 0: + largest_tile_size_available *= 2 + aspect_ratio = p.width / p.height + with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): + with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn): + flush() + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - - x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): + flush() + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -980,6 +1000,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) + random.setstate(saved_rng_state) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) diff --git a/modules/shared_options.py b/modules/shared_options.py index d40db5306..d96502656 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,8 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), + "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From b29fc6d4de8812b25c520a46676cda13c3fe64ca Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Sat, 11 Nov 2023 23:43:13 +0900 Subject: [PATCH 2/9] Implement Hypertile Co-Authored-By: Kieran Hunt --- modules/hypertile.py | 333 ++++++++++++++++++++++++++++++++++++++++++ modules/processing.py | 65 ++++----- 2 files changed, 358 insertions(+), 40 deletions(-) create mode 100644 modules/hypertile.py diff --git a/modules/hypertile.py b/modules/hypertile.py new file mode 100644 index 000000000..ab1c74c02 --- /dev/null +++ b/modules/hypertile.py @@ -0,0 +1,333 @@ +""" +Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE +Warn : The patch works well only if the input image has a width and height that are multiples of 128 +Author : @tfernd Github : https://github.com/tfernd/HyperTile +""" + +from __future__ import annotations +from typing import Callable +from typing_extensions import Literal + +import logging +from functools import wraps, cache +from contextlib import contextmanager + +import math +import torch.nn as nn +import random + +from einops import rearrange + +# TODO add SD-XL layers +DEPTH_LAYERS = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.1.1.transformer_blocks.0.attn1", + "input_blocks.2.1.transformer_blocks.0.attn1", + "output_blocks.9.1.transformer_blocks.0.attn1", + "output_blocks.10.1.transformer_blocks.0.attn1", + "output_blocks.11.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.1.attentions.0.transformer_blocks.0.attn1", + "down_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.0.transformer_blocks.0.attn1", + "up_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.6.1.transformer_blocks.0.attn1", + "output_blocks.7.1.transformer_blocks.0.attn1", + "output_blocks.8.1.transformer_blocks.0.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.2.attentions.0.transformer_blocks.0.attn1", + "down_blocks.2.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.0.transformer_blocks.0.attn1", + "up_blocks.1.attentions.1.transformer_blocks.0.attn1", + "up_blocks.1.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + ], + 3: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + ], +} +# XL layers, thanks for GitHub@gel-crabs for the help +DEPTH_LAYERS_XL = { + 0: [ + # SD 1.5 U-Net (diffusers) + "down_blocks.0.attentions.0.transformer_blocks.0.attn1", + "down_blocks.0.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.0.transformer_blocks.0.attn1", + "up_blocks.3.attentions.1.transformer_blocks.0.attn1", + "up_blocks.3.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.0.attn1", + "input_blocks.5.1.transformer_blocks.0.attn1", + "output_blocks.3.1.transformer_blocks.0.attn1", + "output_blocks.4.1.transformer_blocks.0.attn1", + "output_blocks.5.1.transformer_blocks.0.attn1", + # SD 1.5 VAE + "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", + ], + 1: [ + # SD 1.5 U-Net (diffusers) + #"down_blocks.1.attentions.0.transformer_blocks.0.attn1", + #"down_blocks.1.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.0.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.1.transformer_blocks.0.attn1", + #"up_blocks.2.attentions.2.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "input_blocks.4.1.transformer_blocks.1.attn1", + "input_blocks.5.1.transformer_blocks.1.attn1", + "output_blocks.3.1.transformer_blocks.1.attn1", + "output_blocks.4.1.transformer_blocks.1.attn1", + "output_blocks.5.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.0.attn1", + "input_blocks.8.1.transformer_blocks.0.attn1", + "output_blocks.0.1.transformer_blocks.0.attn1", + "output_blocks.1.1.transformer_blocks.0.attn1", + "output_blocks.2.1.transformer_blocks.0.attn1", + "input_blocks.7.1.transformer_blocks.1.attn1", + "input_blocks.8.1.transformer_blocks.1.attn1", + "output_blocks.0.1.transformer_blocks.1.attn1", + "output_blocks.1.1.transformer_blocks.1.attn1", + "output_blocks.2.1.transformer_blocks.1.attn1", + "input_blocks.7.1.transformer_blocks.2.attn1", + "input_blocks.8.1.transformer_blocks.2.attn1", + "output_blocks.0.1.transformer_blocks.2.attn1", + "output_blocks.1.1.transformer_blocks.2.attn1", + "output_blocks.2.1.transformer_blocks.2.attn1", + "input_blocks.7.1.transformer_blocks.3.attn1", + "input_blocks.8.1.transformer_blocks.3.attn1", + "output_blocks.0.1.transformer_blocks.3.attn1", + "output_blocks.1.1.transformer_blocks.3.attn1", + "output_blocks.2.1.transformer_blocks.3.attn1", + "input_blocks.7.1.transformer_blocks.4.attn1", + "input_blocks.8.1.transformer_blocks.4.attn1", + "output_blocks.0.1.transformer_blocks.4.attn1", + "output_blocks.1.1.transformer_blocks.4.attn1", + "output_blocks.2.1.transformer_blocks.4.attn1", + "input_blocks.7.1.transformer_blocks.5.attn1", + "input_blocks.8.1.transformer_blocks.5.attn1", + "output_blocks.0.1.transformer_blocks.5.attn1", + "output_blocks.1.1.transformer_blocks.5.attn1", + "output_blocks.2.1.transformer_blocks.5.attn1", + "input_blocks.7.1.transformer_blocks.6.attn1", + "input_blocks.8.1.transformer_blocks.6.attn1", + "output_blocks.0.1.transformer_blocks.6.attn1", + "output_blocks.1.1.transformer_blocks.6.attn1", + "output_blocks.2.1.transformer_blocks.6.attn1", + "input_blocks.7.1.transformer_blocks.7.attn1", + "input_blocks.8.1.transformer_blocks.7.attn1", + "output_blocks.0.1.transformer_blocks.7.attn1", + "output_blocks.1.1.transformer_blocks.7.attn1", + "output_blocks.2.1.transformer_blocks.7.attn1", + "input_blocks.7.1.transformer_blocks.8.attn1", + "input_blocks.8.1.transformer_blocks.8.attn1", + "output_blocks.0.1.transformer_blocks.8.attn1", + "output_blocks.1.1.transformer_blocks.8.attn1", + "output_blocks.2.1.transformer_blocks.8.attn1", + "input_blocks.7.1.transformer_blocks.9.attn1", + "input_blocks.8.1.transformer_blocks.9.attn1", + "output_blocks.0.1.transformer_blocks.9.attn1", + "output_blocks.1.1.transformer_blocks.9.attn1", + "output_blocks.2.1.transformer_blocks.9.attn1", + ], + 2: [ + # SD 1.5 U-Net (diffusers) + "mid_block.attentions.0.transformer_blocks.0.attn1", + # SD 1.5 U-Net (ldm) + "middle_block.1.transformer_blocks.0.attn1", + "middle_block.1.transformer_blocks.1.attn1", + "middle_block.1.transformer_blocks.2.attn1", + "middle_block.1.transformer_blocks.3.attn1", + "middle_block.1.transformer_blocks.4.attn1", + "middle_block.1.transformer_blocks.5.attn1", + "middle_block.1.transformer_blocks.6.attn1", + "middle_block.1.transformer_blocks.7.attn1", + "middle_block.1.transformer_blocks.8.attn1", + "middle_block.1.transformer_blocks.9.attn1", + ], +} + + +RNG_INSTANCE = random.Random() + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + """ + Returns a random divisor of value that + x * min_value <= value + if max_options is 1, the behavior is deterministic + """ + min_value = min(min_value, value) + + # All big divisors of value (inclusive) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order + + ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order + + idx = RNG_INSTANCE.randint(0, len(ns) - 1) + + return ns[idx] + +def set_hypertile_seed(seed: int) -> None: + RNG_INSTANCE.seed(seed) + +def largest_tile_size_available(width:int, height:int) -> int: + """ + Calculates the largest tile size available for a given width and height + Tile size is always a power of 2 + """ + gcd = math.gcd(width, height) + largest_tile_size_available = 1 + while gcd % (largest_tile_size_available * 2) == 0: + largest_tile_size_available *= 2 + return largest_tile_size_available + +def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + We check all possible divisors of hw and return the closest to the aspect ratio + """ + divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw + pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw + ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw + closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio + closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio + return closest_pair + +@cache +def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]: + """ + Finds h and w such that h*w = hw and h/w = aspect_ratio + """ + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + # find h and w such that h*w = hw and h/w = aspect_ratio + if h * w != hw: + w_candidate = hw / h + # check if w is an integer + if not w_candidate.is_integer(): + h_candidate = hw / w + # check if h is an integer + if not h_candidate.is_integer(): + return iterative_closest_divisors(hw, aspect_ratio) + else: + h = int(h_candidate) + else: + w = int(w_candidate) + return h, w + +@contextmanager +def split_attention( + layer: nn.Module, + /, + aspect_ratio: float, # width/height + tile_size: int = 128, # 128 for VAE + swap_size: int = 1, # 1 for VAE + *, + disable: bool = False, + max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1 + scale_depth: bool = True, # scale the tile-size depending on the depth + is_sdxl: bool = False, # is the model SD-XL +): + # Hijacks AttnBlock from ldm and Attention from diffusers + + if disable: + logging.info(f"Attention for {layer.__class__.__qualname__} not splitted") + yield + return + + latent_tile_size = max(128, tile_size) // 8 + + def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable: + @wraps(forward) + def wrapper(*args, **kwargs): + x = args[0] + + # VAE + if x.ndim == 4: + b, c, h, w = x.shape + + nh = random_divisor(h, latent_tile_size, swap_size) + nw = random_divisor(w, latent_tile_size, swap_size) + + if nh * nw > 1: + x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles + + out = forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) + + # U-Net + else: + hw: int = x.size(1) + h, w = find_hw_candidates(hw, aspect_ratio) + assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}" + + factor = 2**depth if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) + + module._split_sizes_hypertile.append((nh, nw)) # type: ignore + + if nh * nw > 1: + x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + + out = forward(x, *args[1:], **kwargs) + + if nh * nw > 1: + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + + return out + + return wrapper + + # Handle hijacking the forward method and recovering afterwards + try: + if is_sdxl: + layers = DEPTH_LAYERS_XL + else: + layers = DEPTH_LAYERS + for depth in range(max_depth + 1): + for layer_name, module in layer.named_modules(): + if any(layer_name.endswith(try_name) for try_name in layers[depth]): + # print input shape for debugging + logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}") + # hijack + module._original_forward_hypertile = module.forward + module.forward = self_attn_forward(module.forward, depth, layer_name, module) + module._split_sizes_hypertile = [] + yield + finally: + for layer_name, module in layer.named_modules(): + # remove hijack + if hasattr(module, "_original_forward_hypertile"): + if module._split_sizes_hypertile: + logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})") + # recover + module.forward = module._original_forward_hypertile + del module._original_forward_hypertile + del module._split_sizes_hypertile diff --git a/modules/processing.py b/modules/processing.py index e23095343..e19a09a3c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths import modules.face_restoration +from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available import modules.images as images import modules.styles import modules.sd_models as sd_models @@ -799,17 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] - unet_object = p.sd_model.model - vae_model = p.sd_model.first_stage_model - try: - from hyper_tile import split_attention, flush - except (ImportError, ModuleNotFoundError): # pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df - split_attention = lambda *args, **kwargs: lambda x: x # return a no-op context manager - flush = lambda: None - import random - saved_rng_state = random.getstate() - random.seed(p.seed) # hyper_tile uses random, so we need to seed it - with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -871,29 +861,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.comment(comment) p.extra_generation_params.update(model_hijack.extra_generation_params) - + set_hypertile_seed(p.seed) + # add batch size + hypertile status to information to reproduce the run if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): - # get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height - gcd = math.gcd(p.width, p.height) - largest_tile_size_available = 1 - while gcd % (largest_tile_size_available * 2) == 0: - largest_tile_size_available *= 2 - aspect_ratio = p.width / p.height - with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): - with split_attention(unet_object, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn): - flush() - samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): x_samples_ddim = samples_ddim else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with split_attention(vae_model, aspect_ratio=aspect_ratio, tile_size=min(largest_tile_size_available, 128), disable=not shared.opts.hypertile_split_vae_attn): - flush() + with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1000,7 +981,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) - random.setstate(saved_rng_state) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) @@ -1161,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - + aspect_ratio = self.width / self.height x = self.rng.next() - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + tile_size = largest_tile_size_available(self.width, self.height) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + devices.torch_gc() + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x - if not self.enable_hr: return samples if self.latent_scale_mode is None: - decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: decoded_samples = None with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=self.hr_checkpoint_info) - - devices.torch_gc() - return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): @@ -1186,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples self.is_hr_pass = True - target_width = self.hr_upscale_to_x target_height = self.hr_upscale_to_y @@ -1264,18 +1244,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scripts is not None: self.scripts.before_hr(self) - - samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + tile_size = largest_tile_size_available(target_width, target_height) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.sampler = None devices.torch_gc() - - decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False - return decoded_samples def close(self): @@ -1550,8 +1531,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier x *= self.initial_noise_multiplier - - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) + aspect_ratio = self.width / self.height + tile_size = largest_tile_size_available(self.width, self.height) + with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + devices.torch_gc() + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask From af45872fdb8a66ffd6a405d99120e0bacbb4a170 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:15:14 +0900 Subject: [PATCH 3/9] copy LDM VAE key from XL --- modules/hypertile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/hypertile.py b/modules/hypertile.py index ab1c74c02..32d8604cc 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -35,6 +35,7 @@ DEPTH_LAYERS = { "output_blocks.11.1.transformer_blocks.0.attn1", # SD 1.5 VAE "decoder.mid_block.attentions.0", + "decoder.mid.attn_1", ], 1: [ # SD 1.5 U-Net (diffusers) From bcfaf3979a9f93e37c418b58c75b02d9570b4354 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Thu, 16 Nov 2023 18:43:16 +0900 Subject: [PATCH 4/9] convert/add hypertile options --- modules/hypertile.py | 36 ++++++++++++++++++++++++++++++++++++ modules/processing.py | 21 +++++++++++---------- modules/shared_options.py | 6 ++++++ 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/modules/hypertile.py b/modules/hypertile.py index 32d8604cc..fee24a8ca 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -332,3 +332,39 @@ def split_attention( module.forward = module._original_forward_hypertile del module._original_forward_hypertile del module._split_sizes_hypertile + +def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts): + """ + Returns context manager for VAE + """ + enabled = not opts.hypertile_split_vae_attn + swap_size = opts.hypertile_swap_size_vae + max_depth = opts.hypertile_max_depth_vae + tile_size_max = opts.hypertile_max_tile_vae + return split_attention( + model, + aspect_ratio=aspect_ratio, + tile_size=min(tile_size, tile_size_max), + swap_size=swap_size, + disable=not enabled, + max_depth=max_depth, + is_sdxl=False, + ) + +def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool): + """ + Returns context manager for U-Net + """ + enabled = not opts.hypertile_split_unet_attn + swap_size = opts.hypertile_swap_size_unet + max_depth = opts.hypertile_max_depth_unet + tile_size_max = opts.hypertile_max_tile_unet + return split_attention( + model, + aspect_ratio=aspect_ratio, + tile_size=min(tile_size, tile_size_max), + swap_size=swap_size, + disable=not enabled, + max_depth=max_depth, + is_sdxl=is_sdxl, + ) \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index e19a09a3c..c622ff337 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -24,7 +24,7 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths import modules.face_restoration -from modules.hypertile import split_attention, set_hypertile_seed, largest_tile_size_available +from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae import modules.images as images import modules.styles import modules.sd_models as sd_models @@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with split_attention(p.sd_model.first_stage_model, aspect_ratio = p.width / p.height, tile_size=min(largest_tile_size_available(p.width, p.height), 128), disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_unet(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1144,8 +1144,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): aspect_ratio = self.width / self.height x = self.rng.next() tile_size = largest_tile_size_available(self.width, self.height) - with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): - with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): + with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1153,7 +1153,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples if self.latent_scale_mode is None: - with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) else: decoded_samples = None @@ -1245,15 +1245,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scripts is not None: self.scripts.before_hr(self) tile_size = largest_tile_size_available(target_width, target_height) - with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): - with split_attention(self.sd_model.model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=3, max_depth=1,scale_depth=True, disable=not opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + aspect_ratio = self.width / self.height + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): + with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) self.sampler = None devices.torch_gc() - with split_attention(self.sd_model.first_stage_model, aspect_ratio=target_width / target_height, tile_size=min(tile_size, 256), swap_size=1, disable=not opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) self.is_hr_pass = False @@ -1533,8 +1534,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): x *= self.initial_noise_multiplier aspect_ratio = self.width / self.height tile_size = largest_tile_size_available(self.width, self.height) - with split_attention(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 128), swap_size=1, disable=not shared.opts.hypertile_split_vae_attn, is_sdxl=shared.sd_model.is_sdxl): - with split_attention(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=min(tile_size, 256), swap_size=2, disable=not shared.opts.hypertile_split_unet_attn, is_sdxl=shared.sd_model.is_sdxl): + with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): + with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) diff --git a/modules/shared_options.py b/modules/shared_options.py index d96502656..28a489069 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -202,6 +202,12 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"), + "hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), + "hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { From 472c22cc8a46b825545d5c86bd2745269430d7b0 Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Thu, 16 Nov 2023 19:03:45 +0900 Subject: [PATCH 5/9] fix ruff - add newline --- modules/hypertile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/hypertile.py b/modules/hypertile.py index fee24a8ca..86acecdc0 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -367,4 +367,4 @@ def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, o disable=not enabled, max_depth=max_depth, is_sdxl=is_sdxl, - ) \ No newline at end of file + ) From c40be2252ab1c8c289562db208c5ac6618bd8545 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:22:27 +0900 Subject: [PATCH 6/9] Fix critical issue - unet apply --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c622ff337..2fda7f332 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with hypertile_context_unet(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(p.sd_model.model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1145,7 +1145,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = self.rng.next() tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1247,7 +1247,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): tile_size = largest_tile_size_available(target_width, target_height) aspect_ratio = self.width / self.height with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) @@ -1535,7 +1535,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): aspect_ratio = self.width / self.height tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): - with hypertile_context_unet(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): devices.torch_gc() samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) From c0725ba2d098a6a78610e7d96ee75f63a32d4e52 Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:34:50 +0900 Subject: [PATCH 7/9] Fix inverted option issue I'm pretty sure I was sleepy while implementing this --- modules/hypertile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/hypertile.py b/modules/hypertile.py index 86acecdc0..3a1468c6b 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -337,7 +337,7 @@ def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, op """ Returns context manager for VAE """ - enabled = not opts.hypertile_split_vae_attn + enabled = opts.hypertile_split_vae_attn swap_size = opts.hypertile_swap_size_vae max_depth = opts.hypertile_max_depth_vae tile_size_max = opts.hypertile_max_tile_vae @@ -355,7 +355,7 @@ def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, o """ Returns context manager for U-Net """ - enabled = not opts.hypertile_split_unet_attn + enabled = opts.hypertile_split_unet_attn swap_size = opts.hypertile_swap_size_unet max_depth = opts.hypertile_max_depth_unet tile_size_max = opts.hypertile_max_tile_unet From ffd0f8ddc309688636ac1ac10d82b72ab6b466bf Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 09:54:33 +0900 Subject: [PATCH 8/9] set empty value for SD XL 3rd layer --- modules/hypertile.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/hypertile.py b/modules/hypertile.py index 3a1468c6b..be898fce4 100644 --- a/modules/hypertile.py +++ b/modules/hypertile.py @@ -170,6 +170,7 @@ DEPTH_LAYERS_XL = { "middle_block.1.transformer_blocks.8.attn1", "middle_block.1.transformer_blocks.9.attn1", ], + 3 : [] # TODO - separate layers for SD-XL } From 97431f29feb17ffc96ca95e9b3efec87be9d8b3a Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Fri, 17 Nov 2023 10:05:28 +0900 Subject: [PATCH 9/9] fix double gc and decoding with unet context --- modules/processing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 2fda7f332..36c2be5e5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: if opts.sd_vae_decode_method != 'Full': p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method - with hypertile_context_unet(p.sd_model.model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): + with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts): x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() @@ -1146,11 +1146,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - devices.torch_gc() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x if not self.enable_hr: return samples + devices.torch_gc() if self.latent_scale_mode is None: with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): @@ -1536,7 +1536,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): tile_size = largest_tile_size_available(self.width, self.height) with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts): with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts): - devices.torch_gc() samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: