Skip to content

Commit

Permalink
[fbsync] Handle invalid reduction values (#6675)
Browse files Browse the repository at this point in the history
Summary:
* Add ValueError

* Add tests for ValueError

* Add tests for ValueError

* Add ValueError

* Change to if/else

* Ammend iou_fn tests

* Move code excerpt

* Format tests

Reviewed By: datumbox

Differential Revision: D40138724

fbshipit-source-id: 56c742a8c2ff80f2f51cba4cb3156835ed250653

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Oct 7, 2022
1 parent 7516e02 commit ac0cef0
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 7 deletions.
22 changes: 22 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,11 @@ def test_giou_loss(self, dtype, device):
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum")
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean")

# Test reduction value
# reduction value other than ["none", "mean", "sum"] should raise a ValueError
with pytest.raises(ValueError, match="Invalid"):
ops.generalized_box_iou_loss(box1s, box2s, reduction="xyz")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device):
Expand All @@ -1413,6 +1418,9 @@ def test_ciou_loss(self, dtype, device):
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")

with pytest.raises(ValueError, match="Invalid"):
ops.complete_box_iou_loss(box1s, box2s, reduction="xyz")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device):
Expand All @@ -1432,6 +1440,9 @@ def test_distance_iou_loss(self, dtype, device):
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean")
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum")

with pytest.raises(ValueError, match="Invalid"):
ops.distance_box_iou_loss(box1s, box2s, reduction="xyz")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_distance_iou_inputs(self, dtype, device):
Expand Down Expand Up @@ -1554,6 +1565,17 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)

# Raise ValueError for anonymous reduction mode
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_reduction_mode(self, device, dtype, reduction="xyz"):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
torch.random.manual_seed(0)
inputs, targets = self._generate_diverse_input_target_pair(device=device, dtype=dtype)
with pytest.raises(ValueError, match="Invalid"):
ops.sigmoid_focal_loss(inputs, targets, 0.25, 2, reduction)


class TestMasksToBoxes:
def test_masks_box(self):
Expand Down
11 changes: 9 additions & 2 deletions torchvision/ops/ciou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,16 @@ def complete_box_iou_loss(
alpha = v / (1 - iou + v + eps)

loss = diou_loss + alpha * v
if reduction == "mean":

# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()

else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss
9 changes: 8 additions & 1 deletion torchvision/ops/diou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@ def distance_box_iou_loss(

loss, _ = _diou_iou_loss(boxes1, boxes2, eps)

if reduction == "mean":
# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss


Expand Down
11 changes: 9 additions & 2 deletions torchvision/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def sigmoid_focal_loss(
Loss tensor with the reduction option applied.
"""
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(sigmoid_focal_loss)
p = torch.sigmoid(inputs)
Expand All @@ -43,9 +44,15 @@ def sigmoid_focal_loss(
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss

if reduction == "mean":
# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()

else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss
10 changes: 8 additions & 2 deletions torchvision/ops/giou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,15 @@ def generalized_box_iou_loss(

loss = 1 - miouk

if reduction == "mean":
# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()

else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss

0 comments on commit ac0cef0

Please sign in to comment.