Source code for pywick.optimizers.adabelief

# Source: https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adabelief.py (apache 2.0)

import math

import torch
from torch.optim.optimizer import Optimizer

from .a2grad import Betas2, OptFloat, OptLossClosure, Params

__all__ = 'AdaBelief'


[docs]class AdaBelief(Optimizer): r"""Implements AdaBelief Optimizer Algorithm. It has been proposed in `AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients`__. Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: learning rate (default: 1e-2) betas: coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps: term added to the denominator to improve numerical stability (default: 0.001) weight_decay: weight decay (L2 penalty) (default: 0) amsgrad: whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) weight_decouple: If set as True, then the optimizer uses decoupled weight decay as in AdamW (default: False) fixed_decay : This is used when weight_decouple is set as True. When fixed_decay == True, the weight decay is performed as $W_{new} = W_{old} - W_{old} \times decay$. When fixed_decay == False, the weight decay is performed as $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the weight decay ratio decreases with learning rate (lr). (default: False) rectify: (default: False) If set as True, then perform the rectified update similar to RAdam __ https://arxiv.org/abs/2010.07468 Note: Reference code: https://github.com/juntang-zhuang/Adabelief-Optimizer """ def __init__( self, params: Params, lr: float = 1e-3, betas: Betas2 = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0, amsgrad: bool = False, weight_decouple: bool = False, fixed_decay: bool = False, rectify: bool = False, ) -> None: if lr <= 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) if eps < 0.0: raise ValueError('Invalid epsilon value: {}'.format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError( 'Invalid beta parameter at index 0: {}'.format(betas[0]) ) if not 0.0 <= betas[1] < 1.0: raise ValueError( 'Invalid beta parameter at index 1: {}'.format(betas[1]) ) if weight_decay < 0: raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay) ) defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, ) super(AdaBelief, self).__init__(params, defaults) self._weight_decouple = weight_decouple self._rectify = rectify self._fixed_decay = fixed_decay def __setstate__(self, state): super(AdaBelief, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError( 'AdaBelief does not support sparse gradients, ' 'please consider SparseAdam instead' ) amsgrad = group['amsgrad'] state = self.state[p] beta1, beta2 = group['betas'] # State initialization if len(state) == 0: state['rho_inf'] = 2.0 / (1.0 - beta2) - 1.0 state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like( p.data, memory_format=torch.preserve_format ) # Exponential moving average of squared gradient values state['exp_avg_var'] = torch.zeros_like( p.data, memory_format=torch.preserve_format ) if amsgrad: # Maintains max of all exp. moving avg. of # sq. grad. values state['max_exp_avg_var'] = torch.zeros_like( p.data, memory_format=torch.preserve_format ) # get current state variable exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] state['step'] += 1 bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] # perform weight decay, check if decoupled weight decay if self._weight_decouple: if not self._fixed_decay: p.data.mul_(1.0 - group['lr'] * group['weight_decay']) else: p.data.mul_(1.0 - group['weight_decay']) else: if group['weight_decay'] != 0: grad.add_(p.data, alpha=group['weight_decay']) # Update first and second moment running average exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) grad_residual = grad - exp_avg exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2 ) if amsgrad: max_exp_avg_var = state['max_exp_avg_var'] # Maintains the maximum of all 2nd moment running # avg. till now torch.max( max_exp_avg_var, exp_avg_var, out=max_exp_avg_var ) # Use the max. for normalizing running avg. of gradient denom = ( max_exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2) ).add_(group['eps']) else: denom = ( exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2) ).add_(group['eps']) if not self._rectify: # Default update step_size = group['lr'] / bias_correction1 p.data.addcdiv_(exp_avg, denom, value=-step_size) else: # Rectified update # calculate rho_t state['rho_t'] = state['rho_inf'] - 2 * state[ 'step' ] * beta2 ** state['step'] / (1.0 - beta2 ** state['step']) if ( state['rho_t'] > 4 ): # perform Adam style update if variance is small rho_inf, rho_t = state['rho_inf'], state['rho_t'] rt = ( (rho_t - 4.0) * (rho_t - 2.0) * rho_inf / (rho_inf - 4.0) / (rho_inf - 2.0) / rho_t ) rt = math.sqrt(rt) step_size = rt * group['lr'] / bias_correction1 p.data.addcdiv_(-step_size, exp_avg, denom) else: # perform SGD style update p.data.add_(-group['lr'], exp_avg) return loss