Fix bugs when arg dtype doesn't match

This commit is contained in:
KohakuBlueleaf 2024-01-09 22:39:39 +08:00
parent 209c26a1cb
commit 42e6df723c

View File

@ -134,24 +134,19 @@ patch_module_list = [
def manual_cast_forward(target_dtype):
def forward_wrapper(self, *args, **kwargs):
org_dtype = torch_utils.get_param(self).dtype
if not target_dtype == org_dtype == dtype_inference:
self.to(target_dtype)
args = [
arg.to(target_dtype)
if isinstance(arg, torch.Tensor)
else arg
for arg in args
]
kwargs = {
k: v.to(target_dtype)
if isinstance(v, torch.Tensor)
else v
for k, v in kwargs.items()
}
if any(
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
for arg in args
):
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
org_dtype = torch_utils.get_param(self).dtype
if org_dtype != target_dtype:
self.to(target_dtype)
result = self.org_forward(*args, **kwargs)
self.to(org_dtype)
if org_dtype != target_dtype:
self.to(org_dtype)
if target_dtype != dtype_inference:
if isinstance(result, tuple):