ManualCast for 10/16 series gpu

This commit is contained in:
Kohaku-Blueleaf 2023-10-28 15:24:26 +08:00
parent 0beb131c7f
commit d4d3134f6d
3 changed files with 62 additions and 14 deletions

View File

@ -16,6 +16,23 @@ def has_mps() -> bool:
return mac_specific.has_mps
def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id():
return (
int(shared.cmd_opts.device_id)
if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
def get_cuda_device_string():
if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}"
@ -60,8 +77,7 @@ def enable_tf32():
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device()
if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"):
if cuda_no_autocast():
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@ -92,15 +108,44 @@ def cond_cast_float(input):
nv_rng = None
patch_module_list = [
torch.nn.Linear,
torch.nn.Conv2d,
torch.nn.MultiheadAttention,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
]
@contextlib.contextmanager
def manual_autocast():
def manual_cast_forward(self, *args, **kwargs):
org_dtype = next(self.parameters()).dtype
self.to(dtype)
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
return result
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward
module_type.org_forward = org_forward
try:
yield None
finally:
for module_type in patch_module_list:
module_type.forward = module_type.org_forward
def autocast(disable=False, unet=False):
def autocast(disable=False):
print(fp8, dtype, shared.cmd_opts.precision, device)
if disable:
return contextlib.nullcontext()
if unet and fp8 and device==cpu:
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
return manual_autocast()
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()

View File

@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(unet=True):
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if getattr(samples_ddim, 'already_decoded', False):

View File

@ -403,23 +403,26 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if enable_fp8:
devices.fp8 = True
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
if devices.device == devices.cpu:
for module in model.model.diffusion_model.modules():
if isinstance(module, torch.nn.Conv2d):
module.to(torch.float8_e4m3fn)
elif isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
if model.is_sdxl:
cond_stage = model.conditioner
else:
cond_stage = model.cond_stage_model
for module in cond_stage.modules():
if isinstance(module, torch.nn.Linear):
module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")
timer.record("apply fp8")
else:
devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16