fix dims typo in unipc

This commit is contained in:
EllangoK 2023-03-11 15:52:14 -05:00
parent 27e319dc4f
commit 48f4abd2e6

View File

@ -719,7 +719,7 @@ class UniPC:
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
)
if x_t is None: