import tqdm class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): """ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 """ pairs = learn_rate.split(',') self.rates = [] self.it = 0 self.maxit = 0 for i, pair in enumerate(pairs): tmp = pair.split(':') if len(tmp) == 2: step = int(tmp[1]) if step > cur_step: self.rates.append((float(tmp[0]), min(step, max_steps))) self.maxit += 1 if step > max_steps: return elif step == -1: self.rates.append((float(tmp[0]), max_steps)) self.maxit += 1 return else: self.rates.append((float(tmp[0]), max_steps)) self.maxit += 1 return def __iter__(self): return self def __next__(self): if self.it < self.maxit: self.it += 1 return self.rates[self.it - 1] else: raise StopIteration class LearnRateScheduler: def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) (self.learn_rate, self.end_step) = next(self.schedules) self.verbose = verbose if self.verbose: print(f'Training at rate of {self.learn_rate} until step {self.end_step}') self.finished = False def apply(self, optimizer, step_number): if step_number < self.end_step: return try: (self.learn_rate, self.end_step) = next(self.schedules) except Exception: self.finished = True return if self.verbose: tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') for pg in optimizer.param_groups: pg['lr'] = self.learn_rate