stable-diffusion-webui/modules/hypernetwork.py

89 lines
2.5 KiB
Python

import glob
import os
import sys
import traceback
import torch
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
def __init__(self, dim, state_dict):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 2)
self.linear2 = torch.nn.Linear(dim * 2, dim)
self.load_state_dict(state_dict, strict=True)
self.to(devices.device)
def forward(self, x):
return x + (self.linear2(self.linear1(x)))
class Hypernetwork:
filename = None
name = None
def __init__(self, filename):
self.filename = filename
self.name = os.path.splitext(os.path.basename(filename))[0]
self.layers = {}
state_dict = torch.load(filename, map_location='cpu')
for size, sd in state_dict.items():
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
def load_hypernetworks(path):
res = {}
for filename in glob.iglob(path + '**/*.pt', recursive=True):
try:
hn = Hypernetwork(filename)
res[hn.name] = hn
except Exception:
print(f"Error loading hypernetwork {filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return res
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)