Source code for pywick.constraints

"""
Constraints can be selectively applied on layers using regular expressions.
Constraints can be explicit (hard) constraints applied at an arbitrary batch or epoch frequency, or they can be implicit (soft)
constraints similar to regularizers where the the constraint deviation is added as a penalty to the total model loss.
"""

from fnmatch import fnmatch

import torch as th
from .callbacks import Callback


class ConstraintContainer:

    def __init__(self, constraints):
        self.constraints = constraints
        self.batch_constraints = [c for c in self.constraints if c.unit.upper() == 'BATCH']
        self.epoch_constraints = [c for c in self.constraints if c.unit.upper() == 'EPOCH']

    def register_constraints(self, model):
        """
        Grab pointers to the weights which will be modified by constraints so
        that we don't have to search through the entire network using `apply`
        each time
        """
        # get batch constraint pointers
        self._batch_c_ptrs = {}
        for c_idx, constraint in enumerate(self.batch_constraints):
            self._batch_c_ptrs[c_idx] = []
            for name, module in model.named_modules():
                if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'):
                    self._batch_c_ptrs[c_idx].append(module)

        # get epoch constraint pointers
        self._epoch_c_ptrs = {}
        for c_idx, constraint in enumerate(self.epoch_constraints):
            self._epoch_c_ptrs[c_idx] = []
            for name, module in model.named_modules():
                if fnmatch(name, constraint.module_filter) and hasattr(module, 'weight'):
                    self._epoch_c_ptrs[c_idx].append(module)

    def apply_batch_constraints(self, batch_idx):
        for c_idx, modules in self._batch_c_ptrs.items():
            if (batch_idx+1) % self.constraints[c_idx].frequency == 0:
                for module in modules:
                    self.constraints[c_idx](module)

    def apply_epoch_constraints(self, epoch_idx):
        for c_idx, modules in self._epoch_c_ptrs.items():
            if (epoch_idx+1) % self.constraints[c_idx].frequency == 0:
                for module in modules:
                    self.constraints[c_idx](module)


class ConstraintCallback(Callback):

    def __init__(self, container):
        self.container = container

    def on_batch_end(self, batch_idx, logs):
        self.container.apply_batch_constraints(batch_idx)

    def on_epoch_end(self, epoch_idx, logs):
        self.container.apply_epoch_constraints(epoch_idx)


[docs]class Constraint: """ Default class from which all Constraint implementations inherit. """ def __call__(self): raise NotImplementedError('Subclasses must implement this method')
[docs]class UnitNorm(Constraint): """ UnitNorm constraint. Constraints the weights to have column-wise unit norm """ def __init__(self, frequency=1, unit='batch', module_filter='*'): self.frequency = frequency self.unit = unit self.module_filter = module_filter def __call__(self, module): w = module.weight.data module.weight.data = w.div(th.norm(w,2,0))
[docs]class MaxNorm(Constraint): """ MaxNorm weight constraint. Constrains the weights incident to each hidden unit to have a norm less than or equal to a desired value. Any hidden unit vector with a norm less than the max norm constaint will not be altered. """ def __init__(self, value, axis=0, frequency=1, unit='batch', module_filter='*'): self.value = float(value) self.axis = axis self.frequency = frequency self.unit = unit self.module_filter = module_filter def __call__(self, module): w = module.weight.data module.weight.data = th.renorm(w, 2, self.axis, self.value)
[docs]class NonNeg(Constraint): """ Constrains the weights to be non-negative. """ def __init__(self, frequency=1, unit='batch', module_filter='*'): self.frequency = frequency self.unit = unit self.module_filter = module_filter def __call__(self, module): w = module.weight.data module.weight.data = w.gt(0).float().mul(w)