diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ed221d8fe..f5e657b82 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -40,7 +40,9 @@ class NetworkModuleOFT(network.NetworkModule): self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True - self.rescale = weight.w.get('rescale', None) + self.rescale = weights.w.get('rescale', None) + if self.rescale is not None: + self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1)) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d]