Skip to content

Commit

Permalink
[PaddleHackathon No.14] (PaddlePaddle#41183)
Browse files Browse the repository at this point in the history
* 2022-04-28

* 2022-05-04

* 2022-05-05_V1

* 2022-05-05_V1

* Update loss.py

* Update loss.py

* 2022-06-01_hook

* 2022-06-05

* 2022-06-07

* 2022-06-07_V2

* 2022-06-07_V2

* 2022-06-17_codestyle
  • Loading branch information
yangguohao authored and sneaxiy committed Jun 27, 2022
1 parent 5ec255a commit c360750
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 0 deletions.
@@ -0,0 +1,252 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import numpy as np
import unittest


def call_MultiLabelSoftMarginLoss_layer(
input,
label,
weight=None,
reduction='mean',
):
multilabel_margin_loss = paddle.nn.MultiLabelSoftMarginLoss(
weight=weight, reduction=reduction)
res = multilabel_margin_loss(
input=input,
label=label,
)
return res


def call_MultiLabelSoftMarginLoss_functional(
input,
label,
weight=None,
reduction='mean',
):
res = paddle.nn.functional.multi_label_soft_margin_loss(
input,
label,
reduction=reduction,
weight=weight,
)
return res


def test_static(place,
input_np,
label_np,
weight_np=None,
reduction='mean',
functional=False):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype='float64')
label = paddle.static.data(name='label',
shape=label_np.shape,
dtype='float64')
feed_dict = {
"input": input_np,
"label": label_np,
}
weight = None
if weight_np is not None:
weight = paddle.static.data(name='weight',
shape=weight_np.shape,
dtype='float64')
feed_dict['weight'] = weight_np

if functional:
res = call_MultiLabelSoftMarginLoss_functional(input=input,
label=label,
weight=weight,
reduction=reduction)
else:
res = call_MultiLabelSoftMarginLoss_layer(input=input,
label=label,
weight=weight,
reduction=reduction)

exe = paddle.static.Executor(place)
static_result = exe.run(prog, feed=feed_dict, fetch_list=[res])
return static_result


def test_dygraph(place,
input_np,
label_np,
weight=None,
reduction='mean',
functional=False):
with paddle.fluid.dygraph.base.guard():
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
if weight is not None:
weight = paddle.to_tensor(weight)

if functional:
dy_res = call_MultiLabelSoftMarginLoss_functional(
input=input, label=label, weight=weight, reduction=reduction)
else:
dy_res = call_MultiLabelSoftMarginLoss_layer(input=input,
label=label,
weight=weight,
reduction=reduction)
dy_result = dy_res.numpy()
return dy_result


def calc_multilabel_margin_loss(
input,
label,
weight=None,
reduction="mean",
):

def LogSigmoid(x):
return np.log(1 / (1 + np.exp(-x)))

loss = -(label * LogSigmoid(input) + (1 - label) * LogSigmoid(-input))

if weight is not None:
loss = loss * weight

loss = loss.mean(axis=-1) # only return N loss values

if reduction == "none":
return loss
elif reduction == "mean":
return np.mean(loss)
elif reduction == "sum":
return np.sum(loss)


class TestMultiLabelMarginLoss(unittest.TestCase):

def test_MultiLabelSoftMarginLoss(self):
input = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64)
label = np.random.randint(0, 2, size=(5, 5)).astype(np.float64)

places = ['cpu']
if paddle.device.is_compiled_with_cuda():
places.append('gpu')
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
expected = calc_multilabel_margin_loss(input=input,
label=label,
reduction=reduction)

dy_result = test_dygraph(place=place,
input_np=input,
label_np=label,
reduction=reduction)

static_result = test_static(place=place,
input_np=input,
label_np=label,
reduction=reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
input_np=input,
label_np=label,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input_np=input,
label_np=label,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))

def test_MultiLabelSoftMarginLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.MultiLabelSoftMarginLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[0.0, 1.0]], dtype='float32')
self.assertRaises(ValueError,
paddle.nn.functional.multi_label_soft_margin_loss,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()

def test_MultiLabelSoftMarginLoss_weights(self):
input = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64)
label = np.random.randint(0, 2, size=(5, 5)).astype(np.float64)
weight = np.random.randint(0, 2, size=(5, 5)).astype(np.float64)
place = 'cpu'
reduction = 'mean'
expected = calc_multilabel_margin_loss(input=input,
label=label,
weight=weight,
reduction=reduction)

dy_result = test_dygraph(place=place,
input_np=input,
label_np=label,
weight=weight,
reduction=reduction)

static_result = test_static(place=place,
input_np=input,
label_np=label,
weight_np=weight,
reduction=reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static(place=place,
input_np=input,
label_np=label,
weight_np=weight,
reduction=reduction,
functional=True)
dy_functional = test_dygraph(place=place,
input_np=input,
label_np=label,
weight=weight,
reduction=reduction,
functional=True)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))

def test_MultiLabelSoftMarginLoss_dimension(self):
paddle.disable_static()

input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32')
label = paddle.to_tensor([[0.2, 0.1]], dtype='float32')
self.assertRaises(ValueError,
paddle.nn.functional.multi_label_soft_margin_loss,
input=input,
label=label)
paddle.enable_static()


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/nn/__init__.py
Expand Up @@ -104,6 +104,7 @@
from .layer.loss import BCELoss # noqa: F401
from .layer.loss import KLDivLoss # noqa: F401
from .layer.loss import MarginRankingLoss # noqa: F401
from .layer.loss import MultiLabelSoftMarginLoss
from .layer.loss import CTCLoss # noqa: F401
from .layer.loss import SmoothL1Loss # noqa: F401
from .layer.loss import HingeEmbeddingLoss # noqa: F401
Expand Down Expand Up @@ -312,6 +313,7 @@ def weight_norm(*args):
'MaxUnPool1D',
'MaxUnPool2D',
'MaxUnPool3D',
'MultiLabelSoftMarginLoss',
'HingeEmbeddingLoss',
'Identity',
'CosineEmbeddingLoss',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Expand Up @@ -91,6 +91,7 @@
from .loss import ctc_loss # noqa: F401
from .loss import hinge_embedding_loss # noqa: F401
from .loss import cosine_embedding_loss # noqa: F401
from .loss import multi_label_soft_margin_loss
from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss
from .norm import batch_norm # noqa: F401
Expand Down Expand Up @@ -206,6 +207,7 @@
'log_loss',
'mse_loss',
'margin_ranking_loss',
'multi_label_soft_margin_loss',
'nll_loss',
'npair_loss',
'sigmoid_focal_loss',
Expand Down
80 changes: 80 additions & 0 deletions python/paddle/nn/functional/loss.py
Expand Up @@ -2668,6 +2668,86 @@ def sigmoid_focal_loss(logit,
return loss


def multi_label_soft_margin_loss(input,
label,
weight=None,
reduction="mean",
name=None):
r"""
Parameters:
input (Tensor): Input tensor, the data type is float32 or float64. Shape is (N, C), where C is number of classes, and if shape is more than 2D, this is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is float32 or float64. The shape of label is the same as the shape of input.
weight (Tensor,optional): a manual rescaling weight given to each class.
If given, has to be a Tensor of size C and the data type is float32, float64.
Default is ``'None'`` .
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default: ``'mean'``
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
input: N-D Tensor, the shape is [N, \*], N is batch size and `\*` means number of classes, available dtype is float32, float64. The sum operationoperates over all the elements.
label: N-D Tensor, same shape as the input.
weight:N-D Tensor, the shape is [N,1]
output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input.
Returns:
Tensor, The tensor variable storing the multi_label_soft_margin_loss of input and label.
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
input = paddle.to_tensor([[1, -2, 3], [0, -1, 2], [1, 0, 1]], dtype=paddle.float32)
# label elements in {1., -1.}
label = paddle.to_tensor([[-1, 1, -1], [1, 1, 1], [1, -1, 1]], dtype=paddle.float32)
loss = F.multi_label_soft_margin_loss(input, label, reduction='none')
print(loss)
# Tensor([3.49625897, 0.71111226, 0.43989015])
loss = F.multi_label_soft_margin_loss(input, label, reduction='mean')
print(loss)
# Tensor([1.54908717])
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'multi_label_soft_margin_loss' should be 'sum', 'mean' or 'none', "
"but received {}.".format(reduction))

if not (input.shape == label.shape):
raise ValueError("The input and label should have same dimension,"
"but received {}!={}".format(input.shape, label.shape))

if not _non_static_mode():
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
'multilabel_soft_margin_loss')
check_variable_and_dtype(label, 'label', ['float32', 'float64'],
'multilabel_soft_margin_loss')

loss = -(label * paddle.nn.functional.log_sigmoid(input) +
(1 - label) * paddle.nn.functional.log_sigmoid(-input))

if weight is not None:
if not _non_static_mode():
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'],
'multilabel_soft_margin_loss')
loss = loss * weight

loss = loss.mean(axis=-1) # only return N loss values

if reduction == "none":
return loss
elif reduction == "mean":
return paddle.mean(loss)
elif reduction == "sum":
return paddle.sum(loss)


def hinge_embedding_loss(input, label, margin=1.0, reduction='mean', name=None):
r"""
This operator calculates hinge_embedding_loss. Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`(containing 1 or -1).
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Expand Up @@ -76,6 +76,7 @@
from .loss import BCELoss # noqa: F401
from .loss import KLDivLoss # noqa: F401
from .loss import MarginRankingLoss # noqa: F401
from .loss import MultiLabelSoftMarginLoss
from .loss import CTCLoss # noqa: F401
from .loss import SmoothL1Loss # noqa: F401
from .loss import HingeEmbeddingLoss # noqa: F401
Expand Down

0 comments on commit c360750

Please sign in to comment.