Skip to content

Commit

Permalink
Clean up to_tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Aug 27, 2021
1 parent 96f6e0a commit c762e39
Showing 1 changed file with 15 additions and 45 deletions.
60 changes: 15 additions & 45 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,9 @@ def get_image_num_channels(img: Tensor) -> int:
return F_pil.get_image_num_channels(img)


@torch.jit.unused
def _is_numpy(img: Any) -> bool:
return isinstance(img, np.ndarray)


@torch.jit.unused
def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}


def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to a float tensor and normalize
the values to 0-1 scale if the input is of type `uint8`.
This function does not support torchscript.
See :class:`~torchvision.transforms.ToTensor` for more details.
Expand All @@ -110,46 +101,25 @@ def to_tensor(pic):
Returns:
Tensor: Converted image.
"""
if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

if _is_numpy(pic) and not _is_numpy_image(pic):
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))

default_float_dtype = torch.get_default_dtype()

if isinstance(pic, np.ndarray):
# handle numpy array
if pic.ndim == 2:
pic = pic[:, :, None]
elif pic.ndim != 3:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
img = torch.from_numpy(pic.transpose((2, 0, 1)))
elif F_pil._is_pil_image(pic):
img = pil_to_tensor(pic)
else:
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
# backward compatibility
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
normalize = isinstance(img, torch.ByteTensor) and \
(accimage is None or isinstance(pic, accimage.Image))
img = img.to(dtype=default_float_dtype)
if normalize:
img.div_(255)

# handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
)

if pic.mode == '1':
img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
if isinstance(img, torch.ByteTensor):
return img.to(dtype=default_float_dtype).div(255)
else:
return img
return img


def pil_to_tensor(pic):
Expand Down

0 comments on commit c762e39

Please sign in to comment.