diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c225eb0f9..29aebc4501 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Runner registry support for Config API ([#936](https://github.com/catalyst-team/catalyst/pull/936)) +- AdamP and SGDP to `catalyst.contrib.nn.criterion` ([#942](https://github.com/catalyst-team/catalyst/pull/942)) ### Changed diff --git a/catalyst/contrib/nn/criterion/__init__.py b/catalyst/contrib/nn/criterion/__init__.py index 70d4be2f77..75a9a548d9 100644 --- a/catalyst/contrib/nn/criterion/__init__.py +++ b/catalyst/contrib/nn/criterion/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -from torch.nn.modules.loss import * +from torch.nn.modules.loss import * from catalyst.contrib.nn.criterion.ce import ( MaskCrossEntropyLoss, NaiveCrossEntropyLoss, diff --git a/catalyst/contrib/nn/optimizers/__init__.py b/catalyst/contrib/nn/optimizers/__init__.py index 1ff099a342..d492967645 100644 --- a/catalyst/contrib/nn/optimizers/__init__.py +++ b/catalyst/contrib/nn/optimizers/__init__.py @@ -1,8 +1,10 @@ # flake8: noqa from torch.optim import * +from catalyst.contrib.nn.optimizers.adamp import AdamP from catalyst.contrib.nn.optimizers.lamb import Lamb from catalyst.contrib.nn.optimizers.lookahead import Lookahead from catalyst.contrib.nn.optimizers.qhadamw import QHAdamW from catalyst.contrib.nn.optimizers.radam import RAdam from catalyst.contrib.nn.optimizers.ralamb import Ralamb +from catalyst.contrib.nn.optimizers.sgdp import SGDP diff --git a/catalyst/contrib/nn/optimizers/adamp.py b/catalyst/contrib/nn/optimizers/adamp.py new file mode 100644 index 0000000000..066085590b --- /dev/null +++ b/catalyst/contrib/nn/optimizers/adamp.py @@ -0,0 +1,199 @@ +""" +AdamP +Copyright (c) 2020-present NAVER Corp. +MIT license + +Original source code: https://github.com/clovaai/AdamP +""" + +import math + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer + + +class AdamP(Optimizer): + """Implements AdamP algorithm. + + The original Adam algorithm was proposed in + `Adam: A Method for Stochastic Optimization`_. + The AdamP variant was proposed in + `Slowing Down the Weight Norm Increase in Momentum-based Optimizers`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient + (default: 0) + delta (float): threshold that determines whether + a set of parameters is scale invariant or not (default: 0.1) + wd_ratio (float): relative weight decay applied on scale-invariant + parameters compared to that applied on scale-variant parameters + (default: 0.1) + nesterov (boolean, optional): enables Nesterov momentum + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Slowing Down the Weight Norm Increase in Momentum-based Optimizers: + https://arxiv.org/abs/2006.08217 + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + delta=0.1, + wd_ratio=0.1, + nesterov=False, + ): + """ + + Args: + params (iterable): iterable of parameters to optimize + or dicts defining parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients + used for computing running averages of gradient + and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient + (default: 1e-2) + delta (float): threshold that determines whether + a set of parameters is scale invariant or not (default: 0.1) + wd_ratio (float): relative weight decay applied on scale-invariant + parameters compared to that applied on scale-variant parameters + (default: 0.1) + nesterov (boolean, optional): enables Nesterov momentum + (default: False) + """ + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + delta=delta, + wd_ratio=wd_ratio, + nesterov=nesterov, + ) + super(AdamP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + return F.cosine_similarity(x, y, dim=1, eps=eps).abs_() + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view( + expand_size + ).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( + expand_size + ) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + """ + Performs a single optimization step (parameter update). + + Arguments: + closure (callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + Returns: + computed loss + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad.data + beta1, beta2 = group["betas"] + nesterov = group["nesterov"] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p.data) + state["exp_avg_sq"] = torch.zeros_like(p.data) + + # Adam + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( + group["eps"] + ) + step_size = group["lr"] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + perturb, wd_ratio = self._projection( + p, + grad, + perturb, + group["delta"], + group["wd_ratio"], + group["eps"], + ) + + # Weight decay + if group["weight_decay"] > 0: + p.data.mul_( + 1 - group["lr"] * group["weight_decay"] * wd_ratio + ) + + # Step + p.data.add_(perturb, alpha=-step_size) + + return loss + + +__all__ = ["AdamP"] diff --git a/catalyst/contrib/nn/optimizers/sgdp.py b/catalyst/contrib/nn/optimizers/sgdp.py new file mode 100644 index 0000000000..e58d297d7e --- /dev/null +++ b/catalyst/contrib/nn/optimizers/sgdp.py @@ -0,0 +1,184 @@ +""" +AdamP +Copyright (c) 2020-present NAVER Corp. +MIT license + +Original source code: https://github.com/clovaai/AdamP +""" + +import math + +import torch +import torch.nn.functional as F +from torch.optim.optimizer import Optimizer, required + + +class SGDP(Optimizer): + """Implements SGDP algorithm. + + The SGDP variant was proposed in + `Slowing Down the Weight Norm Increase in Momentum-based Optimizers`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + delta (float): threshold that determines whether + a set of parameters is scale invariant or not (default: 0.1) + wd_ratio (float): relative weight decay applied on scale-invariant + parameters compared to that applied on scale-variant parameters + (default: 0.1) + + .. _Slowing Down the Weight Norm Increase in Momentum-based Optimizers: + https://arxiv.org/abs/2006.08217 + """ + + def __init__( + self, + params, + lr=required, + momentum=0, + weight_decay=0, + dampening=0, + nesterov=False, + eps=1e-8, + delta=0.1, + wd_ratio=0.1, + ): + """ + + Args: + params (iterable): iterable of parameters to optimize + or dicts defining parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) + (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum + (default: False) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + delta (float): threshold that determines whether + a set of parameters is scale invariant or not (default: 0.1) + wd_ratio (float): relative weight decay applied on scale-invariant + parameters compared to that applied on scale-variant parameters + (default: 0.1) + """ + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + eps=eps, + delta=delta, + wd_ratio=wd_ratio, + ) + super(SGDP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + return F.cosine_similarity(x, y, dim=1, eps=eps).abs_() + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view( + expand_size + ).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( + expand_size + ) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + """ + Performs a single optimization step (parameter update). + + Arguments: + closure (callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + Returns: + computed loss + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state["momentum"] = torch.zeros_like(p.data) + + # SGD + buf = state["momentum"] + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + d_p, wd_ratio = self._projection( + p, + grad, + d_p, + group["delta"], + group["wd_ratio"], + group["eps"], + ) + + # Weight decay + if group["weight_decay"] > 0: + p.data.mul_( + 1 + - group["lr"] + * group["weight_decay"] + * wd_ratio + / (1 - momentum) + ) + + # Step + p.data.add_(d_p, alpha=-group["lr"]) + + return loss + + +__all__ = ["SGDP"] diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index 7039a8a812..d2e1d8144c 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -286,10 +286,16 @@ SqueezeAndExcitation :show-inheritance: - Optimizers ~~~~~~~~~~~~~~~~ +AdamP +""""""""""""" +.. automodule:: catalyst.contrib.nn.optimizers.adamp + :members: + :undoc-members: + :show-inheritance: + Lamb """""""""""""""""""" .. automodule:: catalyst.contrib.nn.optimizers.lamb @@ -325,6 +331,13 @@ Ralamb :undoc-members: :show-inheritance: +SGDP +""""""""""""" +.. automodule:: catalyst.contrib.nn.optimizers.sgdp + :members: + :undoc-members: + :show-inheritance: + Schedulers ~~~~~~~~~~~~~~~~