New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: SoftMax, CosFace, ArcFace layers #939
Merged
Merged
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
92db47f
cosface loss
ditwoo e63eedf
pep fixes
ditwoo 5539c7f
fixed link
ditwoo df4a32f
more tests
ditwoo a9734e1
ignore flake
ditwoo e6c1588
cosface now is a layer, softmax, cosface, tests
ditwoo 2893c1b
docs
ditwoo e8ae3f7
softmax, cosface, arcface layers
ditwoo 2f401d2
docs for __repr__
ditwoo ee26d7a
another docs fix
ditwoo 0887d13
and another docs fix
ditwoo e95ac6f
fixed arcface
ditwoo 60122fb
fixed conflict
ditwoo 6c5c5d9
fix: docs
ditwoo e27db82
fixed docs
ditwoo ef0d147
new docs format & SubCenterArcFace
ditwoo 46508aa
arcface title
ditwoo 81bf90e
fixed conflict
ditwoo ed6a1c5
docs
ditwoo c621648
docs for forward method
ditwoo 90418ca
typings & docs
ditwoo f4752df
moved noqa comment
ditwoo 6379740
fixed docs
ditwoo 1fdbc89
fixed init docs
ditwoo 800ae52
fixed examples
ditwoo File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class ArcFace(nn.Module): | ||
"""Implementation of ArcFace loss for metric learning. | ||
|
||
.. _ArcFace: Additive Angular Margin Loss for Deep Face Recognition: | ||
https://arxiv.org/abs/1801.07698v1 | ||
|
||
Example: | ||
>>> layer = ArcFace(5, 10, s=1.31, m=0.5) | ||
>>> loss_fn = nn.CrosEntropyLoss() | ||
>>> embedding = torch.randn(3, 5, requires_grad=True) | ||
>>> target = torch.empty(3, dtype=torch.long).random_(5) | ||
>>> output = layer(embedding, target) | ||
>>> loss = loss_fn(output, target) | ||
>>> loss.backward() | ||
|
||
""" | ||
|
||
def __init__( | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
in_features: int, | ||
out_features: int, | ||
s: float = 64.0, | ||
m: float = 0.5, | ||
): | ||
""" | ||
Args: | ||
in_features (int): size of each input sample. | ||
out_features (int): size of each output sample. | ||
s (float, optional): norm of input feature, | ||
Default: ``64.0``. | ||
m (float, optional): margin. | ||
Default: ``0.5``. | ||
""" | ||
super(ArcFace, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.s = s | ||
self.m = m | ||
|
||
self.cos_m = math.cos(m) | ||
self.sin_m = math.sin(m) | ||
self.th = math.cos(math.pi - m) | ||
self.mm = math.sin(math.pi - m) * m | ||
|
||
self.weight = nn.Parameter( | ||
torch.FloatTensor(out_features, in_features) | ||
) | ||
nn.init.xavier_uniform_(self.weight) | ||
|
||
def __repr__(self) -> str: | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Object representation.""" | ||
return "ArcFace(in_features={},out_features={},s={},m={})".format( | ||
self.in_features, self.out_features, self.s, self.m | ||
) | ||
|
||
def forward( | ||
self, input: torch.Tensor, target: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
input (torch.Tensor): input features, | ||
expected shapes BxF. | ||
target (torch.Tensor): target classes, | ||
expected shapes B. | ||
|
||
Returns: | ||
torch.Tensor with loss value. | ||
""" | ||
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | ||
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) | ||
phi = cosine * self.cos_m - sine * self.sin_m | ||
|
||
phi = torch.where(cosine > self.th, phi, cosine - self.mm) | ||
|
||
one_hot = torch.zeros(cosine.size()).to(input.device) | ||
one_hot.scatter_(1, target.view(-1, 1).long(), 1) | ||
logits = (one_hot * phi) + ((1.0 - one_hot) * cosine) | ||
logits *= self.s | ||
|
||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class CosFace(nn.Module): | ||
"""Implementation of CosFace loss for metric learning. | ||
|
||
.. _CosFace: Large Margin Cosine Loss for Deep Face Recognition: | ||
https://arxiv.org/abs/1801.09414 | ||
|
||
Example: | ||
>>> layer = CosFaceLoss(5, 10, s=1.31, m=0.1) | ||
>>> loss_fn = nn.CrosEntropyLoss() | ||
>>> embedding = torch.randn(3, 5, requires_grad=True) | ||
>>> target = torch.empty(3, dtype=torch.long).random_(5) | ||
>>> output = layer(embedding, target) | ||
>>> loss = loss_fn(output, target) | ||
>>> loss.backward() | ||
|
||
""" | ||
|
||
def __init__( | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self, | ||
in_features: int, | ||
out_features: int, | ||
s: float = 64.0, | ||
m: float = 0.35, | ||
): | ||
""" | ||
Args: | ||
in_features (int): size of each input sample. | ||
out_features (int): size of each output sample. | ||
s (float, optional): norm of input feature, | ||
Default: ``64.0``. | ||
m (float, optional): margin. | ||
Default: ``0.35``. | ||
""" | ||
super(CosFace, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.s = s | ||
self.m = m | ||
|
||
self.weight = nn.Parameter( | ||
torch.FloatTensor(out_features, in_features) | ||
) | ||
nn.init.xavier_uniform_(self.weight) | ||
|
||
def __repr__(self) -> str: | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Object representation.""" | ||
return "CosFace(in_features={},out_features={},s={},m={})".format( | ||
self.in_features, self.out_features, self.s, self.m | ||
) | ||
|
||
def forward( | ||
self, input: torch.Tensor, target: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Args: | ||
input (torch.Tensor): input features, | ||
expected shapes BxF. | ||
target (torch.Tensor): target classes, | ||
expected shapes B. | ||
|
||
Returns: | ||
torch.Tensor with loss value. | ||
""" | ||
cosine = F.linear(F.normalize(input), F.normalize(self.weight)) | ||
phi = cosine - self.m | ||
one_hot = torch.zeros(cosine.size()).to(input.device) | ||
one_hot.scatter_(1, target.view(-1, 1).long(), 1) | ||
logits = (one_hot * phi) + ((1.0 - one_hot) * cosine) | ||
logits *= self.s | ||
|
||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class SoftMax(nn.Module): | ||
"""Implementation of SoftMax head for metric learning. | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Example: | ||
>>> layer = SoftMax() | ||
>>> loss_fn = nn.CrosEntropyLoss() | ||
>>> embedding = torch.randn(3, 5, requires_grad=True) | ||
>>> target = torch.empty(3, dtype=torch.long).random_(5) | ||
>>> output = layer(embedding, target) | ||
>>> loss = loss_fn(output, target) | ||
>>> loss.backward() | ||
|
||
""" | ||
|
||
def __init__(self, in_features: int, num_classes: int): | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Args: | ||
in_features (int): size of each input sample. | ||
num_classes (int): size of each output sample. | ||
""" | ||
super(SoftMax, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = num_classes | ||
self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features)) | ||
self.bias = nn.Parameter(torch.FloatTensor(num_classes)) | ||
|
||
nn.init.xavier_uniform_(self.weight) | ||
nn.init.zeros_(self.bias) | ||
|
||
def __repr__(self) -> str: | ||
ditwoo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Object representation.""" | ||
return "SoftMax(in_features={},out_features={})".format( | ||
self.in_features, self.out_features | ||
) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Args: | ||
input (torch.Tensor): input features, | ||
expected shapes BxF. | ||
|
||
Returns: | ||
torch.Tensor with loss value. | ||
""" | ||
return F.linear(input, self.weight, self.bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as far as this is criterions, could we move them to
contrib/nn/criterion
? what do you think about this?