Fix rescale

This commit is contained in:
Kohaku-Blueleaf 2024-02-18 14:58:41 +08:00 committed by GitHub
parent 90441294db
commit 5a8dd0c549
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -40,7 +40,9 @@ class NetworkModuleOFT(network.NetworkModule):
self.is_boft = False
if weights.w["oft_diag"].dim() == 4:
self.is_boft = True
self.rescale = weight.w.get('rescale', None)
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]