diff --git a/CHANGELOG.md b/CHANGELOG.md index dbc83f2128..164505ecbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fix bug in `OptimizerCallback` when mixed-precision params set both: in callback arguments and in distributed_params ([#1042](https://github.com/catalyst-team/catalyst/pull/1042)) +## [20.12.1] - XXXX-XX-XX + + +### Added + +- Inference mode for face layers ([#1045](https://github.com/catalyst-team/catalyst/pull/1045)) + ## [20.12] - 2020-12-20 diff --git a/catalyst/contrib/nn/modules/arcface.py b/catalyst/contrib/nn/modules/arcface.py index 1fa08aaf78..f0e45c9633 100644 --- a/catalyst/contrib/nn/modules/arcface.py +++ b/catalyst/contrib/nn/modules/arcface.py @@ -74,7 +74,7 @@ def __repr__(self) -> str: return rep def forward( - self, input: torch.Tensor, target: torch.LongTensor + self, input: torch.Tensor, target: torch.LongTensor = None ) -> torch.Tensor: """ Args: @@ -85,6 +85,9 @@ def forward( target: target classes, expected shapes ``B`` where ``B`` is batch dimension. + If `None` then will be returned + projection on centroids. + Default is `None`. Returns: tensor (logits) with shapes ``BxC`` @@ -92,6 +95,10 @@ def forward( (out_features). """ cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) + + if target is None: + return cos_theta + theta = torch.acos( torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps) ) @@ -187,7 +194,7 @@ def __repr__(self) -> str: return rep def forward( - self, input: torch.Tensor, label: torch.LongTensor + self, input: torch.Tensor, target: torch.LongTensor = None ) -> torch.Tensor: """ Args: @@ -195,29 +202,32 @@ def forward( expected shapes ``BxF`` where ``B`` is batch dimension and ``F`` is an input feature dimension. - label: target classes, + target: target classes, expected shapes ``B`` where ``B`` is batch dimension. + If `None` then will be returned + projection on centroids. + Default is `None`. Returns: tensor (logits) with shapes ``BxC`` where ``C`` is a number of classes. """ - cos_theta = torch.bmm( - F.normalize(input) - .unsqueeze(0) - .expand(self.k, *input.shape), # k*b*f - F.normalize( - self.weight, dim=1 - ), # normalize in_features dim # k*f*c + feats = ( + F.normalize(input).unsqueeze(0).expand(self.k, *input.shape) ) # k*b*f + wght = F.normalize(self.weight, dim=1) # k*f*c + cos_theta = torch.bmm(feats, wght) # k*b*f cos_theta = torch.max(cos_theta, dim=0)[0] # b*f theta = torch.acos( torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps) ) + if target is None: + return cos_theta + one_hot = torch.zeros_like(cos_theta) - one_hot.scatter_(1, label.view(-1, 1).long(), 1) + one_hot.scatter_(1, target.view(-1, 1).long(), 1) selected = torch.where( theta > self.threshold, torch.zeros_like(one_hot), one_hot diff --git a/catalyst/contrib/nn/modules/cosface.py b/catalyst/contrib/nn/modules/cosface.py index da3acaeb51..2ed28bc300 100644 --- a/catalyst/contrib/nn/modules/cosface.py +++ b/catalyst/contrib/nn/modules/cosface.py @@ -68,7 +68,7 @@ def __repr__(self) -> str: return rep def forward( - self, input: torch.Tensor, target: torch.LongTensor + self, input: torch.Tensor, target: torch.LongTensor = None ) -> torch.Tensor: """ Args: @@ -79,6 +79,9 @@ def forward( target: target classes, expected shapes ``B`` where ``B`` is batch dimension. + If `None` then will be returned + projection on centroids. + Default is `None`. Returns: tensor (logits) with shapes ``BxC`` @@ -88,6 +91,9 @@ def forward( cosine = F.linear(F.normalize(input), F.normalize(self.weight)) phi = cosine - self.m + if target is None: + return cosine + one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, target.view(-1, 1).long(), 1) @@ -162,7 +168,7 @@ def __repr__(self) -> str: return rep def forward( - self, input: torch.Tensor, target: torch.LongTensor + self, input: torch.Tensor, target: torch.LongTensor = None ) -> torch.Tensor: """ Args: @@ -173,6 +179,9 @@ def forward( target: target classes, expected shapes ``B`` where ``B`` is batch dimension. + If `None` then will be returned + projection on centroids. + Default is `None`. Returns: tensor (logits) with shapes ``BxC`` @@ -184,6 +193,9 @@ def forward( torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps) ) + if target is None: + return cos_theta + one_hot = torch.zeros_like(cos_theta) one_hot.scatter_(1, target.view(-1, 1).long(), 1) diff --git a/catalyst/contrib/nn/modules/curricularface.py b/catalyst/contrib/nn/modules/curricularface.py index fbd7fe74e9..11f8342298 100644 --- a/catalyst/contrib/nn/modules/curricularface.py +++ b/catalyst/contrib/nn/modules/curricularface.py @@ -79,7 +79,7 @@ def __repr__(self) -> str: # noqa: D105 return rep def forward( - self, input: torch.Tensor, label: torch.LongTensor + self, input: torch.Tensor, label: torch.LongTensor = None ) -> torch.Tensor: """ Args: @@ -90,6 +90,9 @@ def forward( label: target classes, expected shapes ``B`` where ``B`` is batch dimension. + If `None` then will be returned + projection on centroids. + Default is `None`. Returns: tensor (logits) with shapes ``BxC`` @@ -100,6 +103,9 @@ def forward( ) cos_theta = cos_theta.clamp(-1, 1) # for numerical stability + if label is None: + return cos_theta + target_logit = cos_theta[torch.arange(0, input.size(0)), label].view( -1, 1 ) diff --git a/catalyst/contrib/nn/tests/test_modules.py b/catalyst/contrib/nn/tests/test_modules.py index cd1d0110b5..f0f44f4bc9 100644 --- a/catalyst/contrib/nn/tests/test_modules.py +++ b/catalyst/contrib/nn/tests/test_modules.py @@ -5,10 +5,12 @@ import torch.nn as nn from catalyst.contrib.nn.modules import ( + AdaCos, ArcFace, CosFace, CurricularFace, SoftMax, + SubCenterArcFace, ) @@ -61,6 +63,37 @@ def test_softmax(): assert np.allclose(expected, actual) +def _check_layer(layer): + embedding = torch.randn(3, 5, requires_grad=True) + target = torch.empty(3, dtype=torch.long).random_(10) + + output = layer(embedding, target) + assert output.shape == (3, 10) + + output = layer(embedding) + assert output.shape == (3, 10) + + +def test_arcface_iference_mode(): + _check_layer(ArcFace(5, 10, s=1.31, m=0.5)) + + +def test_subcenter_arcface_iference_mode(): + _check_layer(SubCenterArcFace(5, 10, s=1.31, m=0.35, k=2)) + + +def test_cosface_iference_mode(): + _check_layer(CosFace(5, 10, s=1.31, m=0.1)) + + +def test_adacos_iference_mode(): + _check_layer(AdaCos(5, 10)) + + +def test_curricularface_iference_mode(): + _check_layer(CurricularFace(5, 10, s=1.31, m=0.5)) + + def test_arcface_with_cross_entropy_loss(): emb_size = 4 n_classes = 3