mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Fix bugs when arg dtype doesn't match
This commit is contained in:
parent
209c26a1cb
commit
42e6df723c
@ -134,24 +134,19 @@ patch_module_list = [
|
||||
|
||||
def manual_cast_forward(target_dtype):
|
||||
def forward_wrapper(self, *args, **kwargs):
|
||||
org_dtype = torch_utils.get_param(self).dtype
|
||||
if not target_dtype == org_dtype == dtype_inference:
|
||||
self.to(target_dtype)
|
||||
args = [
|
||||
arg.to(target_dtype)
|
||||
if isinstance(arg, torch.Tensor)
|
||||
else arg
|
||||
for arg in args
|
||||
]
|
||||
kwargs = {
|
||||
k: v.to(target_dtype)
|
||||
if isinstance(v, torch.Tensor)
|
||||
else v
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
if any(
|
||||
isinstance(arg, torch.Tensor) and arg.dtype != target_dtype
|
||||
for arg in args
|
||||
):
|
||||
args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
||||
kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
|
||||
org_dtype = torch_utils.get_param(self).dtype
|
||||
if org_dtype != target_dtype:
|
||||
self.to(target_dtype)
|
||||
result = self.org_forward(*args, **kwargs)
|
||||
self.to(org_dtype)
|
||||
if org_dtype != target_dtype:
|
||||
self.to(org_dtype)
|
||||
|
||||
if target_dtype != dtype_inference:
|
||||
if isinstance(result, tuple):
|
||||
|
Loading…
Reference in New Issue
Block a user