From e2117ab7f1da622c32fe6c92182ec962ccf73b33 Mon Sep 17 00:00:00 2001 From: yangguohao <1901212980@pku.edu.cn> Date: Sun, 27 Mar 2022 01:29:03 +0800 Subject: [PATCH] '2022_03_27' --- .../unittests/test_triplet_margin_loss.py | 20 +++++++++++++++++++ python/paddle/nn/functional/loss.py | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py b/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py index b6287394edb84..e285f6a31bfcb 100644 --- a/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py @@ -188,6 +188,26 @@ def test_TripletMarginLoss_error(self): reduction="unsupport reduction") paddle.enable_static() + def test_TripletMarginLoss_dimension(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3],[1, 2]], dtype='float32') + positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32') + self.assertRaises( + ValueError, + paddle.nn.functional.triplet_margin_loss, + input=input, + positive=positive, + negative=negative,) + TMLoss = paddle.nn.TripletMarginLoss() + self.assertRaises( + ValueError, + TMLoss, + input=input, + positive=positive, + negative=negative,) + paddle.enable_static() if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 38cb532fbb635..914dea7dc7c1e 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2250,6 +2250,16 @@ def triplet_margin_loss(input,positive,negative, check_variable_and_dtype(negative, 'negative', ['float32', 'float64'], 'triplet_margin_loss') + # reshape to [batch_size, N] + input = input.flatten(start_axis=1,stop_axis=-1) + positive = positive.flatten(start_axis=1,stop_axis=-1) + negative = negative.flatten(start_axis=1,stop_axis=-1) + if not(input.shape==positive.shape==negative.shape): + raise ValueError( + "input's shape must equal to " + "positive's shape and " + "negative's shape") + distance_function = paddle.nn.PairwiseDistance(p, epsilon=eps) positive_dist = distance_function(input, positive) negative_dist = distance_function(input, negative)