Skip to content

Commit

Permalink
2022-06-17_codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Jun 17, 2022
1 parent de173dd commit ce13bcf
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 162 deletions.
4 changes: 2 additions & 2 deletions python/paddle/nn/__init__.py
Expand Up @@ -104,12 +104,12 @@
from .layer.loss import BCELoss # noqa: F401
from .layer.loss import KLDivLoss # noqa: F401
from .layer.loss import MarginRankingLoss # noqa: F401
from .layer.loss import MultiLabelSoftMarginLoss
from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import MultiLabelSoftMarginLoss
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
Expand Down Expand Up @@ -312,10 +312,10 @@ def weight_norm(*args):
'MaxUnPool1D',
'MaxUnPool2D',
'MaxUnPool3D',
'MultiLabelSoftMarginLoss',
'HingeEmbeddingLoss',
'Identity',
'CosineEmbeddingLoss',
'RReLU',
'MultiLabelSoftMarginLoss',
'TripletMarginWithDistanceLoss',
]
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/__init__.py
Expand Up @@ -206,6 +206,7 @@
'log_loss',
'mse_loss',
'margin_ranking_loss',
'multi_label_soft_margin_loss',
'nll_loss',
'npair_loss',
'sigmoid_focal_loss',
Expand Down Expand Up @@ -235,5 +236,4 @@
'cosine_embedding_loss',
'rrelu',
'triplet_margin_with_distance_loss',
'multi_label_soft_margin_loss',
]
160 changes: 80 additions & 80 deletions python/paddle/nn/functional/loss.py
Expand Up @@ -2668,6 +2668,86 @@ def sigmoid_focal_loss(logit,
return loss


def multi_label_soft_margin_loss(input,
label,
weight=None,
reduction="mean",
name=None):
r"""
Parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is the same as the shape of input.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size C and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
label: N-D Tensor, same shape as the input.
weight:N-D Tensor, the shape is [N,1]
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.
Returns:
Tensor, The tensor variable storing the multi_label_soft_margin_loss of input and label.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
# label elements in {1., -1.}
label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
loss = F.multi_label_soft_margin_loss(input, label, reduction='none')
print(loss)
# Tensor([3.49625897, 0.71111226, 0.43989015])
loss = F.multi_label_soft_margin_loss(input, label, reduction='mean')
print(loss)
# Tensor([1.54908717])
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'multi_label_soft_margin_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))

if not (input.shape == label.shape):
raise ValueError("The input and label should have same dimension,"
"but received {}!={}".format(input.shape, label.shape))

if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'multilabel_soft_margin_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'multilabel_soft_margin_loss')

loss = -(label * paddle.nn.functional.log_sigmoid(input) +
(1 - label) * paddle.nn.functional.log_sigmoid(-input))

if weight is not None:
if not _non_static_mode():
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'multilabel_soft_margin_loss')
loss = loss * weight

loss = loss.mean(axis=-1) # only return N loss values

if reduction == "none":
return loss
elif reduction == "mean":
return paddle.mean(loss)
elif reduction == "sum":
return paddle.sum(loss)


def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None):
r"""
This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
Expand Down Expand Up @@ -2999,83 +3079,3 @@ def triplet_margin_with_distance_loss(input,
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss


def multi_label_soft_margin_loss(input,
label,
weight=None,
reduction="mean",
name=None):
r"""
Parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is the same as the shape of input.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size C and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
label: N-D Tensor, same shape as the input.
weight:N-D Tensor, the shape is [N,1]
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.
Returns:
Tensor, The tensor variable storing the multi_label_soft_margin_loss of input and label.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
# label elements in {1., -1.}
label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
loss = F.multi_label_soft_margin_loss(input, label, reduction='none')
print(loss)
# Tensor([3.49625897, 0.71111226, 0.43989015])
loss = F.multi_label_soft_margin_loss(input, label, reduction='mean')
print(loss)
# Tensor([1.54908717])
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'multi_label_soft_margin_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))

if not (input.shape == label.shape):
raise ValueError("The input and label should have same dimension,"
"but received {}!={}".format(input.shape, label.shape))

if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'multi_label_soft_margin_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'multi_label_soft_margin_loss')

loss = -(label * paddle.nn.functional.log_sigmoid(input) +
(1 - label) * paddle.nn.functional.log_sigmoid(-input))

if weight is not None:
if not _non_static_mode():
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'multi_label_soft_margin_loss')
loss = loss * weight

loss = loss.mean(axis=-1) # only return N loss values

if reduction == "none":
return loss
elif reduction == "mean":
return paddle.mean(loss)
elif reduction == "sum":
return paddle.sum(loss)
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/__init__.py
Expand Up @@ -76,11 +76,11 @@
from .loss import BCELoss # noqa: F401
from .loss import KLDivLoss # noqa: F401
from .loss import MarginRankingLoss # noqa: F401
from .loss import MultiLabelSoftMarginLoss
from .loss import CTCLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401
from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss
from .loss import MultiLabelSoftMarginLoss
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
Expand Down
156 changes: 78 additions & 78 deletions python/paddle/nn/layer/loss.py
Expand Up @@ -1217,6 +1217,84 @@ def forward(self, input, label):
name=self.name)


class MultiLabelSoftMarginLoss(Layer):
r"""Creates a criterion that optimizes a multi-class multi-classification
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
and output :math:`y` (which is a 2D `Tensor` of target class indices).
For each sample in the mini-batch:
.. math::
\text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
:math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
:math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
:math:`y` and :math:`x` must have the same size.
Parameters:
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size C and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Call parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.
Shape:
input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
label: N-D Tensor, same shape as the input.
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.
Returns:
A callable object of MultiLabelSoftMarginLoss.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none')
loss = multi_label_soft_margin_loss(input, label)
print(loss)
# Tensor([3.49625897, 0.71111226, 0.43989015])
multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='mean')
loss = multi_label_soft_margin_loss(input, label)
print(loss)
# Tensor([1.54908717])
"""

def __init__(self, weight=None, reduction="mean", name=None):
super(MultiLabelSoftMarginLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'MultiLabelSoftMarginloss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
self.weight = weight
self.reduction = reduction
self.name = name

def forward(self, input, label):
return F.multi_label_soft_margin_loss(input,
label,
weight=self.weight,
reduction=self.reduction,
name=self.name)


class HingeEmbeddingLoss(Layer):
r"""
This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
Expand Down Expand Up @@ -1507,81 +1585,3 @@ def forward(self, input, positive, negative):
swap=self.swap,
reduction=self.reduction,
name=self.name)


class MultiLabelSoftMarginLoss(Layer):
r"""Creates a criterion that optimizes a multi-class multi-classification
hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
and output :math:`y` (which is a 2D `Tensor` of target class indices).
For each sample in the mini-batch:
.. math::
\text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
:math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
:math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
:math:`y` and :math:`x` must have the same size.
Parameters:
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size C and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Call parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor containing 1 or -1, the data type is float32 or float64. The shape of label is the same as the shape of input.
Shape:
input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
label: N-D Tensor, same shape as the input.
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.
Returns:
A callable object of MultiLabelSoftMarginLoss.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='none')
loss = multi_label_soft_margin_loss(input, label)
print(loss)
# Tensor([3.49625897, 0.71111226, 0.43989015])
multi_label_soft_margin_loss = nn.MultiLabelSoftMarginLoss(reduction='mean')
loss = multi_label_soft_margin_loss(input, label)
print(loss)
# Tensor([1.54908717])
"""

def __init__(self, weight=None, reduction="mean", name=None):
super(MultiLabelSoftMarginLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'MultiLabelSoftMarginLoss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))
self.weight = weight
self.reduction = reduction
self.name = name

def forward(self, input, label):
return F.multi_label_soft_margin_loss(input,
label,
weight=self.weight,
reduction=self.reduction,
name=self.name)

1 comment on commit ce13bcf

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.