Add rescale mechanism

LyCORIS will support save oft_blocks instead of oft_diag in the near future (for both OFT and BOFT)

But this means we need to store the rescale if user enable it.
This commit is contained in:
Kohaku-Blueleaf 2024-02-12 14:25:09 +08:00 committed by GitHub
parent eb6f2df826
commit 90441294db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,6 +40,7 @@ class NetworkModuleOFT(network.NetworkModule):
self.is_boft = False self.is_boft = False
if weights.w["oft_diag"].dim() == 4: if weights.w["oft_diag"].dim() == 4:
self.is_boft = True self.is_boft = True
self.rescale = weight.w.get('rescale', None)
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d] is_conv = type(self.sd_module) in [torch.nn.Conv2d]
@ -108,6 +109,10 @@ class NetworkModuleOFT(network.NetworkModule):
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b) inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
merged_weight = inp merged_weight = inp
# Rescale mechanism
if self.rescale is not None:
merged_weight = self.rescale.to(merged_weight) * merged_weight
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape) return self.finalize_updown(updown, orig_weight, output_shape)