From c4b9b07db6272768428fa8efeb7d7a9f22eca0b1 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 09:00:15 -0500 Subject: [PATCH] Fix embeddings dtype mismatch --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f9652d215..531790f36 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: - emb = embedding.vec + emb = embedding.vec.to(devices.dtype_unet) if devices.unet_needs_upcast else embedding.vec emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])