Skip to content

Commit

Permalink
Update loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Jun 13, 2022
1 parent 187fa6d commit e6793b9
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions python/paddle/nn/layer/loss.py
Expand Up @@ -1485,7 +1485,7 @@ def __init__(self,
distance_function=None,
margin=1.0,
swap=False,
reduction: str='mean',
reduction: str = 'mean',
name=None):
super(TripletMarginWithDistanceLoss, self).__init__()
if reduction not in ['sum', 'mean', 'none']:
Expand All @@ -1500,14 +1500,13 @@ def __init__(self,
self.name = name

def forward(self, input, positive, negative):
return F.triplet_margin_with_distance_loss(
input,
positive,
negative,
margin=self.margin,
swap=self.swap,
reduction=self.reduction,
name=self.name)
return F.triplet_margin_with_distance_loss(input,
positive,
negative,
margin=self.margin,
swap=self.swap,
reduction=self.reduction,
name=self.name)


class TripletMarginLoss(Layer):
Expand Down Expand Up @@ -1569,18 +1568,17 @@ class TripletMarginLoss(Layer):
.. code-block:: python
import paddle
import paddle.nn.TripletMarginLoss
input = paddle.to_tensor([[1, 5, 3], [0, 3, 2], [1, 4, 1]], dtype=paddle.float32)
positive= paddle.to_tensor([[5, 1, 2], [3, 2, 1], [3, -1, 1]], dtype=paddle.float32)
negative = paddle.to_tensor([[2, 1, -3], [1, 1, -1], [4, -2, 1]], dtype=paddle.float32)
triplet_margin_loss = TripletMarginLoss()
loss = triplet_margin_loss(input, positive, negative, reduction='none')
triplet_margin_loss = paddle.nn.TripletMarginLoss(reduction='none')
loss = triplet_margin_loss(input, positive, negative)
print(loss)
# Tensor([0. , 0.57496738, 0. ])
loss = triplet_margin_loss(input, positive, negative, margin=1.0, reduction='mean')
triplet_margin_loss = paddle.nn.TripletMarginLoss(margin=1.0, swap=True, reduction='mean', )
loss = triplet_margin_loss(input, positive, negative,)
print(loss)
# Tensor([0.19165580])
Expand Down

0 comments on commit e6793b9

Please sign in to comment.