Skip to content

Commit

Permalink
[fbsync] Fix d/c IoU for different batch sizes (#6338)
Browse files Browse the repository at this point in the history
Summary:
* Fix bug in calculating cIoU for unequal sizes

* Remove comment

* what the epsilon?

* Fixing DIoU

* Optimization by Francisco.

* Fix the expected values on CompleteBoxIoU

* Apply suggestions from code review

* Adding cartesian product test.

* remove static

Reviewed By: NicolasHug

Differential Revision: D38351751

fbshipit-source-id: 097e5f7048c650767e275fbb2c30ed0c800b1314

Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <vvryniotis@fb.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Abhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com>
  • Loading branch information
3 people authored and facebook-github-bot committed Aug 2, 2022
1 parent d893469 commit e2e1db6
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 49 deletions.
130 changes: 89 additions & 41 deletions test/test_ops.py
Expand Up @@ -1111,14 +1111,6 @@ def test_bbox_convert_jit(self):
torch.testing.assert_close(scripted_cxcywh, box_cxcywh)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]


class TestBoxArea:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area(box)
Expand Down Expand Up @@ -1152,99 +1144,155 @@ def test_box_area_jit(self):
torch.testing.assert_close(scripted_area, expected)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
xy1 = torch.rand((size, 2), dtype=dtype)
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
return torch.cat([xy1, xy2], axis=-1)


class TestIouBase:
@staticmethod
def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], atol: float, expected: List):
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box = torch.tensor(test_input, dtype=dtype)
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box, actual_box)
out = target_fn(actual_box1, actual_box2)
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, test_input: List):
box_tensor = torch.tensor(test_input, dtype=torch.float)
def _run_jit_test(target_fn: Callable, actual_box: List):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected)

@staticmethod
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
N = boxes1.size(0)
M = boxes2.size(0)
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = gen_box(5)
boxes2 = gen_box(7)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
assert torch.allclose(a, b)


class TestBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)


class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.generalized_box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.generalized_box_iou)


class TestDistanceBoxIoU(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
int_expected = [
[1.0000, 0.1875, -0.4444],
[0.1875, 1.0000, -0.5625],
[-0.4444, -0.5625, 1.0000],
[-0.0781, 0.1875, -0.6267],
]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.distance_box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.distance_box_iou)


class TestCompleteBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
int_expected = [
[1.0000, 0.1875, -0.4444],
[0.1875, 1.0000, -0.5625],
[-0.4444, -0.5625, 1.0000],
[-0.0781, 0.1875, -0.6267],
]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.complete_box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.complete_box_iou)


def get_boxes(dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
Expand Down
12 changes: 7 additions & 5 deletions torchvision/ops/boxes.py
Expand Up @@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso

diou, iou = _box_diou_iou(boxes1, boxes2, eps)

w_pred = boxes1[:, 2] - boxes1[:, 0]
h_pred = boxes1[:, 3] - boxes1[:, 1]
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]

w_gt = boxes2[:, 2] - boxes2[:, 0]
h_gt = boxes2[:, 3] - boxes2[:, 1]

v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
return diou - alpha * v
Expand All @@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
diou, _ = _box_diou_iou(boxes1, boxes2)
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
return diou


Expand All @@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
_upcast((y_p[:, None] - y_g[None, :])) ** 2
)
# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared), iou
Expand Down
6 changes: 3 additions & 3 deletions torchvision/ops/ciou_loss.py
Expand Up @@ -14,8 +14,8 @@ def complete_box_iou_loss(

"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap overlap area, This loss function considers important geometrical
factors such as overlap area, normalized central point distance and aspect ratio.
boxes do not overlap. This loss function considers important geometrical
factors such as overlap area, normalized central point distance and aspect ratio.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
Expand All @@ -35,7 +35,7 @@ def complete_box_iou_loss(
Tensor: Loss tensor with the reduction option applied.
Reference:
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
Zhaohui Zheng et al.: Complete Intersection over Union Loss:
https://arxiv.org/abs/1911.08287
"""
Expand Down

0 comments on commit e2e1db6

Please sign in to comment.