diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index b824b5bff..ce5839504 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -182,11 +182,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != devices.device: - - if devices.has_mps(): - attr = attr.to(device="mps", dtype=torch.float32) - else: - attr = attr.to(devices.device) + attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) setattr(self, name, attr)