diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index d7b317029..ed221d8fe 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -40,6 +40,7 @@ class NetworkModuleOFT(network.NetworkModule): self.is_boft = False if weights.w["oft_diag"].dim() == 4: self.is_boft = True + self.rescale = weight.w.get('rescale', None) is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_conv = type(self.sd_module) in [torch.nn.Conv2d] @@ -108,6 +109,10 @@ class NetworkModuleOFT(network.NetworkModule): inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) merged_weight = inp + # Rescale mechanism + if self.rescale is not None: + merged_weight = self.rescale.to(merged_weight) * merged_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape)