Skip to content

Commit

Permalink
Avoid recommuting the affine matrix in bbox rotate (#6712)
Browse files Browse the repository at this point in the history
* Avoid recommuting the affine matrix in bbox rotate

* Fix linter

* inverted=True for estimating image size

* Update the image size estimation to match the one from the image kernel

* Nits

* Address comments.

* Center=0,0 when expand=true
  • Loading branch information
datumbox committed Oct 6, 2022
1 parent 026991b commit 61034d5
Showing 1 changed file with 26 additions and 25 deletions.
51 changes: 26 additions & 25 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,9 +279,9 @@ def affine_image_tensor(
center_f = [0.0, 0.0]
if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]

translate_f = [1.0 * t for t in translate]
translate_f = [float(t) for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)

output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
Expand Down Expand Up @@ -321,7 +321,7 @@ def _affine_bounding_box_xyxy(
shear: List[float],
center: Optional[List[float]] = None,
expand: bool = False,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Tuple[int, int]]:
angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center
)
Expand All @@ -333,19 +333,24 @@ def _affine_bounding_box_xyxy(
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device

affine_matrix = torch.tensor(
_get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False),
dtype=dtype,
device=device,
).view(2, 3)
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
transposed_affine_matrix = (
torch.tensor(
affine_vector,
dtype=dtype,
device=device,
)
.view(2, 3)
.T
)
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
# Single point structure is similar to
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
# 2) Now let's transform the points using affine matrix
transformed_points = torch.matmul(points, affine_matrix.T)
transformed_points = torch.matmul(points, transposed_affine_matrix)
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
# and compute bounding box from 4 transformed points:
transformed_points = transformed_points.view(-1, 4, 2)
Expand All @@ -360,20 +365,24 @@ def _affine_bounding_box_xyxy(
points = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, 1.0 * height, 1.0],
[1.0 * width, 1.0 * height, 1.0],
[1.0 * width, 0.0, 1.0],
[0.0, float(height), 1.0],
[float(width), float(height), 1.0],
[float(width), 0.0, 1.0],
],
dtype=dtype,
device=device,
)
new_points = torch.matmul(points, affine_matrix.T)
new_points = torch.matmul(points, transposed_affine_matrix)
tr, _ = torch.min(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height)
image_size = (new_height, new_width)

return out_bboxes.to(bounding_box.dtype)
return out_bboxes.to(bounding_box.dtype), image_size


def affine_bounding_box(
Expand All @@ -391,7 +400,7 @@ def affine_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center)
out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center)

# out_bboxes should be of shape [N boxes, 4]

Expand Down Expand Up @@ -502,7 +511,7 @@ def rotate_image_tensor(
warnings.warn("The provided center argument has no effect on the result if expand is True")
else:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]

# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
Expand Down Expand Up @@ -558,7 +567,7 @@ def rotate_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

out_bboxes = _affine_bounding_box_xyxy(
out_bboxes, image_size = _affine_bounding_box_xyxy(
bounding_box,
image_size,
angle=-angle,
Expand All @@ -569,14 +578,6 @@ def rotate_bounding_box(
expand=expand,
)

if expand:
# TODO: Move this computation inside of `_affine_bounding_box_xyxy` to avoid computing the rotation and points
# matrix twice
height, width = image_size
rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0])
new_width, new_height = _FT._compute_affine_output_size(rotation_matrix, width, height)
image_size = (new_height, new_width)

return (
convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
Expand Down

0 comments on commit 61034d5

Please sign in to comment.