diff --git a/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py b/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py new file mode 100644 index 0000000000000..745cb6a178032 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py @@ -0,0 +1,395 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import unittest + + +def call_TripletMarginLoss_layer( + input, + positive, + negative, + p=2, + margin=0.3, + swap=False, + eps=1e-6, + reduction='mean', +): + triplet_margin_loss = paddle.nn.TripletMarginLoss(p=p, + epsilon=eps, + margin=margin, + swap=swap, + reduction=reduction) + res = triplet_margin_loss( + input=input, + positive=positive, + negative=negative, + ) + return res + + +def call_TripletMarginLoss_functional( + input, + positive, + negative, + p=2, + margin=0.3, + swap=False, + eps=1e-6, + reduction='mean', +): + res = paddle.nn.functional.triplet_margin_loss(input=input, + positive=positive, + negative=negative, + p=p, + epsilon=eps, + margin=margin, + swap=swap, + reduction=reduction) + return res + + +def test_static(place, + input_np, + positive_np, + negative_np, + p=2, + margin=0.3, + swap=False, + eps=1e-6, + reduction='mean', + functional=False): + prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(prog, startup_prog): + input = paddle.static.data(name='input', + shape=input_np.shape, + dtype='float64') + positive = paddle.static.data(name='positive', + shape=positive_np.shape, + dtype='float64') + negative = paddle.static.data(name='negative', + shape=negative_np.shape, + dtype='float64') + feed_dict = { + "input": input_np, + "positive": positive_np, + "negative": negative_np + } + + if functional: + res = call_TripletMarginLoss_functional(input=input, + positive=positive, + negative=negative, + p=p, + eps=eps, + margin=margin, + swap=swap, + reduction=reduction) + else: + res = call_TripletMarginLoss_layer(input=input, + positive=positive, + negative=negative, + p=p, + eps=eps, + margin=margin, + swap=swap, + reduction=reduction) + + exe = paddle.static.Executor(place) + static_result = exe.run(prog, feed=feed_dict, fetch_list=[res]) + return static_result + + +def test_dygraph(place, + input, + positive, + negative, + p=2, + margin=0.3, + swap=False, + eps=1e-6, + reduction='mean', + functional=False): + paddle.disable_static() + input = paddle.to_tensor(input) + positive = paddle.to_tensor(positive) + negative = paddle.to_tensor(negative) + + if functional: + dy_res = call_TripletMarginLoss_functional(input=input, + positive=positive, + negative=negative, + p=p, + eps=eps, + margin=margin, + swap=swap, + reduction=reduction) + else: + dy_res = call_TripletMarginLoss_layer(input=input, + positive=positive, + negative=negative, + p=p, + eps=eps, + margin=margin, + swap=swap, + reduction=reduction) + dy_result = dy_res.numpy() + paddle.enable_static() + return dy_result + + +def calc_triplet_margin_loss( + input, + positive, + negative, + p=2, + margin=0.3, + swap=False, + reduction='mean', +): + positive_dist = np.linalg.norm((input - positive), p, axis=1) + negative_dist = np.linalg.norm((input - negative), p, axis=1) + + if swap: + swap_dist = np.linalg.norm((positive - negative), p, axis=1) + negative_dist = np.minimum(negative_dist, swap_dist) + expected = np.maximum(positive_dist - negative_dist + margin, 0) + + if reduction == 'mean': + expected = np.mean(expected) + elif reduction == 'sum': + expected = np.sum(expected) + else: + expected = expected + + return expected + + +class TestTripletMarginLoss(unittest.TestCase): + + def test_TripletMarginLoss(self): + shape = (2, 2) + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + positive = np.random.uniform(0, 2, size=shape).astype(np.float64) + negative = np.random.uniform(0, 2, size=shape).astype(np.float64) + + places = [paddle.CPUPlace()] + if paddle.device.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + reductions = ['sum', 'mean', 'none'] + for place in places: + for reduction in reductions: + expected = calc_triplet_margin_loss(input=input, + positive=positive, + negative=negative, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + input=input, + positive=positive, + negative=negative, + reduction=reduction, + ) + + static_result = test_static( + place=place, + input_np=input, + positive_np=positive, + negative_np=negative, + reduction=reduction, + ) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + input_np=input, + positive_np=positive, + negative_np=negative, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + input=input, + positive=positive, + negative=negative, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_TripletMarginLoss_error(self): + paddle.disable_static() + self.assertRaises(ValueError, + paddle.nn.loss.TripletMarginLoss, + reduction="unsupport reduction") + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32') + self.assertRaises(ValueError, + paddle.nn.functional.triplet_margin_loss, + input=input, + positive=positive, + negative=negative, + reduction="unsupport reduction") + paddle.enable_static() + + def test_TripletMarginLoss_dimension(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3], [1, 2]], dtype='float32') + positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32') + self.assertRaises( + ValueError, + paddle.nn.functional.triplet_margin_loss, + input=input, + positive=positive, + negative=negative, + ) + TMLoss = paddle.nn.loss.TripletMarginLoss() + self.assertRaises( + ValueError, + TMLoss, + input=input, + positive=positive, + negative=negative, + ) + paddle.enable_static() + + def test_TripletMarginLoss_swap(self): + reduction = 'mean' + place = paddle.CPUPlace() + shape = (2, 2) + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + positive = np.random.uniform(0, 2, size=shape).astype(np.float64) + negative = np.random.uniform(0, 2, size=shape).astype(np.float64) + expected = calc_triplet_margin_loss(input=input, + swap=True, + positive=positive, + negative=negative, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + swap=True, + input=input, + positive=positive, + negative=negative, + reduction=reduction, + ) + + static_result = test_static( + place=place, + swap=True, + input_np=input, + positive_np=positive, + negative_np=negative, + reduction=reduction, + ) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + swap=True, + input_np=input, + positive_np=positive, + negative_np=negative, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + swap=True, + input=input, + positive=positive, + negative=negative, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + def test_TripletMarginLoss_margin(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3]], dtype='float32') + positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32') + margin = -0.5 + self.assertRaises( + ValueError, + paddle.nn.functional.triplet_margin_loss, + margin=margin, + input=input, + positive=positive, + negative=negative, + ) + paddle.enable_static() + + def test_TripletMarginLoss_p(self): + p = 3 + shape = (2, 2) + reduction = 'mean' + place = paddle.CPUPlace() + input = np.random.uniform(0.1, 0.8, size=shape).astype(np.float64) + positive = np.random.uniform(0, 2, size=shape).astype(np.float64) + negative = np.random.uniform(0, 2, size=shape).astype(np.float64) + expected = calc_triplet_margin_loss(input=input, + p=p, + positive=positive, + negative=negative, + reduction=reduction) + + dy_result = test_dygraph( + place=place, + p=p, + input=input, + positive=positive, + negative=negative, + reduction=reduction, + ) + + static_result = test_static( + place=place, + p=p, + input_np=input, + positive_np=positive, + negative_np=negative, + reduction=reduction, + ) + self.assertTrue(np.allclose(static_result, expected)) + self.assertTrue(np.allclose(static_result, dy_result)) + self.assertTrue(np.allclose(dy_result, expected)) + static_functional = test_static(place=place, + p=p, + input_np=input, + positive_np=positive, + negative_np=negative, + reduction=reduction, + functional=True) + dy_functional = test_dygraph(place=place, + p=p, + input=input, + positive=positive, + negative=negative, + reduction=reduction, + functional=True) + self.assertTrue(np.allclose(static_functional, expected)) + self.assertTrue(np.allclose(static_functional, dy_functional)) + self.assertTrue(np.allclose(dy_functional, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index a1e02dab4707d..8b29659a1f400 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -109,6 +109,7 @@ from .layer.loss import HingeEmbeddingLoss # noqa: F401 from .layer.loss import CosineEmbeddingLoss # noqa: F401 from .layer.loss import TripletMarginWithDistanceLoss +from .layer.loss import TripletMarginLoss from .layer.norm import BatchNorm # noqa: F401 from .layer.norm import SyncBatchNorm # noqa: F401 from .layer.norm import GroupNorm # noqa: F401 @@ -316,4 +317,5 @@ def weight_norm(*args): 'CosineEmbeddingLoss', 'RReLU', 'TripletMarginWithDistanceLoss', + 'TripletMarginLoss', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 43ce403ab0b23..cdb1135eba800 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -92,6 +92,7 @@ from .loss import hinge_embedding_loss # noqa: F401 from .loss import cosine_embedding_loss # noqa: F401 from .loss import triplet_margin_with_distance_loss +from .loss import triplet_margin_loss from .norm import batch_norm # noqa: F401 from .norm import instance_norm # noqa: F401 from .norm import layer_norm # noqa: F401 @@ -234,4 +235,5 @@ 'cosine_embedding_loss', 'rrelu', 'triplet_margin_with_distance_loss', + 'triplet_margin_loss', ] diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index c882ab08296ae..2f37f8a50f4d1 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -28,7 +28,7 @@ from paddle.utils import deprecated from paddle import _C_ops from paddle import in_dynamic_mode -from paddle.framework import core +from paddle.framework import core, _non_static_mode from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode, _non_static_mode, _current_expected_place __all__ = [] @@ -2999,3 +2999,124 @@ def triplet_margin_with_distance_loss(input, return paddle.sum(loss, name=name) elif reduction == 'none': return loss + + +def triplet_margin_loss(input, + positive, + negative, + margin=1.0, + p=2, + epsilon=1e-6, + swap=False, + reduction='mean', + name=None): + r""" + Measures the triplet loss given an input + tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. A triplet + is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative + examples` respectively). The shapes of all input tensors should be + :math:`(N, *)`. + + The loss function for each sample in the mini-batch is: + + .. math:: + L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\} + + + where + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + Parameters: + input (Tensor): Input tensor, the data type is float32 or float64. + the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. + + positive (Tensor): Positive tensor, the data type is float32 or float64. + The shape of label is the same as the shape of input. + + negative (Tensor): Negative tensor, the data type is float32 or float64. + The shape of label is the same as the shape of input. + + margin (float, Optional): Default: :math:`1`. + + p (int, Optional): The norm degree for pairwise distance. Default: :math:`2`. + + epsilon (float, Optional): Add small value to avoid division by zero, + default value is 1e-6. + + swap (bool,Optional): The distance swap change the negative distance to the distance between + positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`. + Default: ``False``. + + + reduction (str, Optional):Indicate how to average the loss by batch_size. + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + + name (str, Optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Output: Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative. + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + + input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32) + positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32) + negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32) + loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='none') + print(loss) + # Tensor([0. , 0.57496738, 0. ]) + + + loss = F.triplet_margin_loss(input, positive, negative, margin=1.0, reduction='mean') + print(loss) + # Tensor([0.19165580]) + + """ + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "'reduction' in 'triplet_margin_loss' should be 'sum', 'mean' or 'none', " + "but received {}.".format(reduction)) + if margin < 0: + raise ValueError( + "The margin between positive samples and negative samples should be greater than 0." + ) + if not _non_static_mode(): + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'triplet_margin_loss') + check_variable_and_dtype(positive, 'positive', ['float32', 'float64'], + 'triplet_margin_loss') + check_variable_and_dtype(negative, 'negative', ['float32', 'float64'], + 'triplet_margin_loss') + + if not (input.shape == positive.shape == negative.shape): + raise ValueError("input's shape must equal to " + "positive's shape and " + "negative's shape") + + distance_function = paddle.nn.PairwiseDistance(p, epsilon=epsilon) + positive_dist = distance_function(input, positive) + negative_dist = distance_function(input, negative) + + if swap: + swap_dist = distance_function(positive, negative) + negative_dist = paddle.minimum(negative_dist, swap_dist) + + loss = paddle.clip(positive_dist - negative_dist + margin, min=0.0) + + if reduction == 'mean': + return paddle.mean(loss, name=name) + elif reduction == 'sum': + return paddle.sum(loss, name=name) + elif reduction == 'none': + return loss diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index a8e3d8ec1d464..e9ccee1bd3829 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -80,6 +80,7 @@ from .loss import SmoothL1Loss # noqa: F401 from .loss import HingeEmbeddingLoss # noqa: F401 from .loss import TripletMarginWithDistanceLoss +from .loss import TripletMarginLoss from .norm import BatchNorm1D # noqa: F401 from .norm import BatchNorm2D # noqa: F401 from .norm import BatchNorm3D # noqa: F401 diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index 9b796d6965c33..1e72548ecc138 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -1507,3 +1507,109 @@ def forward(self, input, positive, negative): swap=self.swap, reduction=self.reduction, name=self.name) + + +class TripletMarginLoss(Layer): + r""" + Creates a criterion that measures the triplet loss given an input + tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`. + This is used for measuring a relative similarity between samples. A triplet + is composed by `input`, `positive` and `negative` (i.e., `input`, `positive examples` and `negative + examples` respectively). The shapes of all input tensors should be + :math:`(N, *)`. + + The loss function for each sample in the mini-batch is: + + .. math:: + L(input, pos, neg) = \max \{d(input_i, pos_i) - d(input_i, neg_i) + {\rm margin}, 0\} + + + where + + .. math:: + d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p + + Parameters: + margin (float, Optional):Default: :math:`1`. + + p (int, Optional):The norm degree for pairwise distance. Default: :math:`2`. + + epsilon (float, Optional):Add small value to avoid division by zero, + default value is 1e-6. + + swap (bool, Optional):The distance swap change the negative distance to the distance between + positive sample and negative sample. For more details, see `Learning shallow convolutional feature descriptors with triplet losses`. + Default: ``False``. + + reduction (str, Optional):Indicate how to average the loss by batch_size. + the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. + If :attr:`reduction` is ``'none'``, the unreduced loss is returned; + If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned; + If :attr:`reduction` is ``'sum'``, the summed loss is returned. + Default: ``'mean'`` + + name (str,Optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Call Parameters: + input (Tensor):Input tensor, the data type is float32 or float64. + the shape is [N, \*], N is batch size and `\*` means any number of additional dimensions, available dtype is float32, float64. + + positive (Tensor):Positive tensor, the data type is float32 or float64. + The shape of label is the same as the shape of input. + + negative (Tensor):Negative tensor, the data type is float32 or float64. + The shape of label is the same as the shape of input. + + Returns: + Tensor. The tensor variable storing the triplet_margin_loss of input and positive and negative. + + Examples: + .. code-block:: python + + import paddle + + input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32) + positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32) + negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32) + triplet_margin_loss = paddle.nn.TripletMarginLoss(reduction='none') + loss = triplet_margin_loss(input, positive, negative) + print(loss) + # Tensor([0. , 0.57496738, 0. ]) + + triplet_margin_loss = paddle.nn.TripletMarginLoss(margin=1.0, swap=True, reduction='mean', ) + loss = triplet_margin_loss(input, positive, negative,) + print(loss) + # Tensor([0.19165580]) + + """ + + def __init__(self, + margin=1.0, + p=2., + epsilon=1e-6, + swap=False, + reduction='mean', + name=None): + super(TripletMarginLoss, self).__init__() + if reduction not in ['sum', 'mean', 'none']: + raise ValueError( + "The value of 'reduction' in TripletMarginLoss should be 'sum', 'mean' or 'none', but " + "received %s, which is not allowed." % reduction) + self.margin = margin + self.p = p + self.epsilon = epsilon + self.swap = swap + self.reduction = reduction + self.name = name + + def forward(self, input, positive, negative): + return F.triplet_margin_loss(input, + positive, + negative, + margin=self.margin, + p=self.p, + epsilon=self.epsilon, + swap=self.swap, + reduction=self.reduction, + name=self.name)