Skip to content

Commit

Permalink
Make get_image_size and get_image_num_channels public. (#4321)
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Aug 26, 2021
1 parent 37a9ee5 commit 96f6e0a
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 40 deletions.
8 changes: 4 additions & 4 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def forward(self, image: Tensor,
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, _ = F._get_image_size(image)
width, _ = F.get_image_size(image)
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
Expand Down Expand Up @@ -76,7 +76,7 @@ def forward(self, image: Tensor,
elif image.ndimension() == 2:
image = image.unsqueeze(0)

orig_w, orig_h = F._get_image_size(image)
orig_w, orig_h = F.get_image_size(image)

while True:
# sample an option
Expand Down Expand Up @@ -157,7 +157,7 @@ def forward(self, image: Tensor,
if torch.rand(1) < self.p:
return image, target

orig_w, orig_h = F._get_image_size(image)
orig_w, orig_h = F.get_image_size(image)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
Expand Down Expand Up @@ -226,7 +226,7 @@ def forward(self, image: Tensor,
image = self._contrast(image)

if r[6] < self.p:
channels = F._get_image_num_channels(image)
channels = F.get_image_num_channels(image)
permutation = torch.randperm(channels)

is_pil = F._is_pil_image(image)
Expand Down
20 changes: 19 additions & 1 deletion test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fn', [F.get_image_size, F.get_image_num_channels])
def test_image_sizes(device, fn):
script_F = torch.jit.script(fn)

img_tensor, pil_img = _create_data(16, 18, 3, device=device)
value_img = fn(img_tensor)
value_pil_img = fn(pil_img)
assert value_img == value_pil_img

value_img_script = script_F(img_tensor)
assert value_img == value_img_script

batch_tensors = _create_data_batch(16, 18, 3, num_samples=4, device=device)
value_img_batch = fn(batch_tensors)
assert value_img == value_img_batch


@needs_cuda
def test_scale_channel():
"""Make sure that _scale_channel gives the same results on CPU and GPU as
Expand Down Expand Up @@ -908,7 +926,7 @@ def test_resized_crop(device, mode):

@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('func, args', [
(F_t._get_image_size, ()), (F_t.vflip, ()),
(F_t.get_image_size, ()), (F_t.vflip, ()),
(F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)),
(F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )),
(F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )),
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward(self, img: Tensor) -> Tensor:
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

Expand All @@ -209,10 +209,10 @@ def forward(self, img: Tensor) -> Tensor:
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0,
img = F.affine(img, angle=0.0, translate=[int(F.get_image_size(img)[0] * magnitude), 0], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0,
img = F.affine(img, angle=0.0, translate=[0, int(F.get_image_size(img)[1] * magnitude)], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
Expand Down
38 changes: 25 additions & 13 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,34 @@ def _interpolation_modes_from_int(i: int) -> InterpolationMode:
_is_pil_image = F_pil._is_pil_image


def _get_image_size(img: Tensor) -> List[int]:
"""Returns image size as [w, h]
def get_image_size(img: Tensor) -> List[int]:
"""Returns the size of an image as [width, height].
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
List[int]: The image size.
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_size(img)
return F_t.get_image_size(img)

return F_pil._get_image_size(img)
return F_pil.get_image_size(img)


def _get_image_num_channels(img: Tensor) -> int:
"""Returns number of image channels
def get_image_num_channels(img: Tensor) -> int:
"""Returns the number of channels of an image.
Args:
img (PIL Image or Tensor): The image to be checked.
Returns:
int: The number of channels.
"""
if isinstance(img, torch.Tensor):
return F_t._get_image_num_channels(img)
return F_t.get_image_num_channels(img)

return F_pil._get_image_num_channels(img)
return F_pil.get_image_num_channels(img)


@torch.jit.unused
Expand Down Expand Up @@ -500,7 +512,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])

image_width, image_height = _get_image_size(img)
image_width, image_height = get_image_size(img)
crop_height, crop_width = output_size

if crop_width > image_width or crop_height > image_height:
Expand All @@ -511,7 +523,7 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
image_width, image_height = _get_image_size(img)
image_width, image_height = get_image_size(img)
if crop_width == image_width and crop_height == image_height:
return img

Expand Down Expand Up @@ -696,7 +708,7 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")

image_width, image_height = _get_image_size(img)
image_width, image_height = get_image_size(img)
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
Expand Down Expand Up @@ -993,7 +1005,7 @@ def rotate(

center_f = [0.0, 0.0]
if center is not None:
img_size = _get_image_size(img)
img_size = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]

Expand Down Expand Up @@ -1094,7 +1106,7 @@ def affine(
if len(shear) != 2:
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))

img_size = _get_image_size(img)
img_size = get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
Expand Down
6 changes: 3 additions & 3 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def _is_pil_image(img: Any) -> bool:


@torch.jit.unused
def _get_image_size(img: Any) -> List[int]:
def get_image_size(img: Any) -> List[int]:
if _is_pil_image(img):
return img.size
return list(img.size)
raise TypeError("Unexpected type {}".format(type(img)))


@torch.jit.unused
def _get_image_num_channels(img: Any) -> int:
def get_image_num_channels(img: Any) -> int:
if _is_pil_image(img):
return 1 if img.mode == 'L' else 3
raise TypeError("Unexpected type {}".format(type(img)))
Expand Down
14 changes: 7 additions & 7 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.")


def _get_image_size(img: Tensor) -> List[int]:
def get_image_size(img: Tensor) -> List[int]:
# Returns (w, h) of tensor image
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]


def _get_image_num_channels(img: Tensor) -> int:
def get_image_num_channels(img: Tensor) -> int:
if img.ndim == 2:
return 1
elif img.ndim > 2:
Expand Down Expand Up @@ -50,7 +50,7 @@ def _max_value(dtype: torch.dtype) -> float:


def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = _get_image_num_channels(img)
c = get_image_num_channels(img)
if c not in permitted:
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))

Expand Down Expand Up @@ -122,7 +122,7 @@ def hflip(img: Tensor) -> Tensor:
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
_assert_image_tensor(img)

w, h = _get_image_size(img)
w, h = get_image_size(img)
right = left + width
bottom = top + height

Expand Down Expand Up @@ -187,7 +187,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor(img)

_assert_channels(img, [1, 3])
if _get_image_num_channels(img) == 1: # Match PIL behaviour
if get_image_num_channels(img) == 1: # Match PIL behaviour
return img

orig_dtype = img.dtype
Expand Down Expand Up @@ -513,7 +513,7 @@ def resize(
if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")

w, h = _get_image_size(img)
w, h = get_image_size(img)

if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
Expand Down Expand Up @@ -586,7 +586,7 @@ def _assert_grid_transform_inputs(
warnings.warn("Argument fill should be either int, float, tuple or list")

# Check fill
num_channels = _get_image_num_channels(img)
num_channels = get_image_num_channels(img)
if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels):
msg = ("The number of elements in 'fill' cannot broadcast to match the number of "
"channels of the image ({} != {})")
Expand Down
18 changes: 9 additions & 9 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
w, h = F._get_image_size(img)
w, h = F.get_image_size(img)
th, tw = output_size

if h + 1 < th or w + 1 < tw:
Expand Down Expand Up @@ -613,7 +613,7 @@ def forward(self, img):
if self.padding is not None:
img = F.pad(img, self.padding, self.fill, self.padding_mode)

width, height = F._get_image_size(img)
width, height = F.get_image_size(img)
# pad the width if needed
if self.pad_if_needed and width < self.size[1]:
padding = [self.size[1] - width, 0]
Expand Down Expand Up @@ -742,12 +742,12 @@ def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]

if torch.rand(1) < self.p:
width, height = F._get_image_size(img)
width, height = F.get_image_size(img)
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation, fill)
return img
Expand Down Expand Up @@ -858,7 +858,7 @@ def get_params(
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = F._get_image_size(img)
width, height = F.get_image_size(img)
area = height * width

log_ratio = torch.log(torch.tensor(ratio))
Expand Down Expand Up @@ -1280,7 +1280,7 @@ def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
angle = self.get_params(self.degrees)
Expand Down Expand Up @@ -1439,11 +1439,11 @@ def forward(self, img):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
fill = [float(fill)] * F.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]

img_size = F._get_image_size(img)
img_size = F.get_image_size(img)

ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)

Expand Down Expand Up @@ -1529,7 +1529,7 @@ def forward(self, img):
Returns:
PIL Image or Tensor: Randomly grayscaled image.
"""
num_output_channels = F._get_image_num_channels(img)
num_output_channels = F.get_image_num_channels(img)
if torch.rand(1) < self.p:
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img
Expand Down

0 comments on commit 96f6e0a

Please sign in to comment.