replace duplicate code with a function

This commit is contained in:
AUTOMATIC 2022-10-11 11:09:51 +03:00
parent 5e2627a1a6
commit 948533950c
2 changed files with 29 additions and 38 deletions

View File

@ -64,21 +64,26 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
def apply_hypernetwork(hypernetwork, context):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None): def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) k = self.to_k(context_k)
v = self.to_v(context_v)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

View File

@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
from modules import shared from modules import shared, hypernetwork
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try: try:
@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
if hypernetwork_layers is not None: del context, context_k, context_v, x
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2) return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion # taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None): def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
k_in *= self.scale k_in *= self.scale
@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
if hypernetwork_layers is not None: k_in = self.to_k(context_k)
k_in = self.to_k(hypernetwork_layers[0](context)) v_in = self.to_v(context_v)
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)