Replace einops.rearrange with torch native

This commit is contained in:
huchenlei 2024-05-15 15:46:53 -04:00
parent 1c0a0c4c26
commit 0e98529365
1 changed files with 16 additions and 2 deletions

View File

@ -486,7 +486,19 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
def _reshape(t):
"""rearrange(t, 'b n (h d) -> b n h d', h=h).
Using torch native operations to avoid overhead as this function is
called frequently. (70 times/it for SDXL)
"""
b, n, _ = t.shape # Get the batch size (b) and sequence length (n)
d = t.shape[2] // h # Determine the depth per head
return t.reshape(b, n, h, d)
q = _reshape(q_in)
k = _reshape(k_in)
v = _reshape(v_in)
del q_in, k_in, v_in
dtype = q.dtype
@ -497,7 +509,9 @@ def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
out = out.to(dtype)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# out = rearrange(out, 'b n h d -> b n (h d)', h=h)
b, n, h, d = out.shape
out = out.reshape(b, n, h * d)
return self.to_out(out)