mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-06-07 21:20:49 +00:00
Add general forward method for all modules.
This commit is contained in:
parent
a06dab8d7a
commit
18ca987c92
@ -3,6 +3,10 @@ import os
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from modules import sd_models, cache, errors, hashes, shared
|
from modules import sd_models, cache, errors, hashes, shared
|
||||||
|
|
||||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||||
@ -115,6 +119,29 @@ class NetworkModule:
|
|||||||
if hasattr(self.sd_module, 'weight'):
|
if hasattr(self.sd_module, 'weight'):
|
||||||
self.shape = self.sd_module.weight.shape
|
self.shape = self.sd_module.weight.shape
|
||||||
|
|
||||||
|
self.ops = None
|
||||||
|
self.extra_kwargs = {}
|
||||||
|
if isinstance(self.sd_module, nn.Conv2d):
|
||||||
|
self.ops = F.conv2d
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'stride': self.sd_module.stride,
|
||||||
|
'padding': self.sd_module.padding
|
||||||
|
}
|
||||||
|
elif isinstance(self.sd_module, nn.Linear):
|
||||||
|
self.ops = F.linear
|
||||||
|
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||||
|
self.ops = F.layer_norm
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'normalized_shape': self.sd_module.normalized_shape,
|
||||||
|
'eps': self.sd_module.eps
|
||||||
|
}
|
||||||
|
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||||
|
self.ops = F.group_norm
|
||||||
|
self.extra_kwargs = {
|
||||||
|
'num_groups': self.sd_module.num_groups,
|
||||||
|
'eps': self.sd_module.eps
|
||||||
|
}
|
||||||
|
|
||||||
self.dim = None
|
self.dim = None
|
||||||
self.bias = weights.w.get("bias")
|
self.bias = weights.w.get("bias")
|
||||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||||
@ -155,5 +182,10 @@ class NetworkModule:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
raise NotImplementedError()
|
"""A general forward implementation for all modules"""
|
||||||
|
if self.ops is None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||||
|
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||||
|
|
||||||
|
@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
|
|||||||
self.network_current_names = wanted_names
|
self.network_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
def network_forward(module, input, original_forward):
|
def network_forward(org_module, input, original_forward):
|
||||||
"""
|
"""
|
||||||
Old way of applying Lora by executing operations during layer's forward.
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
Stacking many loras this way results in big performance degradation.
|
Stacking many loras this way results in big performance degradation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if len(loaded_networks) == 0:
|
if len(loaded_networks) == 0:
|
||||||
return original_forward(module, input)
|
return original_forward(org_module, input)
|
||||||
|
|
||||||
input = devices.cond_cast_unet(input)
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
network_restore_weights_from_backup(module)
|
network_restore_weights_from_backup(org_module)
|
||||||
network_reset_cached_weight(module)
|
network_reset_cached_weight(org_module)
|
||||||
|
|
||||||
y = original_forward(module, input)
|
y = original_forward(org_module, input)
|
||||||
|
|
||||||
network_layer_name = getattr(module, 'network_layer_name', None)
|
network_layer_name = getattr(org_module, 'network_layer_name', None)
|
||||||
for lora in loaded_networks:
|
for lora in loaded_networks:
|
||||||
module = lora.modules.get(network_layer_name, None)
|
module = lora.modules.get(network_layer_name, None)
|
||||||
if module is None:
|
if module is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user