This commit is contained in:
Chenlei Hu 2024-05-17 17:33:25 +00:00 committed by GitHub
commit 28fd9ba7a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 146 additions and 61 deletions

View File

@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False
first_stage_config:

View File

@ -41,7 +41,7 @@ model:
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: True
use_checkpoint: False
legacy: False
first_stage_config:

View File

@ -45,7 +45,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False
first_stage_config:

View File

@ -21,7 +21,7 @@ model:
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
use_checkpoint: False
in_channels: 9
out_channels: 4
model_channels: 320

View File

@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False
first_stage_config:

View File

@ -40,7 +40,7 @@ model:
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
use_checkpoint: False
legacy: False
first_stage_config:

View File

@ -378,13 +378,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None:
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
else:
bias_backup = None
# Unlike weight which always has value, some modules don't have bias.
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
self.network_bias_backup = bias_backup
if current_names != wanted_names:

View File

@ -41,7 +41,7 @@ parser.add_argument("--lowvram", action='store_true', help="enable stable diffus
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "half", "autocast"], default="autocast")
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
@ -69,7 +69,8 @@ parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefe
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--disable-nan-check", action='store_true', help="[Deprecated] do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--enable-nan-check", action='store_true', help="Check if produced images/latent spaces have nans at extra performance cost. (~20ms/it)")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")

View File

@ -114,6 +114,9 @@ errors.run(enable_tf32, "Enabling TF32")
cpu: torch.device = torch.device("cpu")
fp8: bool = False
# Force fp16 for all models in inference. No casting during inference.
# This flag is controlled by "--precision half" command line arg.
force_fp16: bool = False
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
@ -127,6 +130,8 @@ unet_needs_upcast = False
def cond_cast_unet(input):
if force_fp16:
return input.to(torch.float16)
return input.to(dtype_unet) if unet_needs_upcast else input
@ -206,6 +211,11 @@ def autocast(disable=False):
if disable:
return contextlib.nullcontext()
if force_fp16:
# No casting during inference if force_fp16 is enabled.
# All tensor dtype conversion happens before inference.
return contextlib.nullcontext()
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
@ -230,7 +240,7 @@ class NansException(Exception):
def test_for_nans(x, where):
if shared.cmd_opts.disable_nan_check:
if not shared.cmd_opts.enable_nan_check:
return
if not torch.all(torch.isnan(x)).item():
@ -250,8 +260,6 @@ def test_for_nans(x, where):
else:
message = "A tensor with all NaNs was produced."
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
@ -269,3 +277,17 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
def force_model_fp16():
"""
ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which
force conversion of input to float32. If force_fp16 is enabled, we need to
prevent this casting.
"""
assert force_fp16
import sgm.modules.diffusionmodules.util as sgm_util
import ldm.modules.diffusionmodules.util as ldm_util
sgm_util.GroupNorm32 = torch.nn.GroupNorm
ldm_util.GroupNorm32 = torch.nn.GroupNorm
print("ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.")

View File

@ -440,6 +440,10 @@ def prepare_environment():
git_pull_recursive(extensions_dir)
startup_timer.record("update extensions")
if args.disable_nan_check:
print("Nan check disabled by default. --disable-nan-check argument is now ignored. "
"Use --enable-nan-check to re-enable nan check.")
if "--exit" in sys.argv:
print("Exiting because of --exit argument")
exit(0)
@ -454,8 +458,8 @@ def configure_for_tests():
sys.argv.append(os.path.join(script_path, "test/test_files/empty.pt"))
if "--skip-torch-cuda-test" not in sys.argv:
sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check")
if "--enable-nan-check" in sys.argv:
sys.argv.remove("--enable-nan-check")
os.environ['COMMANDLINE_ARGS'] = ""

View File

@ -115,20 +115,17 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)
else:
sd = sd_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))
if getattr(sd_model.model, "is_sdxl_inpaint", False):
# The "masked-image" in this case will just be all 0.5 since the entire image is masked.
image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5
image_conditioning = images_tensor_to_samples(image_conditioning,
approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
return image_conditioning
# Dummy zero conditioning if we're not using inpainting or unclip models.
# Still takes up a bit of memory, but no encoder call.
@ -390,11 +387,8 @@ class StableDiffusionProcessing:
if self.sampler.conditioning_key == "crossattn-adm":
return self.unclip_image_conditioning(source_image)
sd = self.sampler.model_wrap.inner_model.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
if getattr(self.sampler.model_wrap.inner_model.model, "is_sdxl_inpaint", False):
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
# Dummy zero conditioning if we're not using inpainting or depth model.
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

View File

@ -4,16 +4,19 @@ import ldm.modules.attention
import ldm.modules.diffusionmodules.openaimodel
# Setting flag=False so that torch skips checking parameters.
# parameters checking is expensive in frequent operations.
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
return checkpoint(self._forward, x, context, flag=False)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
return checkpoint(self._forward, x, flag=False)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
return checkpoint(self._forward, x, emb, flag=False)
stored = []

View File

@ -486,7 +486,19 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
def _reshape(t):
"""rearrange(t, 'b n (h d) -> b n h d', h=h).
Using torch native operations to avoid overhead as this function is
called frequently. (70 times/it for SDXL)
"""
b, n, _ = t.shape # Get the batch size (b) and sequence length (n)
d = t.shape[2] // h # Determine the depth per head
return t.reshape(b, n, h, d)
q = _reshape(q_in)
k = _reshape(k_in)
v = _reshape(v_in)
del q_in, k_in, v_in
dtype = q.dtype
@ -497,7 +509,9 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
out = out.to(dtype)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# out = rearrange(out, 'b n h d -> b n (h d)', h=h)
b, n, h, d = out.shape
out = out.reshape(b, n, h * d)
return self.to_out(out)

View File

@ -36,7 +36,7 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
"""Always make sure inputs to unet are in correct dtype."""
if isinstance(cond, dict):
for y in cond.keys():
if isinstance(cond[y], list):
@ -45,7 +45,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)
if devices.unet_needs_upcast:
return result.float()
else:
return result
class GELUHijack(torch.nn.GELU, torch.nn.Module):
@ -64,12 +68,11 @@ def hijack_ddpm_edit():
if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available():
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
@ -81,5 +84,17 @@ CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_s
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)
def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):
if devices.unet_needs_upcast and timesteps.dtype == torch.int64:
dtype = torch.float32
else:
dtype = devices.dtype_unet
return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)

View File

@ -1,7 +1,11 @@
import importlib
always_true_func = lambda *args, **kwargs: True
class CondFunc:
def __new__(cls, orig_func, sub_func, cond_func):
def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
@ -20,13 +24,13 @@ class CondFunc:
print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
pass
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)

View File

@ -380,6 +380,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
# Set is_sdxl_inpaint flag.
diffusion_model_input = state_dict.get('diffusion_model.input_blocks.0.0.weight', None)
model.is_sdxl_inpaint = (
model.is_sdxl and
diffusion_model_input is not None and
diffusion_model_input.shape[1] == 9
)
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@ -403,6 +410,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
assert shared.cmd_opts.precision != "half", "Cannot use --precision half with --no-half"
timer.record("apply float()")
else:
vae = model.first_stage_model
@ -540,7 +548,7 @@ def repair_config(sd_config):
if hasattr(sd_config.model.params, 'unet_config'):
if shared.cmd_opts.no_half:
sd_config.model.params.unet_config.params.use_fp16 = False
elif shared.cmd_opts.upcast_sampling:
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
@ -551,6 +559,14 @@ def repair_config(sd_config):
karlo_path = os.path.join(paths.models_path, 'karlo')
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
# Do not use checkpoint for inference.
# This helps prevent extra performance overhead on checking parameters.
# The perf overhead is about 100ms/it on 4090 for SDXL.
if hasattr(sd_config.model.params, "network_config"):
sd_config.model.params.network_config.params.use_checkpoint = False
if hasattr(sd_config.model.params, "unet_config"):
sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt()

View File

@ -35,7 +35,7 @@ def is_using_v_parameterization_for_sd2(state_dict):
with sd_disable_initialization.DisableInitialization():
unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(
use_checkpoint=True,
use_checkpoint=False,
use_fp16=False,
image_size=32,
in_channels=4,

View File

@ -35,11 +35,10 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
sd = self.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
"""WARNING: This function is called once per denoising iteration. DO NOT add
expensive functionc calls such as `model.state_dict`. """
if self.is_sdxl_inpaint:
x = torch.cat([x] + cond['c_concat'], dim=1)
return self.model(x, t, cond)

View File

@ -31,6 +31,14 @@ def initialize():
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
if cmd_opts.precision == "half":
msg = "--no-half and --no-half-vae conflict with --precision half"
assert devices.dtype == torch.float16, msg
assert devices.dtype_vae == torch.float16, msg
assert devices.dtype_inference == torch.float16, msg
devices.force_fp16 = True
devices.force_model_fp16()
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"