Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: inference mode for faces layers #1045

Merged
merged 3 commits into from Dec 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fix bug in `OptimizerCallback` when mixed-precision params set both:
in callback arguments and in distributed_params ([#1042](https://github.com/catalyst-team/catalyst/pull/1042))

## [20.12.1] - XXXX-XX-XX


### Added

- Inference mode for face layers ([#1045](https://github.com/catalyst-team/catalyst/pull/1045))


## [20.12] - 2020-12-20

Expand Down
32 changes: 21 additions & 11 deletions catalyst/contrib/nn/modules/arcface.py
Expand Up @@ -74,7 +74,7 @@ def __repr__(self) -> str:
return rep

def forward(
self, input: torch.Tensor, target: torch.LongTensor
self, input: torch.Tensor, target: torch.LongTensor = None
) -> torch.Tensor:
"""
Args:
Expand All @@ -85,13 +85,20 @@ def forward(
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes
(out_features).
"""
cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))

if target is None:
return cos_theta

theta = torch.acos(
torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)
)
Expand Down Expand Up @@ -187,37 +194,40 @@ def __repr__(self) -> str:
return rep

def forward(
self, input: torch.Tensor, label: torch.LongTensor
self, input: torch.Tensor, target: torch.LongTensor = None
) -> torch.Tensor:
"""
Args:
input: input features,
expected shapes ``BxF`` where ``B``
is batch dimension and ``F`` is an
input feature dimension.
label: target classes,
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.

Returns:
tensor (logits) with shapes ``BxC``
where ``C`` is a number of classes.
"""
cos_theta = torch.bmm(
F.normalize(input)
.unsqueeze(0)
.expand(self.k, *input.shape), # k*b*f
F.normalize(
self.weight, dim=1
), # normalize in_features dim # k*f*c
feats = (
F.normalize(input).unsqueeze(0).expand(self.k, *input.shape)
) # k*b*f
wght = F.normalize(self.weight, dim=1) # k*f*c
cos_theta = torch.bmm(feats, wght) # k*b*f
cos_theta = torch.max(cos_theta, dim=0)[0] # b*f
theta = torch.acos(
torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)
)

if target is None:
return cos_theta

one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

selected = torch.where(
theta > self.threshold, torch.zeros_like(one_hot), one_hot
Expand Down
16 changes: 14 additions & 2 deletions catalyst/contrib/nn/modules/cosface.py
Expand Up @@ -68,7 +68,7 @@ def __repr__(self) -> str:
return rep

def forward(
self, input: torch.Tensor, target: torch.LongTensor
self, input: torch.Tensor, target: torch.LongTensor = None
) -> torch.Tensor:
"""
Args:
Expand All @@ -79,6 +79,9 @@ def forward(
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.

Returns:
tensor (logits) with shapes ``BxC``
Expand All @@ -88,6 +91,9 @@ def forward(
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
phi = cosine - self.m

if target is None:
return cosine

one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

Expand Down Expand Up @@ -162,7 +168,7 @@ def __repr__(self) -> str:
return rep

def forward(
self, input: torch.Tensor, target: torch.LongTensor
self, input: torch.Tensor, target: torch.LongTensor = None
) -> torch.Tensor:
"""
Args:
Expand All @@ -173,6 +179,9 @@ def forward(
target: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.

Returns:
tensor (logits) with shapes ``BxC``
Expand All @@ -184,6 +193,9 @@ def forward(
torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps)
)

if target is None:
return cos_theta

one_hot = torch.zeros_like(cos_theta)
one_hot.scatter_(1, target.view(-1, 1).long(), 1)

Expand Down
8 changes: 7 additions & 1 deletion catalyst/contrib/nn/modules/curricularface.py
Expand Up @@ -79,7 +79,7 @@ def __repr__(self) -> str: # noqa: D105
return rep

def forward(
self, input: torch.Tensor, label: torch.LongTensor
self, input: torch.Tensor, label: torch.LongTensor = None
) -> torch.Tensor:
"""
Args:
Expand All @@ -90,6 +90,9 @@ def forward(
label: target classes,
expected shapes ``B`` where
``B`` is batch dimension.
If `None` then will be returned
projection on centroids.
Default is `None`.

Returns:
tensor (logits) with shapes ``BxC``
Expand All @@ -100,6 +103,9 @@ def forward(
)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability

if label is None:
return cos_theta

target_logit = cos_theta[torch.arange(0, input.size(0)), label].view(
-1, 1
)
Expand Down
33 changes: 33 additions & 0 deletions catalyst/contrib/nn/tests/test_modules.py
Expand Up @@ -5,10 +5,12 @@
import torch.nn as nn

from catalyst.contrib.nn.modules import (
AdaCos,
ArcFace,
CosFace,
CurricularFace,
SoftMax,
SubCenterArcFace,
)


Expand Down Expand Up @@ -61,6 +63,37 @@ def test_softmax():
assert np.allclose(expected, actual)


def _check_layer(layer):
embedding = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(10)

output = layer(embedding, target)
assert output.shape == (3, 10)

output = layer(embedding)
assert output.shape == (3, 10)


def test_arcface_iference_mode():
_check_layer(ArcFace(5, 10, s=1.31, m=0.5))


def test_subcenter_arcface_iference_mode():
_check_layer(SubCenterArcFace(5, 10, s=1.31, m=0.35, k=2))


def test_cosface_iference_mode():
_check_layer(CosFace(5, 10, s=1.31, m=0.1))


def test_adacos_iference_mode():
_check_layer(AdaCos(5, 10))


def test_curricularface_iference_mode():
_check_layer(CurricularFace(5, 10, s=1.31, m=0.5))


def test_arcface_with_cross_entropy_loss():
emb_size = 4
n_classes = 3
Expand Down