doc: add boft comment

This commit is contained in:
v0xie 2024-02-08 21:58:59 -08:00
parent 325eaeb584
commit 613b0d9548
1 changed files with 3 additions and 3 deletions

View File

@ -29,13 +29,14 @@ class NetworkModuleOFT(network.NetworkModule):
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
self.alpha = weights.w["alpha"] # alpha is constraint
self.dim = self.oft_blocks.shape[0] # lora dim
# LyCORIS
# LyCORIS OFT
elif "oft_diag" in weights.w.keys():
self.is_kohya = False
self.oft_blocks = weights.w["oft_diag"]
# self.alpha is unused
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
# LyCORIS BOFT
self.is_boft = False
if weights.w["oft_diag"].dim() == 4:
self.is_boft = True
@ -89,6 +90,7 @@ class NetworkModuleOFT(network.NetworkModule):
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
else:
# TODO: determine correct value for scale
scale = 1.0
m = self.boft_m
b = self.boft_b
@ -99,8 +101,6 @@ class NetworkModuleOFT(network.NetworkModule):
if i == 0:
# Apply multiplier/scale and rescale into first weight
bi = bi * scale + (1 - scale) * eye
#if self.rescaled:
# bi = bi * self.rescale
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)