Source code for pywick.optimizers.a2grad

# Source: https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/a2grad.py (apache 2.0)

import copy
import math
from typing import Optional, Tuple, Dict, Any, Callable, Union, Iterable

import torch
from torch.optim.optimizer import Optimizer

from torch import Tensor

Params = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]

LossClosure = Callable[[], float]
OptLossClosure = Optional[LossClosure]
Betas2 = Tuple[float, float]
State = Dict[str, Any]
OptFloat = Optional[float]
Nus2 = Tuple[float, float]


__all__ = ('A2GradUni', 'A2GradInc', 'A2GradExp', 'Betas2', 'OptFloat', 'OptLossClosure', 'Params', 'State', 'Nus2')


[docs]class A2GradUni(Optimizer): r"""Implements A2GradUni Optimizer Algorithm. It has been proposed in `Optimal Adaptive and Accelerated Stochastic Gradient Descent`__. Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: not used for this optimizer (default: None) beta: (default: 10) lips: Lipschitz constant (default: 10) __ https://arxiv.org/abs/1810.00553 Note: Reference code: https://github.com/severilov/A2Grad_optimizer """ def __init__( self, params: Params, lr: Optional[float] = None, beta: float = 10, lips: float = 10, ): defaults = dict(beta=beta, lips=lips, lr=lr) # lr is not supported for this optimizer, we need to make tests work # and schedulers not to fail if beta < 0.0: raise ValueError('Invalid beta value: {}'.format(beta)) if lips < 0.0: raise ValueError('Invalid lips value: {}'.format(lips)) super().__init__(params, defaults)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self.state[p] if len(state) == 0: state['step'] = 0 state['alpha_k'] = 1 state['v_k'] = 0 state['avg_grad'] = copy.deepcopy(grad) state['x_k'] = copy.deepcopy(p.data) gamma_k = 2 * group['lips'] / (state['step'] + 1) avg_grad = state['avg_grad'] avg_grad.mul_(state['step']) avg_grad.add_(grad) avg_grad.div_(state['step'] + 1) delta_k = torch.add(grad, avg_grad, alpha=-1) state['v_k'] += torch.sum(delta_k * delta_k).item() h_k = math.sqrt(state['v_k']) alpha_k_1 = 2 / (state['step'] + 3) coef = 1 / (gamma_k + group['beta'] * h_k) x_k_1 = state['x_k'] x_k_1.add_(grad, alpha=-coef) p.data.mul_(1 - alpha_k_1) p.data.add_(x_k_1, alpha=alpha_k_1) p.data.add_( grad, alpha=-(1 - alpha_k_1) * state['alpha_k'] * coef ) state['alpha_k'] = alpha_k_1 state['step'] += 1 return loss
[docs]class A2GradInc(Optimizer): r"""Implements A2GradInc Optimizer Algorithm. It has been proposed in `Optimal Adaptive and Accelerated Stochastic Gradient Descent`__. Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: not used for this optimizer (default: None) beta: (default: 10) lips: Lipschitz constant (default: 10) __ https://arxiv.org/abs/1810.00553 Note: Reference code: https://github.com/severilov/A2Grad_optimizer """ def __init__( self, params: Params, lr: Optional[float] = None, beta: float = 10, lips: float = 10, ): if beta < 0.0: raise ValueError('Invalid beta value: {}'.format(beta)) if lips < 0.0: raise ValueError('Invalid weight_decay value: {}'.format(lips)) defaults = dict(beta=beta, lips=lips, lr=lr) super(A2GradInc, self).__init__(params, defaults)
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self.state[p] if len(state) == 0: state['step'] = 0 state['alpha_k'] = 1 state['v_k'] = 0 state['avg_grad'] = copy.deepcopy(grad) state['x_k'] = copy.deepcopy(p.data) gamma_k = 2 * group['lips'] / (state['step'] + 1) avg_grad = state['avg_grad'] avg_grad.mul_(state['step']) avg_grad.add_(grad) avg_grad.div_(state['step'] + 1) delta_k = torch.add(grad, avg_grad, alpha=-1) state['v_k'] *= (state['step'] / (state['step'] + 1)) ** 2 state['v_k'] += torch.sum(delta_k * delta_k).item() h_k = math.sqrt(state['v_k']) alpha_k_1 = 2 / (state['step'] + 3) coef = 1 / (gamma_k + group['beta'] * h_k) x_k_1 = state['x_k'] x_k_1.add_(grad, alpha=-coef) p.data.mul_(1 - alpha_k_1) p.data.add_(x_k_1, alpha=alpha_k_1) p.data.add_( grad, alpha=-(1 - alpha_k_1) * state['alpha_k'] * coef ) state['alpha_k'] = alpha_k_1 state['step'] += 1 return loss
[docs]class A2GradExp(Optimizer): r"""Implements A2GradExp Optimizer Algorithm. It has been proposed in `Optimal Adaptive and Accelerated Stochastic Gradient Descent`__. Arguments: params: iterable of parameters to optimize or dicts defining parameter groups lr: not used for this optimizer (default: None) beta: (default: 10) lips: Lipschitz constant (default: 10) rho: represents the degree of weighting decrease, a constant smoothing factor between 0 and 1 (default: 0.5) __ https://arxiv.org/abs/1810.00553 Note: Reference code: https://github.com/severilov/A2Grad_optimizer """ def __init__( self, params: Params, lr: Optional[float] = None, beta: float = 10, lips: float = 10, rho: float = 0.5, ): defaults = dict(beta=beta, lips=lips, rho=rho, lr=lr) super(A2GradExp, self).__init__(params, defaults) if beta < 0.0: raise ValueError('Invalid beta value: {}'.format(beta)) if lips < 0.0: raise ValueError('Invalid lips value: {}'.format(lips)) if rho < 0.0 or rho > 1.0: raise ValueError('Invalid rho value: {}'.format(rho))
[docs] def step(self, closure: OptLossClosure = None) -> OptFloat: r"""Performs a single optimization step. Arguments: closure: A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data state = self.state[p] if len(state) == 0: state['step'] = 0 state['alpha_k'] = 1 state['v_k'] = 0 state['avg_grad'] = copy.deepcopy(grad) state['x_k'] = copy.deepcopy(p.data) gamma_k = 2 * group['lips'] / (state['step'] + 1) avg_grad = state['avg_grad'] avg_grad.mul_(state['step']) avg_grad.add_(grad) avg_grad.div_(state['step'] + 1) delta_k = torch.add(grad, avg_grad, alpha=-1) if state['step'] == 0: state['v_kk'] = torch.sum(delta_k * delta_k).item() else: state['v_kk'] *= group['rho'] state['v_kk'] += (1 - group['rho']) * torch.sum( delta_k * delta_k ).item() state['v_k'] = max([state['v_kk'], state['v_k']]) h_k = math.sqrt((state['step'] + 1) * state['v_k']) alpha_k_1 = 2 / (state['step'] + 3) coef = -1 / (gamma_k + group['beta'] * h_k) x_k_1 = state['x_k'] x_k_1.add_(grad, alpha=coef) p.data.mul_(1 - alpha_k_1) p.data.add_(x_k_1, alpha=alpha_k_1) p.data.add_( grad, alpha=(1 - alpha_k_1) * state['alpha_k'] * coef ) state['alpha_k'] = alpha_k_1 state['step'] += 1 return loss