Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: SoftMax, CosFace, ArcFace layers #939

Merged
merged 25 commits into from Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Runner registry support for Config API ([#936](https://github.com/catalyst-team/catalyst/pull/936))
- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939))

### Changed

Expand Down
4 changes: 4 additions & 0 deletions catalyst/contrib/nn/modules/__init__.py
Expand Up @@ -31,3 +31,7 @@
scSE,
cSE,
)

from catalyst.contrib.nn.modules.softmax import SoftMax
from catalyst.contrib.nn.modules.arcface import ArcFace
from catalyst.contrib.nn.modules.cosface import CosFace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as this is criterions, could we move them to contrib/nn/criterion? what do you think about this?

87 changes: 87 additions & 0 deletions catalyst/contrib/nn/modules/arcface.py
@@ -0,0 +1,87 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F


class ArcFace(nn.Module):
"""Implementation of ArcFace loss for metric learning.

.. _ArcFace: Additive Angular Margin Loss for Deep Face Recognition:
https://arxiv.org/abs/1801.07698v1

Example:
>>> layer = ArcFace(5, 10, s=1.31, m=0.5)
>>> loss_fn = nn.CrosEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = layer(embedding, target)
>>> loss = loss_fn(output, target)
>>> loss.backward()

"""

def __init__(
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
self,
in_features: int,
out_features: int,
s: float = 64.0,
m: float = 0.5,
):
"""
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
s (float, optional): norm of input feature,
Default: ``64.0``.
m (float, optional): margin.
Default: ``0.5``.
"""
super(ArcFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m

self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.th = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m

self.weight = nn.Parameter(
torch.FloatTensor(out_features, in_features)
)
nn.init.xavier_uniform_(self.weight)

def __repr__(self) -> str:
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
"""Object representation."""
return "ArcFace(in_features={},out_features={},s={},m={})".format(
self.in_features, self.out_features, self.s, self.m
)

def forward(
self, input: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): input features,
expected shapes BxF.
target (torch.Tensor): target classes,
expected shapes B.

Returns:
torch.Tensor with loss value.
"""
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m

phi = torch.where(cosine > self.th, phi, cosine - self.mm)

one_hot = torch.zeros(cosine.size()).to(input.device)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
logits = (one_hot * phi) + ((1.0 - one_hot) * cosine)
logits *= self.s

return logits
76 changes: 76 additions & 0 deletions catalyst/contrib/nn/modules/cosface.py
@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class CosFace(nn.Module):
"""Implementation of CosFace loss for metric learning.

.. _CosFace: Large Margin Cosine Loss for Deep Face Recognition:
https://arxiv.org/abs/1801.09414

Example:
>>> layer = CosFaceLoss(5, 10, s=1.31, m=0.1)
>>> loss_fn = nn.CrosEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = layer(embedding, target)
>>> loss = loss_fn(output, target)
>>> loss.backward()

"""

def __init__(
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
self,
in_features: int,
out_features: int,
s: float = 64.0,
m: float = 0.35,
):
"""
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
s (float, optional): norm of input feature,
Default: ``64.0``.
m (float, optional): margin.
Default: ``0.35``.
"""
super(CosFace, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.s = s
self.m = m

self.weight = nn.Parameter(
torch.FloatTensor(out_features, in_features)
)
nn.init.xavier_uniform_(self.weight)

def __repr__(self) -> str:
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
"""Object representation."""
return "CosFace(in_features={},out_features={},s={},m={})".format(
self.in_features, self.out_features, self.s, self.m
)

def forward(
self, input: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): input features,
expected shapes BxF.
target (torch.Tensor): target classes,
expected shapes B.

Returns:
torch.Tensor with loss value.
"""
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
phi = cosine - self.m
one_hot = torch.zeros(cosine.size()).to(input.device)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)
logits = (one_hot * phi) + ((1.0 - one_hot) * cosine)
logits *= self.s

return logits
50 changes: 50 additions & 0 deletions catalyst/contrib/nn/modules/softmax.py
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftMax(nn.Module):
"""Implementation of SoftMax head for metric learning.
ditwoo marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> layer = SoftMax()
>>> loss_fn = nn.CrosEntropyLoss()
>>> embedding = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = layer(embedding, target)
>>> loss = loss_fn(output, target)
>>> loss.backward()

"""

def __init__(self, in_features: int, num_classes: int):
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
in_features (int): size of each input sample.
num_classes (int): size of each output sample.
"""
super(SoftMax, self).__init__()
self.in_features = in_features
self.out_features = num_classes
self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
self.bias = nn.Parameter(torch.FloatTensor(num_classes))

nn.init.xavier_uniform_(self.weight)
nn.init.zeros_(self.bias)

def __repr__(self) -> str:
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
"""Object representation."""
return "SoftMax(in_features={},out_features={})".format(
self.in_features, self.out_features
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): input features,
expected shapes BxF.

Returns:
torch.Tensor with loss value.
"""
return F.linear(input, self.weight, self.bias)
5 changes: 5 additions & 0 deletions catalyst/contrib/nn/tests/test_criterion.py
@@ -1,3 +1,8 @@
# flake8: noqa
import numpy as np
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
ditwoo marked this conversation as resolved.
Show resolved Hide resolved

import torch

from catalyst.contrib.nn import criterion as module
ditwoo marked this conversation as resolved.
Show resolved Hide resolved
from catalyst.contrib.nn.criterion import (
CircleLoss,
Expand Down