mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Merge pull request #14871 from v0xie/boft
Support inference with LyCORIS BOFT networks
This commit is contained in:
commit
6e4fc5e1a8
@ -22,6 +22,8 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.org_module: list[torch.Module] = [self.sd_module]
|
||||
|
||||
self.scale = 1.0
|
||||
self.is_kohya = False
|
||||
self.is_boft = False
|
||||
|
||||
# kohya-ss
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
@ -29,13 +31,19 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.oft_blocks = weights.w["oft_blocks"] # (num_blocks, block_size, block_size)
|
||||
self.alpha = weights.w["alpha"] # alpha is constraint
|
||||
self.dim = self.oft_blocks.shape[0] # lora dim
|
||||
# LyCORIS
|
||||
# LyCORIS OFT
|
||||
elif "oft_diag" in weights.w.keys():
|
||||
self.is_kohya = False
|
||||
self.oft_blocks = weights.w["oft_diag"]
|
||||
# self.alpha is unused
|
||||
self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)
|
||||
|
||||
# LyCORIS BOFT
|
||||
if weights.w["oft_diag"].dim() == 4:
|
||||
self.is_boft = True
|
||||
self.rescale = weights.w.get('rescale', None)
|
||||
if self.rescale is not None:
|
||||
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
|
||||
|
||||
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
|
||||
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
|
||||
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported
|
||||
@ -51,6 +59,13 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
self.constraint = self.alpha * self.out_dim
|
||||
self.num_blocks = self.dim
|
||||
self.block_size = self.out_dim // self.dim
|
||||
elif self.is_boft:
|
||||
self.constraint = None
|
||||
self.boft_m = weights.w["oft_diag"].shape[0]
|
||||
self.block_num = weights.w["oft_diag"].shape[1]
|
||||
self.block_size = weights.w["oft_diag"].shape[2]
|
||||
self.boft_b = self.block_size
|
||||
#self.block_size, self.block_num = butterfly_factor(self.out_dim, self.dim)
|
||||
else:
|
||||
self.constraint = None
|
||||
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
|
||||
@ -68,6 +83,7 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
|
||||
R = oft_blocks.to(orig_weight.device)
|
||||
|
||||
if not self.is_boft:
|
||||
# This errors out for MultiheadAttention, might need to be handled up-stream
|
||||
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
|
||||
merged_weight = torch.einsum(
|
||||
@ -76,6 +92,28 @@ class NetworkModuleOFT(network.NetworkModule):
|
||||
merged_weight
|
||||
)
|
||||
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
|
||||
else:
|
||||
# TODO: determine correct value for scale
|
||||
scale = 1.0
|
||||
m = self.boft_m
|
||||
b = self.boft_b
|
||||
r_b = b // 2
|
||||
inp = orig_weight
|
||||
for i in range(m):
|
||||
bi = R[i] # b_num, b_size, b_size
|
||||
if i == 0:
|
||||
# Apply multiplier/scale and rescale into first weight
|
||||
bi = bi * scale + (1 - scale) * eye
|
||||
inp = rearrange(inp, "(c g k) ... -> (c k g) ...", g=2, k=2**i * r_b)
|
||||
inp = rearrange(inp, "(d b) ... -> d b ...", b=b)
|
||||
inp = torch.einsum("b i j, b j ... -> b i ...", bi, inp)
|
||||
inp = rearrange(inp, "d b ... -> (d b) ...")
|
||||
inp = rearrange(inp, "(c k g) ... -> (c g k) ...", g=2, k=2**i * r_b)
|
||||
merged_weight = inp
|
||||
|
||||
# Rescale mechanism
|
||||
if self.rescale is not None:
|
||||
merged_weight = self.rescale.to(merged_weight) * merged_weight
|
||||
|
||||
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
|
||||
output_shape = orig_weight.shape
|
||||
|
Loading…
Reference in New Issue
Block a user