From fe1967a4c4a02eccfa45b65ee19a5b0773ced31c Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Fri, 3 Nov 2023 17:52:55 -0700 Subject: [PATCH] skip multihead attn for now --- extensions-builtin/Lora/network_oft.py | 52 ++++++++++++++++++-------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e102eafc1..979a20476 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule): super().__init__(net, weights) self.lin_module = None + self.org_module: list[torch.Module] = [self.sd_module] # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule): # alpha is rank if alpha is 0 or None if self.alpha is None: pass - self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] + self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n] else: raise ValueError("oft_blocks or oft_diag must be in weights dict") @@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule): # raise ValueError("Linear sd_module must have out_features or embed_dim") elif is_other_linear: self.out_dim = self.sd_module.embed_dim + #self.org_weight = self.org_module[0].weight +# if hasattr(self.sd_module, "in_proj_weight"): +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] +# if hasattr(self.sd_module, "out_proj_weight"): +# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0] +# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1] elif is_conv: self.out_dim = self.sd_module.out_channels else: @@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = self.alpha * self.out_dim #elif is_linear or is_conv: else: - self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) + self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.constraint = None - self.org_module: list[torch.Module] = [self.sd_module] # if is_other_linear: # weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1) @@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): multiplier = self.multiplier() * self.calc_scale() - R = self.get_weight(self.oft_blocks, multiplier) - #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - merged_weight = self.merge_weight(R, orig_weight) + is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention] + if self.is_kohya and not is_other_linear: + R = self.get_weight(self.oft_blocks, multiplier) + #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = self.merge_weight(R, orig_weight) + elif not self.is_kohya and not is_other_linear: + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) + #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks) + merged_weight = torch.einsum( + 'k n m, k n ... -> k m ...', + R * multiplier + torch.eye(self.block_size, device=orig_weight.device), + merged_weight + ) + merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') + if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: + orig_weight=orig_weight.permute(1, 0) + #merged_weight=merged_weight.permute(1, 0) + updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + output_shape = orig_weight.shape + else: + # skip for now + updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype) + output_shape = (orig_weight.shape[1], orig_weight.shape[1]) #if self.lin_module is not None: # R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype) # weight = torch.mul(torch.mul(R, multiplier), orig_weight) #else: - # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) - # weight = torch.einsum( - # 'k n m, k n ... -> k m ...', - # R * multiplier + torch.eye(self.block_size, device=orig_weight.device), - # orig_weight - # ) - # weight = rearrange(weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - output_shape = orig_weight.shape orig_weight = orig_weight return self.finalize_updown(updown, orig_weight, output_shape)