add pbar to unipc

This commit is contained in:
Vladimir Mandic 2023-03-13 12:35:30 -04:00 committed by GitHub
parent dfeee786f9
commit 03a80f198e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -71,7 +71,7 @@ class UniPCSampler(object):
# sampling # sampling
C, H, W = shape C, H, W = shape
size = (batch_size, C, H, W) size = (batch_size, C, H, W)
print(f'Data shape for UniPC sampling is {size}') # print(f'Data shape for UniPC sampling is {size}')
device = self.model.betas.device device = self.model.betas.device
if x_T is None: if x_T is None:

View File

@ -1,6 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import math import math
from tqdm.auto import trange
class NoiseScheduleVP: class NoiseScheduleVP:
@ -750,7 +751,7 @@ class UniPC:
if method == 'multistep': if method == 'multistep':
assert steps >= order, "UniPC order must be < sampling steps" assert steps >= order, "UniPC order must be < sampling steps"
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") #print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}")
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
with torch.no_grad(): with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0])) vec_t = timesteps[0].expand((x.shape[0]))
@ -766,7 +767,7 @@ class UniPC:
self.after_update(x, model_x) self.after_update(x, model_x)
model_prev_list.append(model_x) model_prev_list.append(model_x)
t_prev_list.append(vec_t) t_prev_list.append(vec_t)
for step in range(order, steps + 1): for step in trange(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0]) vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final: if lower_order_final:
step_order = min(order, steps + 1 - step) step_order = min(order, steps + 1 - step)