Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Arc Margin Product #957

Merged
merged 10 commits into from
Oct 10, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947))
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))
- ArcMargin layer to contrib ([#957](https://github.com/catalyst-team/catalyst/pull/957))

### Changed

Expand Down
1 change: 1 addition & 0 deletions catalyst/contrib/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
from catalyst.contrib.nn.modules.softmax import SoftMax
from catalyst.contrib.nn.modules.arcface import ArcFace, SubCenterArcFace
from catalyst.contrib.nn.modules.cosface import CosFace
from catalyst.contrib.nn.modules.arcmargin import ArcMarginProduct
55 changes: 27 additions & 28 deletions catalyst/contrib/nn/modules/arcface.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,20 @@ def __init__( # noqa: D107

def __repr__(self) -> str:
"""Object representation."""
return (
rep = (
"ArcFace("
+ ",".join(
[
f"in_features={self.in_features}",
f"out_features={self.out_features}",
f"s={self.s}",
f"m={self.m}",
f"eps={self.eps}",
]
)
+ ")"
f"in_features={self.in_features},"
f"out_features={self.out_features},"
f"s={self.s},"
f"m={self.m},"
f"eps={self.eps}"
")"
)
return rep

def forward(self, input, target):
def forward(
self, input: torch.Tensor, target: torch.LongTensor
) -> torch.Tensor:
"""
Args:
input: input features,
Expand All @@ -89,14 +88,15 @@ def forward(self, input, target):

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes.
where ``C`` is a number of classes
(out_features).
"""
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
theta = torch.acos(
torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)
)

one_hot = torch.zeros_like(cos_theta, device=input.device)
one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

mask = torch.where(
Expand Down Expand Up @@ -174,22 +174,21 @@ def __init__( # noqa: D107

def __repr__(self) -> str:
"""Object representation."""
return (
rep = (
"SubCenterArcFace("
+ ",".join(
[
f"in_features={self.in_features}",
f"out_features={self.out_features}",
f"s={self.s}",
f"m={self.m}",
f"k={self.k}",
f"eps={self.eps}",
]
)
+ ")"
f"in_features={self.in_features},"
f"out_features={self.out_features},"
f"s={self.s},"
f"m={self.m},"
f"k={self.k},"
f"eps={self.eps}"
")"
)
return rep

def forward(self, input, label):
def forward(
self, input: torch.Tensor, label: torch.LongTensor
) -> torch.Tensor:
"""
Args:
input: input features,
Expand Down Expand Up @@ -217,7 +216,7 @@ def forward(self, input, label):
torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)
)

one_hot = torch.zeros(cos_theta.size()).to(input.device)
one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)

selected = torch.where(
Expand Down
58 changes: 58 additions & 0 deletions catalyst/contrib/nn/modules/arcmargin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class ArcMarginProduct(nn.Module):
"""Implementation of Arc Margin Product.

Args:
in_features: size of each input sample.
out_features: size of each output sample.

Shape:
- Input: :math:`(batch, H_{in})` where
:math:`H_{in} = in\_features`.
- Output: :math:`(batch, H_{out})` where
:math:`H_{out} = out\_features`.

Example:
>>> layer = ArcMarginProduct(5, 10)
>>> loss_fn = nn.CrosEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(10)
>>> output = layer(embedding)
>>> loss = loss_fn(output, target)
>>> loss.backward()

"""

def __init__(self, in_features: int, out_features: int): # noqa: D107
super(ArcMarginProduct, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)

def __repr__(self) -> str:
"""Object representation."""
rep = "ArcMarginProduct(in_features={},out_features={})".format(
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
self.in_features, self.out_features
)
return rep

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input: input features,
expected shapes ``BxF`` where ``B``
is batch dimension and ``F`` is an
input feature dimension.

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes
(out_features).
"""
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
return cosine
14 changes: 10 additions & 4 deletions catalyst/contrib/nn/modules/cosface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ def __init__( # noqa: D107

def __repr__(self) -> str:
"""Object representation."""
return "CosFace(in_features={},out_features={},s={},m={})".format(
rep = "CosFace(in_features={},out_features={},s={},m={})".format(
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
self.in_features, self.out_features, self.s, self.m
)
return rep

def forward(self, input, target):
def forward(
self, input: torch.Tensor, target: torch.LongTensor
) -> torch.Tensor:
"""
Args:
input: input features,
Expand All @@ -72,12 +75,15 @@ def forward(self, input, target):

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes.
where ``C`` is a number of classes
(out_features).
"""
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
phi = cosine - self.m
one_hot = torch.zeros(cosine.size()).to(input.device)

one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

logits = (one_hot * phi) + ((1.0 - one_hot) * cosine)
logits *= self.s

Expand Down
8 changes: 5 additions & 3 deletions catalyst/contrib/nn/modules/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ def __init__(self, in_features: int, num_classes: int): # noqa: D107

def __repr__(self) -> str:
"""Object representation."""
return "SoftMax(in_features={},out_features={})".format(
rep = "SoftMax(in_features={},out_features={})".format(
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
self.in_features, self.out_features
)
return rep

def forward(self, input):
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input: input features,
Expand All @@ -59,6 +60,7 @@ def forward(self, input):

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes.
where ``C`` is a number of classes
(out_features).
"""
return F.linear(input, self.weight, self.bias)
7 changes: 7 additions & 0 deletions docs/api/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ ArcFace and SubCenterArcFace
:undoc-members:
:show-inheritance:

Arc Margin Product
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
.. automodule:: catalyst.contrib.nn.modules.arcmargin
:members:
:undoc-members:
:show-inheritance:

Common modules
""""""""""""""""""""""""""""""""""""""""""
.. automodule:: catalyst.contrib.nn.modules.common
Expand Down