diff --git a/modules/devices.py b/modules/devices.py index 0cd2b55d2..c05f2b35e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -16,6 +16,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -60,8 +77,7 @@ def enable_tf32(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -92,15 +108,44 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + +@contextlib.contextmanager +def manual_autocast(): + def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + for module_type in patch_module_list: + org_forward = module_type.forward + module_type.forward = manual_cast_forward + module_type.org_forward = org_forward + try: + yield None + finally: + for module_type in patch_module_list: + module_type.forward = module_type.org_forward -def autocast(disable=False, unet=False): +def autocast(disable=False): + print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() - if unet and fp8 and device==cpu: + if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 2df8a7ea7..40598f5cf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) if getattr(samples_ddim, 'already_decoded', False): diff --git a/modules/sd_models.py b/modules/sd_models.py index ccb6afd2a..31bcb913a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if enable_fp8: devices.fp8 = True + if model.is_sdxl: + cond_stage = model.conditioner + else: + cond_stage = model.cond_stage_model + + for module in cond_stage.modules(): + if isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + if devices.device == devices.cpu: for module in model.model.diffusion_model.modules(): if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) elif isinstance(module, torch.nn.Linear): module.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet for cpu") else: - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) - timer.record("apply fp8 unet") + timer.record("apply fp8") + else: + devices.fp8 = False devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16