-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added CosFace Loss #202
Added CosFace Loss #202
Conversation
✅ Deploy Preview for capable-unicorn-d5e336 ready!
To edit notification comments on pull requests, go to your Netlify site settings. |
@monatis All checks have passed. Please could you review it? Thanks! |
def l2_norm(inputs: torch.Tensor, dim: int = 0) -> torch.Tensor: | ||
"""Apply L2 normalization to tensor | ||
|
||
Args: | ||
inputs: Input tensor. | ||
dim: Dimension to operate on. | ||
|
||
Returns: | ||
torch.Tensor: L2-normalized tensor | ||
""" | ||
outputs = inputs / torch.norm(inputs, 2, dim, True) | ||
|
||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the documentation, torch.norm
is deprecated and may be removed in the future. That would be better to use torch.linalg.norm
instead here: outputs = inputs / torch.linalg.norm(inputs, dim=dim, ord=2, keepdim=True)
.
This can also be used: return torch.nn.functional.normalize(inputs, p=2, dim=dim)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I think it is better to create a separate issue and a consequent PR for it.
Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com>
@monatis Sorry, I must have missed changing it after reverting my commits. Thank you! Is there anything more before merging? Sorry for the late update. My health hasn't been keeping up for the past couple of days. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get well soon! Hope you feel better now.
The tests didn't pass again. See why in the review comments below. Also let's add that assertion to ArcFaceLoss
as well, just to the beginning of .forward()
method as in CosFaceLoss
. Finally, please run tests locally with the pytest
command.
Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com>
Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
@monatis Ran the tests locally. Made the requested changes. Anything more, to be done? Also, thankyou for the help! |
@crsdvaibhav ArcFace has the assertion missing, and CosFace has it twice. See also #203 (thanks @KarahanS) |
I changed it. Thanks for the merge! I will move to #203 |
l2_norm
from ArcFace Loss toquaterion.utils
and updated the imports and__init__.py