-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up to_tensor. #4326
Clean up to_tensor. #4326
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -88,18 +88,9 @@ def get_image_num_channels(img: Tensor) -> int: | |||||||||
return F_pil.get_image_num_channels(img) | ||||||||||
|
||||||||||
|
||||||||||
@torch.jit.unused | ||||||||||
def _is_numpy(img: Any) -> bool: | ||||||||||
return isinstance(img, np.ndarray) | ||||||||||
|
||||||||||
|
||||||||||
@torch.jit.unused | ||||||||||
def _is_numpy_image(img: Any) -> bool: | ||||||||||
return img.ndim in {2, 3} | ||||||||||
|
||||||||||
|
||||||||||
def to_tensor(pic): | ||||||||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. | ||||||||||
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to a float tensor and normalize | ||||||||||
the values to 0-1 scale if the input is of type `uint8`. | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should disclose under which conditions the normalize. |
||||||||||
This function does not support torchscript. | ||||||||||
|
||||||||||
See :class:`~torchvision.transforms.ToTensor` for more details. | ||||||||||
|
@@ -110,46 +101,25 @@ def to_tensor(pic): | |||||||||
Returns: | ||||||||||
Tensor: Converted image. | ||||||||||
""" | ||||||||||
if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): | ||||||||||
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) | ||||||||||
|
||||||||||
if _is_numpy(pic) and not _is_numpy_image(pic): | ||||||||||
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) | ||||||||||
|
||||||||||
default_float_dtype = torch.get_default_dtype() | ||||||||||
|
||||||||||
if isinstance(pic, np.ndarray): | ||||||||||
# handle numpy array | ||||||||||
if pic.ndim == 2: | ||||||||||
pic = pic[:, :, None] | ||||||||||
elif pic.ndim != 3: | ||||||||||
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) | ||||||||||
img = torch.from_numpy(pic.transpose((2, 0, 1))) | ||||||||||
elif F_pil._is_pil_image(pic): | ||||||||||
img = pil_to_tensor(pic) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Turns out we can't just use
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because PyTorch doesn't (yet) support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand but this explains why the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think that there could be corner cases where This needs to be validated, but would indicate problems with the current function that we don't want to keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, see my earlier remark about overflowing #4326 (comment). I agree that it's not necessary to keep all the workarounds. |
||||||||||
else: | ||||||||||
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) | ||||||||||
|
||||||||||
img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I advocate losing |
||||||||||
# backward compatibility | ||||||||||
if isinstance(img, torch.ByteTensor): | ||||||||||
return img.to(dtype=default_float_dtype).div(255) | ||||||||||
else: | ||||||||||
return img | ||||||||||
|
||||||||||
if accimage is not None and isinstance(pic, accimage.Image): | ||||||||||
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) | ||||||||||
pic.copyto(nppic) | ||||||||||
return torch.from_numpy(nppic).to(dtype=default_float_dtype) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I said, the previous implementation does not normalize There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the reference. Is this for all types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the values of pixels are stored at I'm a bit confused on why this copy works then: vision/torchvision/transforms/functional.py Lines 171 to 174 in 96f6e0a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are two functions for copying data from accimage to Tensors: one that works on |
||||||||||
normalize = isinstance(img, torch.ByteTensor) and \ | ||||||||||
(accimage is None or isinstance(pic, accimage.Image)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this reluctantly for BC. For some reason the previous method, did not normalize the picture if it was of type
Suggested change
Finally note that due to the explicit casts that we do in this version, having or not having the call is likely to affect less cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See Francisco's comment at #4326 (comment) for this |
||||||||||
img = img.to(dtype=default_float_dtype) | ||||||||||
if normalize: | ||||||||||
img.div_(255) | ||||||||||
|
||||||||||
# handle PIL Image | ||||||||||
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32} | ||||||||||
img = torch.from_numpy( | ||||||||||
np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the type does not match with any of the ones in the array then |
||||||||||
) | ||||||||||
|
||||||||||
if pic.mode == '1': | ||||||||||
img = 255 * img | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is no longer necessary. A boolean tensor will be casted to float. We don't need to multiply with 255 and divide by it later. |
||||||||||
img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) | ||||||||||
# put it from HWC to CHW format | ||||||||||
img = img.permute((2, 0, 1)).contiguous() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my remark for the calls on |
||||||||||
if isinstance(img, torch.ByteTensor): | ||||||||||
return img.to(dtype=default_float_dtype).div(255) | ||||||||||
else: | ||||||||||
return img | ||||||||||
return img | ||||||||||
|
||||||||||
|
||||||||||
def pil_to_tensor(pic): | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the previous code intended to return a float tensor.
That was not the case when the
pic
was a PIL but not when it was a numpy array. Proof: