Skip to content

Commit

Permalink
Feature: AdaCos (#958)
Browse files Browse the repository at this point in the history
* adacos

* adacos

* fixed repr & zeros_like

* removed redundant comma
  • Loading branch information
ditwoo committed Oct 10, 2020
1 parent f9b68b6 commit cdad455
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947))
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))
- AdaCos to contrib ([#958](https://github.com/catalyst-team/catalyst/pull/958))

### Changed

Expand Down
2 changes: 1 addition & 1 deletion catalyst/contrib/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@

from catalyst.contrib.nn.modules.softmax import SoftMax
from catalyst.contrib.nn.modules.arcface import ArcFace, SubCenterArcFace
from catalyst.contrib.nn.modules.cosface import CosFace
from catalyst.contrib.nn.modules.cosface import CosFace, AdaCos
114 changes: 114 additions & 0 deletions catalyst/contrib/nn/modules/cosface.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -82,3 +84,115 @@ def forward(self, input, target):
logits *= self.s

return logits


class AdaCos(nn.Module):
"""Implementation of
`AdaCos: Adaptively Scaling Cosine Logits for \
Effectively Learning Deep Face Representations`_.
.. _AdaCos\: Adaptively Scaling Cosine Logits for \
Effectively Learning Deep Face Representations:
https://arxiv.org/abs/1905.00292
Args:
in_features: size of each input sample.
out_features: size of each output sample.
dynamical_s: option to use dynamical scale parameter.
If ``False`` then will be used initial scale.
Default: ``True``.
eps: operation accuracy.
Default: ``1e-6``.
Shape:
- Input: :math:`(batch, H_{in})` where
:math:`H_{in} = in\_features`.
- Output: :math:`(batch, H_{out})` where
:math:`H_{out} = out\_features`.
Example:
>>> layer = AdaCos(5, 10)
>>> loss_fn = nn.CrosEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(10)
>>> output = layer(embedding, target)
>>> loss = loss_fn(output, target)
>>> loss.backward()
"""

def __init__( # noqa: D107
self,
in_features: int,
out_features: int,
dynamical_s: bool = True,
eps: float = 1e-6,
):
super(AdaCos, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = math.sqrt(2) * math.log(out_features - 1)
self.eps = eps

self.weight = nn.Parameter(
torch.FloatTensor(out_features, in_features)
)
nn.init.xavier_uniform_(self.weight)

def __repr__(self) -> str:
"""Object representation."""
rep = (
"AdaCos("
f"in_features={self.in_features},"
f"out_features={self.out_features},"
f"s={self.s},"
f"eps={self.eps}"
")"
)
return rep

def forward(
self, input: torch.Tensor, target: torch.LongTensor
) -> torch.Tensor:
"""
Args:
input: input features,
expected shapes ``BxF`` where ``B``
is batch dimension and ``F`` is an
input feature dimension.
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes
(out_features).
"""
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
theta = torch.acos(
torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)
)

one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

if self.train:
with torch.no_grad():
B_avg = (
torch.where(
one_hot < 1,
torch.exp(self.s * cos_theta),
torch.zeros_like(cos_theta),
)
.sum(1)
.mean()
)
theta_median = theta[one_hot > 0].median()
theta_median = torch.min(
torch.full_like(theta_median, math.pi / 4), theta_median
)
self.s = (torch.log(B_avg) / torch.cos(theta_median)).item()

logits = self.s * cos_theta
return logits
2 changes: 1 addition & 1 deletion docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ Common modules
:undoc-members:
:show-inheritance:

CosFace: Large Margin Cosine Loss for Deep Face Recognition
CosFace and AdaCos
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
.. automodule:: catalyst.contrib.nn.modules.cosface
:members:
Expand Down

0 comments on commit cdad455

Please sign in to comment.