refactor: remove unused function

This commit is contained in:
v0xie 2023-11-04 14:56:47 -07:00
parent 329c8bacce
commit bbf00a96af

View File

@ -2,7 +2,6 @@ import torch
import network import network
from lyco_helpers import factorization from lyco_helpers import factorization
from einops import rearrange from einops import rearrange
from modules import devices
class ModuleTypeOFT(network.ModuleType): class ModuleTypeOFT(network.ModuleType):
@ -54,58 +53,12 @@ class NetworkModuleOFT(network.NetworkModule):
raise ValueError("sd_module must be Linear or Conv") raise ValueError("sd_module must be Linear or Conv")
if self.is_kohya: if self.is_kohya:
#self.num_blocks = self.dim
#self.block_size = self.out_dim // self.num_blocks
#self.block_size = self.dim
#self.num_blocks = self.out_dim // self.block_size
self.constraint = self.alpha * self.out_dim self.constraint = self.alpha * self.out_dim
self.num_blocks, self.block_size = factorization(self.out_dim, self.dim) self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
else: else:
self.constraint = None self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
if is_other_linear:
self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True)
def create_module(self, weights, key, none_ok=False):
weight = weights.get(key)
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)
if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif is_conv and key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
with torch.no_grad():
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
module.to(device=devices.cpu, dtype=devices.dtype)
module.weight.requires_grad_(False)
return module
def merge_weight(self, R_weight, org_weight): def merge_weight(self, R_weight, org_weight):
R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype) R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4: if org_weight.dim() == 4: