Update network_oft.py

This commit is contained in:
Kohaku-Blueleaf 2024-02-21 22:50:43 +08:00 committed by GitHub
parent 591470d86d
commit 64179c3221
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -72,7 +72,7 @@ class NetworkModuleOFT(network.NetworkModule):
eye = torch.eye(self.block_size, device=oft_blocks.device)
if not self.is_R:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))