Source code for pywick.callbacks.ReduceLROnPlateau

from . import Callback

__all__ = ['ReduceLROnPlateau']


[docs]class ReduceLROnPlateau(Callback): """ Reduce the learning rate if the train or validation loss plateaus :param monitor: (string in {'loss', 'val_loss'}): which metric to monitor :param factor: (float): factor to decrease learning rate by :param patience: (int): number of epochs to wait for loss improvement before reducing lr :param epsilon: (float): how much improvement must be made to reset patience :param cooldown: (int): number of epochs to cooldown after a lr reduction :param min_lr: (float): minimum value to ever let the learning rate decrease to :param verbose: (int): whether to print reduction to console """ def __init__(self, monitor='val_loss', factor=0.1, patience=10, epsilon=0, cooldown=0, min_lr=0, verbose=0): self.monitor = monitor if factor >= 1.0: raise ValueError('ReduceLROnPlateau does not support a factor >= 1.0.') self.factor = factor self.min_lr = min_lr self.epsilon = epsilon self.patience = patience self.verbose = verbose self.cooldown = cooldown self.cooldown_counter = 0 self.wait = 0 self.best_loss = 1e15 self._reset() super(ReduceLROnPlateau, self).__init__() def _reset(self): """ Reset the wait and cooldown counters """ self.monitor_op = lambda a, b: (a - b) < -self.epsilon self.best_loss = 1e15 self.cooldown_counter = 0 self.wait = 0 def on_train_begin(self, logs=None): self._reset() def on_epoch_end(self, epoch, logs=None): logs = logs or {} logs['lr'] = [p['lr'] for p in self.trainer._optimizer.param_groups] current_loss = logs.get(self.monitor) if current_loss is None: pass else: # if in cooldown phase if self.cooldown_counter > 0: self.cooldown_counter -= 1 self.wait = 0 # if loss improved, grab new loss and reset wait counter if self.monitor_op(current_loss, self.best_loss): self.best_loss = current_loss self.wait = 0 # loss didnt improve, and not in cooldown phase elif not (self.cooldown_counter > 0): if self.wait >= self.patience: for p in self.trainer._optimizer.param_groups: old_lr = p['lr'] if old_lr > self.min_lr + 1e-4: new_lr = old_lr * self.factor new_lr = max(new_lr, self.min_lr) if self.verbose > 0: print('\nEpoch %05d: reducing lr from %0.3f to %0.3f' % (epoch, old_lr, new_lr)) p['lr'] = new_lr self.cooldown_counter = self.cooldown self.wait = 0 self.wait += 1