diff --git a/CHANGELOG.md b/CHANGELOG.md index ed5da09cfd..b39b6925a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947)) +- SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939)) ### Changed @@ -34,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Runner registry support for Config API ([#936](https://github.com/catalyst-team/catalyst/pull/936)) + - `catalyst-dl tune` command - Optuna with Config API integration for AutoML hyperparameters optimization ([#937](https://github.com/catalyst-team/catalyst/pull/937)) - `OptunaPruningCallback` alias for `OptunaCallback` ([#937](https://github.com/catalyst-team/catalyst/pull/937)) - AdamP and SGDP to `catalyst.contrib.nn.criterion` ([#942](https://github.com/catalyst-team/catalyst/pull/942)) diff --git a/catalyst/contrib/nn/modules/__init__.py b/catalyst/contrib/nn/modules/__init__.py index c4f4d015eb..8ec4d226c7 100644 --- a/catalyst/contrib/nn/modules/__init__.py +++ b/catalyst/contrib/nn/modules/__init__.py @@ -31,3 +31,7 @@ scSE, cSE, ) + +from catalyst.contrib.nn.modules.softmax import SoftMax +from catalyst.contrib.nn.modules.arcface import ArcFace, SubCenterArcFace +from catalyst.contrib.nn.modules.cosface import CosFace diff --git a/catalyst/contrib/nn/modules/arcface.py b/catalyst/contrib/nn/modules/arcface.py new file mode 100644 index 0000000000..74132f4be1 --- /dev/null +++ b/catalyst/contrib/nn/modules/arcface.py @@ -0,0 +1,230 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ArcFace(nn.Module): + """Implementation of + `ArcFace: Additive Angular Margin Loss for Deep Face Recognition`_. + + .. _ArcFace\: Additive Angular Margin Loss for Deep Face Recognition: + https://arxiv.org/abs/1801.07698v1 + + Args: + in_features: size of each input sample. + out_features: size of each output sample. + s: norm of input feature. + Default: ``64.0``. + m: margin. + Default: ``0.5``. + eps: operation accuracy. + Default: ``1e-6``. + + Shape: + - Input: :math:`(batch, H_{in})` where + :math:`H_{in} = in\_features`. + - Output: :math:`(batch, H_{out})` where + :math:`H_{out} = out\_features`. + + Example: + >>> layer = ArcFace(5, 10, s=1.31, m=0.5) + >>> loss_fn = nn.CrosEntropyLoss() + >>> embedding = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(10) + >>> output = layer(embedding, target) + >>> loss = loss_fn(output, target) + >>> loss.backward() + + """ + + def __init__( # noqa: D107 + self, + in_features: int, + out_features: int, + s: float = 64.0, + m: float = 0.5, + eps: float = 1e-6, + ): + super(ArcFace, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.s = s + self.m = m + self.threshold = math.pi - m + self.eps = eps + + self.weight = nn.Parameter( + torch.FloatTensor(out_features, in_features) + ) + nn.init.xavier_uniform_(self.weight) + + def __repr__(self) -> str: + """Object representation.""" + return ( + "ArcFace(" + + ",".join( + [ + f"in_features={self.in_features}", + f"out_features={self.out_features}", + f"s={self.s}", + f"m={self.m}", + f"eps={self.eps}", + ] + ) + + ")" + ) + + def forward(self, input, target): + """ + Args: + input (torch.Tensor): input features, + expected shapes ``BxF`` where ``B`` + is batch dimension and ``F`` is an + input feature dimension. + target (torch.Tensor): target classes, + expected shapes ``B`` where + ``B`` is batch dimension. + + Returns: + tensor (logits) with shapes ``BxC`` + where ``C`` is a number of classes. + """ + cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) + theta = torch.acos( + torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps) + ) + + one_hot = torch.zeros_like(cos_theta, device=input.device) + one_hot.scatter_(1, target.view(-1, 1).long(), 1) + + mask = torch.where( + theta > self.threshold, torch.zeros_like(one_hot), one_hot + ) + + logits = torch.cos(torch.where(mask.bool(), theta + self.m, theta)) + logits *= self.s + + return logits + + +class SubCenterArcFace(nn.Module): + """Implementation of + `Sub-center ArcFace: Boosting Face Recognition + by Large-scale Noisy Web Faces`_. + + .. _Sub-center ArcFace\: Boosting Face Recognition \ + by Large-scale Noisy Web Faces: + https://ibug.doc.ic.ac.uk/media/uploads/documents/eccv_1445.pdf + + Args: + in_features: size of each input sample. + out_features: size of each output sample. + s: norm of input feature, + Default: ``64.0``. + m: margin. + Default: ``0.5``. + k: number of possible class centroids. + Default: ``3``. + eps (float, optional): operation accuracy. + Default: ``1e-6``. + + Shape: + - Input: :math:`(batch, H_{in})` where + :math:`H_{in} = in\_features`. + - Output: :math:`(batch, H_{out})` where + :math:`H_{out} = out\_features`. + + Example: + >>> layer = SubCenterArcFace(5, 10, s=1.31, m=0.35, k=2) + >>> loss_fn = nn.CrosEntropyLoss() + >>> embedding = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(10) + >>> output = layer(embedding, target) + >>> loss = loss_fn(output, target) + >>> loss.backward() + + """ + + def __init__( # noqa: D107 + self, + in_features: int, + out_features: int, + s: float = 64.0, + m: float = 0.5, + k: int = 3, + eps: float = 1e-6, + ): + super(SubCenterArcFace, self).__init__() + self.in_features = in_features + self.out_features = out_features + + self.s = s + self.m = m + self.k = k + self.eps = eps + + self.weight = nn.Parameter( + torch.FloatTensor(k, in_features, out_features) + ) + nn.init.xavier_uniform_(self.weight) + + self.threshold = math.pi - self.m + + def __repr__(self) -> str: + """Object representation.""" + return ( + "SubCenterArcFace(" + + ",".join( + [ + f"in_features={self.in_features}", + f"out_features={self.out_features}", + f"s={self.s}", + f"m={self.m}", + f"k={self.k}", + f"eps={self.eps}", + ] + ) + + ")" + ) + + def forward(self, input, label): + """ + Args: + input (torch.Tensor): input features, + expected shapes ``BxF`` where ``B`` + is batch dimension and ``F`` is an + input feature dimension. + label (torch.Tensor): target classes, + expected shapes ``B`` where + ``B`` is batch dimension. + + Returns: + tensor (logits) with shapes ``BxC`` + where ``C`` is a number of classes. + """ + cos_theta = torch.bmm( + F.normalize(input) + .unsqueeze(0) + .expand(self.k, *input.shape), # k*b*f + F.normalize( + self.weight, dim=1 + ), # normalize in_features dim # k*f*c + ) # k*b*f + cos_theta = torch.max(cos_theta, dim=0)[0] # b*f + theta = torch.acos( + torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps) + ) + + one_hot = torch.zeros(cos_theta.size()).to(input.device) + one_hot.scatter_(1, label.view(-1, 1).long(), 1) + + selected = torch.where( + theta > self.threshold, torch.zeros_like(one_hot), one_hot + ) + + logits = torch.cos(torch.where(selected.bool(), theta + self.m, theta)) + logits *= self.s + + return logits diff --git a/catalyst/contrib/nn/modules/cosface.py b/catalyst/contrib/nn/modules/cosface.py new file mode 100644 index 0000000000..7e2ceb02f9 --- /dev/null +++ b/catalyst/contrib/nn/modules/cosface.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CosFace(nn.Module): + """Implementation of + `CosFace\: Large Margin Cosine Loss for Deep Face Recognition`_. + + .. _CosFace\: Large Margin Cosine Loss for Deep Face Recognition: + https://arxiv.org/abs/1801.09414 + + Args: + in_features: size of each input sample. + out_features: size of each output sample. + s: norm of input feature. + Default: ``64.0``. + m: margin. + Default: ``0.35``. + + Shape: + - Input: :math:`(batch, H_{in})` where + :math:`H_{in} = in\_features`. + - Output: :math:`(batch, H_{out})` where + :math:`H_{out} = out\_features`. + + Example: + >>> layer = CosFaceLoss(5, 10, s=1.31, m=0.1) + >>> loss_fn = nn.CrosEntropyLoss() + >>> embedding = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(10) + >>> output = layer(embedding, target) + >>> loss = loss_fn(output, target) + >>> loss.backward() + + """ + + def __init__( # noqa: D107 + self, + in_features: int, + out_features: int, + s: float = 64.0, + m: float = 0.35, + ): + super(CosFace, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.s = s + self.m = m + + self.weight = nn.Parameter( + torch.FloatTensor(out_features, in_features) + ) + nn.init.xavier_uniform_(self.weight) + + def __repr__(self) -> str: + """Object representation.""" + return "CosFace(in_features={},out_features={},s={},m={})".format( + self.in_features, self.out_features, self.s, self.m + ) + + def forward(self, input, target): + """ + Args: + input (torch.Tensor): input features, + expected shapes ``BxF`` where ``B`` + is batch dimension and ``F`` is an + input feature dimension. + target (torch.Tensor): target classes, + expected shapes ``B`` where + ``B`` is batch dimension. + + Returns: + tensor (logits) with shapes ``BxC`` + where ``C`` is a number of classes. + """ + cosine = F.linear(F.normalize(input), F.normalize(self.weight)) + phi = cosine - self.m + one_hot = torch.zeros(cosine.size()).to(input.device) + one_hot.scatter_(1, target.view(-1, 1).long(), 1) + logits = (one_hot * phi) + ((1.0 - one_hot) * cosine) + logits *= self.s + + return logits diff --git a/catalyst/contrib/nn/modules/softmax.py b/catalyst/contrib/nn/modules/softmax.py new file mode 100644 index 0000000000..7404b7f230 --- /dev/null +++ b/catalyst/contrib/nn/modules/softmax.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SoftMax(nn.Module): + """Implementation of + `Significance of Softmax-based Features in Comparison to + Distance Metric Learning-based Features`_. + + .. _Significance of Softmax-based Features in Comparison to \ + Distance Metric Learning-based Features: + https://arxiv.org/abs/1712.10151 + + Args: + in_features: size of each input sample. + out_features: size of each output sample. + + Shape: + - Input: :math:`(batch, H_{in})` where + :math:`H_{in} = in\_features`. + - Output: :math:`(batch, H_{out})` where + :math:`H_{out} = out\_features`. + + Example: + >>> layer = SoftMax(5, 10) + >>> loss_fn = nn.CrosEntropyLoss() + >>> embedding = torch.randn(3, 5, requires_grad=True) + >>> target = torch.empty(3, dtype=torch.long).random_(10) + >>> output = layer(embedding, target) + >>> loss = loss_fn(output, target) + >>> loss.backward() + + """ + + def __init__(self, in_features: int, num_classes: int): # noqa: D107 + super(SoftMax, self).__init__() + self.in_features = in_features + self.out_features = num_classes + self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features)) + self.bias = nn.Parameter(torch.FloatTensor(num_classes)) + + nn.init.xavier_uniform_(self.weight) + nn.init.zeros_(self.bias) + + def __repr__(self) -> str: + """Object representation.""" + return "SoftMax(in_features={},out_features={})".format( + self.in_features, self.out_features + ) + + def forward(self, input): + """ + Args: + input (torch.Tensor): input features, + expected shapes ``BxF`` where ``B`` + is batch dimension and ``F`` is an + input feature dimension. + + Returns: + tensor (logits) with shapes ``BxC`` + where ``C`` is a number of classes. + """ + return F.linear(input, self.weight, self.bias) diff --git a/catalyst/contrib/nn/tests/test_criterion.py b/catalyst/contrib/nn/tests/test_criterion.py index 01cadc42cb..f230071790 100644 --- a/catalyst/contrib/nn/tests/test_criterion.py +++ b/catalyst/contrib/nn/tests/test_criterion.py @@ -1,3 +1,8 @@ +# flake8: noqa +import numpy as np + +import torch + from catalyst.contrib.nn import criterion as module from catalyst.contrib.nn.criterion import ( CircleLoss, diff --git a/catalyst/contrib/nn/tests/test_modules.py b/catalyst/contrib/nn/tests/test_modules.py new file mode 100644 index 0000000000..a99ecd8140 --- /dev/null +++ b/catalyst/contrib/nn/tests/test_modules.py @@ -0,0 +1,211 @@ +# flake8: noqa +import numpy as np + +import torch +import torch.nn as nn + +from catalyst.contrib.nn.modules import ArcFace, CosFace, SoftMax + + +def normalize(m: np.ndarray) -> np.ndarray: + m_s = np.sqrt((m ** 2).sum(axis=1))[:, np.newaxis] # for each row + return m / m_s + + +def softmax(x: np.ndarray) -> np.ndarray: + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum(1)[:, np.newaxis] # for each row + + +def cross_entropy( + preds: np.ndarray, targs: np.ndarray, axis: int = 1 +) -> float: + return -(targs * np.log(softmax(preds))).sum(axis) + + +def test_softmax(): + emb_size = 4 + n_classes = 3 + + # fmt: off + features = np.array( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype="f", + ) + target = np.array([0, 2], dtype="l") + weight = np.array( + [ + [0.1, 0.2, 0.3, 0.4], + [1.1, 3.2, 5.3, 0.4], + [0.1, 0.2, 6.3, 0.4], + ], + dtype="f", + ) + bias = np.array([0.2, 0.01, 0.1], dtype="f") + # fmt: on + + layer = SoftMax(emb_size, n_classes) + layer.weight.data = torch.from_numpy(weight) + layer.bias.data = torch.from_numpy(bias) + + expected = features @ weight.T + bias + actual = layer(torch.from_numpy(features)).detach().numpy() + assert np.allclose(expected, actual) + + +def test_arcface_with_cross_entropy_loss(): + emb_size = 4 + n_classes = 3 + s = 3.0 + m = 0.5 + eps = 1e-8 + + # fmt: off + features = np.array( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype="f", + ) + target = np.array([0, 2], dtype="l") + weight = np.array( + [ + [0.1, 0.2, 0.3, 0.4], + [1.1, 3.2, 5.3, 0.4], + [0.1, 0.2, 6.3, 0.4], + ], + dtype="f", + ) + # fmt: on + + layer = ArcFace(emb_size, n_classes, s, m, eps) + layer.weight.data = torch.from_numpy(weight) + loss_fn = nn.CrossEntropyLoss(reduction="none") + + normalized_features = normalize(features) # 2x4 + normalized_projection = normalize(weight) # 3x4 + + cosine = normalized_features @ normalized_projection.T # 2x4 * 4x3 = 2x3 + theta = np.arccos(np.clip(cosine, -1 + eps, 1 - eps)) # 2x3 + + # one_hot(target) + mask = np.array([[1, 0, 0], [0, 0, 1]], dtype="l") + mask = np.where(theta > (np.pi - m), np.zeros_like(mask), mask) # 2x3 + feats = np.cos(np.where(mask > 0, theta + m, theta)) * s # 2x3 + + expected_loss = cross_entropy(feats, mask, 1) + actual = ( + loss_fn( + layer(torch.from_numpy(features), torch.LongTensor(target)), + torch.LongTensor(target), + ) + .detach() + .numpy() + ) + assert np.allclose(expected_loss, actual) + + loss_fn = nn.CrossEntropyLoss(reduction="mean") + + expected_loss = cross_entropy(feats, mask, 1) + actual = ( + loss_fn( + layer(torch.from_numpy(features), torch.LongTensor(target)), + torch.LongTensor(target), + ) + .detach() + .numpy() + ) + assert np.isclose(expected_loss.mean(), actual) + + loss_fn = nn.CrossEntropyLoss(reduction="sum") + + expected_loss = cross_entropy(feats, mask, 1) + actual = ( + loss_fn( + layer(torch.from_numpy(features), torch.LongTensor(target)), + torch.LongTensor(target), + ) + .detach() + .numpy() + ) + assert np.isclose(expected_loss.sum(), actual) + + +def test_cosface_with_cross_entropy_loss(): + emb_size = 4 + n_classes = 3 + s = 3.0 + m = 0.1 + + # fmt: off + features = np.array( + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + dtype="f", + ) + target = np.array([0, 2], dtype="l") + weight = np.array( + [ + [0.1, 0.2, 0.3, 0.4], + [1.1, 3.2, 5.3, 0.4], + [0.1, 0.2, 6.3, 0.4], + ], + dtype="f", + ) + # fmt: on + + layer = CosFace(emb_size, n_classes, s, m) + layer.weight.data = torch.from_numpy(weight) + loss_fn = nn.CrossEntropyLoss(reduction="none") + + normalized_features = normalize(features) # 2x4 + normalized_projection = normalize(weight) # 3x4 + + cosine = normalized_features @ normalized_projection.T # 2x4 * 4x3 = 2x3 + phi = cosine - m # 2x3 + + mask = np.array([[1, 0, 0], [0, 0, 1]], dtype="l") # one_hot(target) + feats = (mask * phi + (1.0 - mask) * cosine) * s # 2x3 + + expected_loss = cross_entropy(feats, mask, 1) + actual = ( + loss_fn( + layer(torch.from_numpy(features), torch.LongTensor(target)), + torch.LongTensor(target), + ) + .detach() + .numpy() + ) + assert np.allclose(expected_loss, actual) + + loss_fn = nn.CrossEntropyLoss(reduction="mean") + + expected_loss = cross_entropy(feats, mask, 1) + actual = ( + loss_fn( + layer(torch.from_numpy(features), torch.LongTensor(target)), + torch.LongTensor(target), + ) + .detach() + .numpy() + ) + assert np.isclose(expected_loss.mean(), actual) + + loss_fn = nn.CrossEntropyLoss(reduction="sum") + + expected_loss = cross_entropy(feats, mask, 1) + actual = ( + loss_fn( + layer(torch.from_numpy(features), torch.LongTensor(target)), + torch.LongTensor(target), + ) + .detach() + .numpy() + ) + assert np.isclose(expected_loss.sum(), actual) diff --git a/docs/api/contrib.rst b/docs/api/contrib.rst index d2e1d8144c..d48bbbb9a7 100644 --- a/docs/api/contrib.rst +++ b/docs/api/contrib.rst @@ -250,6 +250,13 @@ Wing Modules ~~~~~~~~~~~~~~~~ +ArcFace and SubCenterArcFace +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +.. automodule:: catalyst.contrib.nn.modules.arcface + :members: + :undoc-members: + :show-inheritance: + Common modules """""""""""""""""""""""""""""""""""""""""" .. automodule:: catalyst.contrib.nn.modules.common @@ -257,6 +264,13 @@ Common modules :undoc-members: :show-inheritance: +CosFace: Large Margin Cosine Loss for Deep Face Recognition +""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" +.. automodule:: catalyst.contrib.nn.modules.cosface + :members: + :undoc-members: + :show-inheritance: + Last-Mean-Average-Attention (LAMA)-Pooling """""""""""""""""""""""""""""""""""""""""" .. automodule:: catalyst.contrib.nn.modules.lama @@ -285,6 +299,12 @@ SqueezeAndExcitation :undoc-members: :show-inheritance: +SoftMax +"""""""""""""""""""""""""""""""""""""""""" +.. automodule:: catalyst.contrib.nn.modules.softmax + :members: + :undoc-members: + :show-inheritance: Optimizers ~~~~~~~~~~~~~~~~