Merge pull request #635 from C43H66N12O12S2/attention

Move scale multiplication to the front
This commit is contained in:
AUTOMATIC1111 2022-09-18 07:28:53 +03:00 committed by GitHub
commit 17b60490fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,7 +50,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
k_in = self.to_k(context) k_in = self.to_k(context) * self.scale
v_in = self.to_v(context) v_in = self.to_v(context)
del context, x del context, x
@ -85,7 +85,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype) s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1 del s1