Source code for pywick.optimizers.lookahead

# Source: https://github.com/alphadl/lookahead.pytorch/blob/master/lookahead.py (MIT)

from collections import defaultdict
from torch.optim.optimizer import Optimizer
import torch


[docs]class Lookahead(Optimizer): r""" Implementation of `Lookahead Optimizer: k steps forward, 1 step back <https://arxiv.org/abs/1907.08610>`_ Args: :param optimizer: - the optimizer to work with (sgd, adam etc) :param k: (int) - number of steps to look ahead (default=5) :param alpha: (float) - slow weights step size """ def __init__(self, optimizer, k=5, alpha=0.5): """ :param optimizer: - the optimizer to work with (sgd, adam etc) :param k: (int) - number of steps to look ahead (default=5) :param alpha: (float) - slow weights step size """ self.optimizer = optimizer self.k = k self.alpha = alpha self.param_groups = self.optimizer.param_groups self.state = defaultdict(dict) self.fast_state = self.optimizer.state for group in self.param_groups: group["counter"] = 0
[docs] def update(self, group): for fast in group["params"]: param_state = self.state[fast] if "slow_param" not in param_state: param_state["slow_param"] = torch.zeros_like(fast.data) param_state["slow_param"].copy_(fast.data) slow = param_state["slow_param"] slow += (fast.data - slow) * self.alpha fast.data.copy_(slow)
[docs] def update_lookahead(self): for group in self.param_groups: self.update(group)
[docs] def step(self, closure=None): loss = self.optimizer.step(closure) for group in self.param_groups: if group["counter"] == 0: self.update(group) group["counter"] += 1 if group["counter"] >= self.k: group["counter"] = 0 return loss
[docs] def state_dict(self): fast_state_dict = self.optimizer.state_dict() slow_state = { (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } fast_state = fast_state_dict["state"] param_groups = fast_state_dict["param_groups"] return { "fast_state": fast_state, "slow_state": slow_state, "param_groups": param_groups, }
[docs] def load_state_dict(self, state_dict): slow_state_dict = { "state": state_dict["slow_state"], "param_groups": state_dict["param_groups"], } fast_state_dict = { "state": state_dict["fast_state"], "param_groups": state_dict["param_groups"], } super(Lookahead, self).load_state_dict(slow_state_dict) self.optimizer.load_state_dict(fast_state_dict) self.fast_state = self.optimizer.state
[docs] def add_param_group(self, param_group): param_group["counter"] = 0 self.optimizer.add_param_group(param_group)