Skip to content

Commit

Permalink
【Hackathon No.21】为 Paddle 新增 SoftMarginLoss (#42364)
Browse files Browse the repository at this point in the history
* 2022-04-28

* 2022-04-28_V2

* 2022-04-30

* 2022-04-30_V2

* 2022-05-01

* 2022-05-02

* 2022-05-02_V2

* 2022-05-05_V1

* 2022-05-06_V1

* 2022-05-07_V1

* Update loss.py

* 2022-05-07_V2

* 2022-05-13_V1

* Update test_soft_margin_loss.py

* Update loss.py

* Update loss.py

* 2022-05-16_V1

* 2022-05-19_V1

* 2022-05-20_V1

* Update test_soft_margin_loss.py

* 2022-06-01_V1

* 2022-06-05

* 2022-06-07

* 2022-06-07

* 2022-06-08

* 2022-06-08_V2

* 2022-06-17-code_style

* Modify python

* 2022-06-20

* for

* for CI;test=document_fix

Co-authored-by: Ligoml <39876205+Ligoml@users.noreply.github.com>
  • Loading branch information
yangguohao and Ligoml committed Jul 25, 2022
1 parent 243acdb commit f9cd526
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 0 deletions.
177 changes: 177 additions & 0 deletions python/paddle/fluid/tests/unittests/test_soft_margin_loss.py
@@ -0,0 +1,177 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import numpy as np
import unittest


def test_static_layer(
place,
input_np,
label_np,
reduction='mean',
):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype=input_np.dtype)
label = paddle.static.data(name='label',
shape=label_np.shape,
dtype=label_np.dtype)
sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction)
res = sm_loss(input, label)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np
},
fetch_list=[res])
return static_result


def test_static_functional(
place,
input_np,
label_np,
reduction='mean',
):
paddle.enable_static()
prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(prog, startup_prog):
input = paddle.static.data(name='input',
shape=input_np.shape,
dtype=input_np.dtype)
label = paddle.static.data(name='label',
shape=label_np.shape,
dtype=label_np.dtype)

res = paddle.nn.functional.soft_margin_loss(input,
label,
reduction=reduction)
exe = paddle.static.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np
},
fetch_list=[res])
return static_result


def test_dygraph_layer(
place,
input_np,
label_np,
reduction='mean',
):
paddle.disable_static()
sm_loss = paddle.nn.loss.SoftMarginLoss(reduction=reduction)
dy_res = sm_loss(paddle.to_tensor(input_np), paddle.to_tensor(label_np))
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result


def test_dygraph_functional(
place,
input_np,
label_np,
reduction='mean',
):
paddle.disable_static()
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)

dy_res = paddle.nn.functional.soft_margin_loss(input,
label,
reduction=reduction)
dy_result = dy_res.numpy()
paddle.enable_static()
return dy_result


def calc_softmarginloss(
input_np,
label_np,
reduction='mean',
):
expected = np.log(1 + np.exp(-label_np * input_np))
# expected = np.mean(expected, axis=-1)

if reduction == 'mean':
expected = np.mean(expected)
elif reduction == 'sum':
expected = np.sum(expected)
else:
expected = expected

return expected


class TestSoftMarginLoss(unittest.TestCase):

def test_SoftMarginLoss(self):
input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64)
types = [np.int32, np.int64, np.float32, np.float64]
places = ['cpu']
if paddle.device.is_compiled_with_cuda():
places.append('gpu')
reductions = ['sum', 'mean', 'none']
for place in places:
for reduction in reductions:
for _type in types:
label_np = np.random.randint(0, 2,
size=(5, 5)).astype(_type)
label_np[label_np == 0] = -1
static_result = test_static_layer(place, input_np, label_np,
reduction)
dy_result = test_dygraph_layer(place, input_np, label_np,
reduction)
expected = calc_softmarginloss(input_np, label_np,
reduction)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
static_functional = test_static_functional(
place, input_np, label_np, reduction)
dy_functional = test_dygraph_functional(
place, input_np, label_np, reduction)
self.assertTrue(np.allclose(static_functional, expected))
self.assertTrue(
np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))

def test_SoftMarginLoss_error(self):
paddle.disable_static()
self.assertRaises(ValueError,
paddle.nn.loss.SoftMarginLoss,
reduction="unsupport reduction")
input = paddle.to_tensor([[0.1, 0.3]], dtype='float32')
label = paddle.to_tensor([[-1.0, 1.0]], dtype='float32')
self.assertRaises(ValueError,
paddle.nn.functional.soft_margin_loss,
input=input,
label=label,
reduction="unsupport reduction")
paddle.enable_static()


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions python/paddle/nn/__init__.py
Expand Up @@ -111,6 +111,7 @@
from .layer.loss import CosineEmbeddingLoss # noqa: F401
from .layer.loss import TripletMarginWithDistanceLoss
from .layer.loss import TripletMarginLoss
from .layer.loss import SoftMarginLoss
from .layer.norm import BatchNorm # noqa: F401
from .layer.norm import SyncBatchNorm # noqa: F401
from .layer.norm import GroupNorm # noqa: F401
Expand Down Expand Up @@ -320,4 +321,5 @@ def weight_norm(*args):
'RReLU',
'TripletMarginWithDistanceLoss',
'TripletMarginLoss',
'SoftMarginLoss',
]
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/__init__.py
Expand Up @@ -94,6 +94,7 @@
from .loss import multi_label_soft_margin_loss
from .loss import triplet_margin_with_distance_loss
from .loss import triplet_margin_loss
from .loss import soft_margin_loss
from .norm import batch_norm # noqa: F401
from .norm import instance_norm # noqa: F401
from .norm import layer_norm # noqa: F401
Expand Down Expand Up @@ -238,4 +239,5 @@
'rrelu',
'triplet_margin_with_distance_loss',
'triplet_margin_loss',
'soft_margin_loss',
]
79 changes: 79 additions & 0 deletions python/paddle/nn/functional/loss.py
Expand Up @@ -3200,3 +3200,82 @@ def triplet_margin_loss(input,
return paddle.sum(loss, name=name)
elif reduction == 'none':
return loss


def soft_margin_loss(input, label, reduction='mean', name=None):
"""
The API measures the soft margin loss between input predictions ``input``
and target labels ``label`` . It can be described as:
.. math::
Out = log(1 + exp((-label * input)))
Parameters:
input (Tensor): The input predications tensor with shape: [N, *],
N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf.
Available dtype is float32, float64.
label (Tensor): The target labels tensor with the same shape as
``input``. The target labels which values should be numbers -1 or 1.
Available dtype is int32, int64, float32, float64.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is [1].
Examples:
.. code-block:: python
import paddle
import numpy as np
input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
output = paddle.nn.functional.soft_margin_loss(input, label)
input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.int64)
label_np[label_np==0]=-1
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
output = paddle.nn.functional.soft_margin_loss(input, label, reduction='none')
"""
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in soft_margin_loss should be 'sum', "
"'mean' or 'none', but received %s, which is not allowed." %
reduction)

if not _non_static_mode():
fluid.data_feeder.check_variable_and_dtype(input, 'input',
['float32', 'float64'],
'soft_margin_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['int32', 'int64', 'float32', 'float64'],
'soft_margin_loss')

if not (input.shape == label.shape):
raise ValueError("input's shape must equal to "
"label's shape")

label = fluid.layers.cast(label, input.dtype)
out = paddle.log(1 + paddle.exp(-label * input))

if reduction == 'sum':
return paddle.sum(out, name=name)
elif reduction == 'mean':
return paddle.mean(out, name=name)
else:
return out
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Expand Up @@ -82,6 +82,7 @@
from .loss import HingeEmbeddingLoss # noqa: F401
from .loss import TripletMarginWithDistanceLoss
from .loss import TripletMarginLoss
from .loss import SoftMarginLoss
from .norm import BatchNorm1D # noqa: F401
from .norm import BatchNorm2D # noqa: F401
from .norm import BatchNorm3D # noqa: F401
Expand Down
71 changes: 71 additions & 0 deletions python/paddle/nn/layer/loss.py
Expand Up @@ -1691,3 +1691,74 @@ def forward(self, input, positive, negative):
swap=self.swap,
reduction=self.reduction,
name=self.name)


class SoftMarginLoss(Layer):
r"""
Creates a criterion that measures a two-class soft margin loss between input predictions ``input``
and target labels ``label`` . It can be described as:
.. math::
Out = log(1 + exp((-label * input)))
Parameters:
reduction (str, optional): Indicate how to average the loss by batch_size,
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
If :attr:`reduction` is ``'sum'``, the summed loss is returned.
Default is ``'mean'``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shapes:
Input (Tensor): The input tensor with shape: [N, *],
N is batch_size, `*` means any number of additional dimensions. The ``input`` ranges from -inf to inf
Available dtype is float32, float64.
Label (Tensor): The target labels tensor with the same shape as
``input``. The target labels which values should be numbers -1 or 1.
Available dtype is int32, int64, float32, float64.
Output (Tensor): If ``reduction`` is ``'none'``, the shape of output is
same as ``input`` , else the shape of output is [1].
Returns:
A callable object of SoftMarginLoss.
Examples:
.. code-block:: python
import paddle
import numpy as np
input = paddle.to_tensor([[0.5, 0.6, 0.7],[0.3, 0.5, 0.2]], 'float32')
label = paddle.to_tensor([[1.0, -1.0, 1.0],[-1.0, 1.0, 1.0]], 'float32')
soft_margin_loss = paddle.nn.SoftMarginLoss()
output = soft_margin_loss(input, label)
input_np = np.random.uniform(0.1, 0.8, size=(5, 5)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(5, 5)).astype(np.int64)
label_np[label_np==0]=-1
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
soft_margin_loss = paddle.nn.SoftMarginLoss(reduction='none')
output = soft_margin_loss(input, label)
"""

def __init__(self, reduction='mean', name=None):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in SoftMarginLoss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)

super(SoftMarginLoss, self).__init__()
self.reduction = reduction
self.name = name

def forward(self, input, label):
out = paddle.nn.functional.soft_margin_loss(input, label,
self.reduction, self.name)
return out

0 comments on commit f9cd526

Please sign in to comment.