diff --git a/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py b/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py index b6287394edb84..e285f6a31bfcb 100644 --- a/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py +++ b/python/paddle/fluid/tests/unittests/test_triplet_margin_loss.py @@ -188,6 +188,26 @@ def test_TripletMarginLoss_error(self): reduction="unsupport reduction") paddle.enable_static() + def test_TripletMarginLoss_dimension(self): + paddle.disable_static() + + input = paddle.to_tensor([[0.1, 0.3],[1, 2]], dtype='float32') + positive = paddle.to_tensor([[0.0, 1.0]], dtype='float32') + negative = paddle.to_tensor([[0.2, 0.1]], dtype='float32') + self.assertRaises( + ValueError, + paddle.nn.functional.triplet_margin_loss, + input=input, + positive=positive, + negative=negative,) + TMLoss = paddle.nn.TripletMarginLoss() + self.assertRaises( + ValueError, + TMLoss, + input=input, + positive=positive, + negative=negative,) + paddle.enable_static() if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 38cb532fbb635..914dea7dc7c1e 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -2250,6 +2250,16 @@ def triplet_margin_loss(input,positive,negative, check_variable_and_dtype(negative, 'negative', ['float32', 'float64'], 'triplet_margin_loss') + # reshape to [batch_size, N] + input = input.flatten(start_axis=1,stop_axis=-1) + positive = positive.flatten(start_axis=1,stop_axis=-1) + negative = negative.flatten(start_axis=1,stop_axis=-1) + if not(input.shape==positive.shape==negative.shape): + raise ValueError( + "input's shape must equal to " + "positive's shape and " + "negative's shape") + distance_function = paddle.nn.PairwiseDistance(p, epsilon=eps) positive_dist = distance_function(input, positive) negative_dist = distance_function(input, negative)