From e2e1db6e41ef49a033d038f5d8bed269ea580ebd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 2 Aug 2022 13:36:35 -0700 Subject: [PATCH] [fbsync] Fix d/c IoU for different batch sizes (#6338) Summary: * Fix bug in calculating cIoU for unequal sizes * Remove comment * what the epsilon? * Fixing DIoU * Optimization by Francisco. * Fix the expected values on CompleteBoxIoU * Apply suggestions from code review * Adding cartesian product test. * remove static Reviewed By: NicolasHug Differential Revision: D38351751 fbshipit-source-id: 097e5f7048c650767e275fbb2c30ed0c800b1314 Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> Co-authored-by: Vasilis Vryniotis Co-authored-by: Vasilis Vryniotis Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com> --- test/test_ops.py | 130 ++++++++++++++++++++++++----------- torchvision/ops/boxes.py | 12 ++-- torchvision/ops/ciou_loss.py | 6 +- 3 files changed, 99 insertions(+), 49 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 8ec0e6c7ea9..8f961e37117 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1111,14 +1111,6 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh) -INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] -FLOAT_BOXES = [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], -] - - class TestBoxArea: def area_check(self, box, expected, atol=1e-4): out = ops.box_area(box) @@ -1152,99 +1144,155 @@ def test_box_area_jit(self): torch.testing.assert_close(scripted_area, expected) +INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]] +INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +FLOAT_BOXES = [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], +] + + +def gen_box(size, dtype=torch.float): + xy1 = torch.rand((size, 2), dtype=dtype) + xy2 = xy1 + torch.rand((size, 2), dtype=dtype) + return torch.cat([xy1, xy2], axis=-1) + + class TestIouBase: @staticmethod - def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], atol: float, expected: List): + def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): for dtype in dtypes: - actual_box = torch.tensor(test_input, dtype=dtype) + actual_box1 = torch.tensor(actual_box1, dtype=dtype) + actual_box2 = torch.tensor(actual_box2, dtype=dtype) expected_box = torch.tensor(expected) - out = target_fn(actual_box, actual_box) + out = target_fn(actual_box1, actual_box2) torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod - def _run_jit_test(target_fn: Callable, test_input: List): - box_tensor = torch.tensor(test_input, dtype=torch.float) + def _run_jit_test(target_fn: Callable, actual_box: List): + box_tensor = torch.tensor(actual_box, dtype=torch.float) expected = target_fn(box_tensor, box_tensor) scripted_fn = torch.jit.script(target_fn) scripted_out = scripted_fn(box_tensor, box_tensor) torch.testing.assert_close(scripted_out, expected) + @staticmethod + def _cartesian_product(boxes1, boxes2, target_fn: Callable): + N = boxes1.size(0) + M = boxes2.size(0) + result = torch.zeros((N, M)) + for i in range(N): + for j in range(M): + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + return result + + @staticmethod + def _run_cartesian_test(target_fn: Callable): + boxes1 = gen_box(5) + boxes2 = gen_box(7) + a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2) + assert torch.allclose(a, b) + class TestBoxIou(TestIouBase): - int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, atol, expected", + "actual_box1, actual_box2, dtypes, atol, expected", [ - pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, atol, expected): - self._run_test(ops.box_iou, test_input, dtypes, atol, expected) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): + self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.box_iou, INT_BOXES) + def test_iou_cartesian(self): + self._run_cartesian_test(ops.box_iou) + class TestGeneralizedBoxIou(TestIouBase): - int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, atol, expected", + "actual_box1, actual_box2, dtypes, atol, expected", [ - pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, atol, expected): - self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): + self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.generalized_box_iou, INT_BOXES) + def test_iou_cartesian(self): + self._run_cartesian_test(ops.generalized_box_iou) + class TestDistanceBoxIoU(TestIouBase): - int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + int_expected = [ + [1.0000, 0.1875, -0.4444], + [0.1875, 1.0000, -0.5625], + [-0.4444, -0.5625, 1.0000], + [-0.0781, 0.1875, -0.6267], + ] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, atol, expected", + "actual_box1, actual_box2, dtypes, atol, expected", [ - pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, atol, expected): - self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): + self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.distance_box_iou, INT_BOXES) + def test_iou_cartesian(self): + self._run_cartesian_test(ops.distance_box_iou) + class TestCompleteBoxIou(TestIouBase): - int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + int_expected = [ + [1.0000, 0.1875, -0.4444], + [0.1875, 1.0000, -0.5625], + [-0.4444, -0.5625, 1.0000], + [-0.0781, 0.1875, -0.6267], + ] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, atol, expected", + "actual_box1, actual_box2, dtypes, atol, expected", [ - pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, atol, expected): - self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): + self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.complete_box_iou, INT_BOXES) + def test_iou_cartesian(self): + self._run_cartesian_test(ops.complete_box_iou) + def get_boxes(dtype, device): box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index e42e7e04a70..a541f8d880a 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso diou, iou = _box_diou_iou(boxes1, boxes2, eps) - w_pred = boxes1[:, 2] - boxes1[:, 0] - h_pred = boxes1[:, 3] - boxes1[:, 1] + w_pred = boxes1[:, None, 2] - boxes1[:, None, 0] + h_pred = boxes1[:, None, 3] - boxes1[:, None, 1] w_gt = boxes2[:, 2] - boxes2[:, 0] h_gt = boxes2[:, 3] - boxes2[:, 1] - v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) return diou - alpha * v @@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) - diou, _ = _box_diou_iou(boxes1, boxes2) + diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps) return diou @@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 # The distance between boxes' centers squared. - centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) + centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + ( + _upcast((y_p[:, None] - y_g[None, :])) ** 2 + ) # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. return iou - (centers_distance_squared / diagonal_distance_squared), iou diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index a71baf28e70..a9f20a5f4c8 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -14,8 +14,8 @@ def complete_box_iou_loss( """ Gradient-friendly IoU loss with an additional penalty that is non-zero when the - boxes do not overlap overlap area, This loss function considers important geometrical - factors such as overlap area, normalized central point distance and aspect ratio. + boxes do not overlap. This loss function considers important geometrical + factors such as overlap area, normalized central point distance and aspect ratio. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with @@ -35,7 +35,7 @@ def complete_box_iou_loss( Tensor: Loss tensor with the reduction option applied. Reference: - Zhaohui Zheng et. al: Complete Intersection over Union Loss: + Zhaohui Zheng et al.: Complete Intersection over Union Loss: https://arxiv.org/abs/1911.08287 """