Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up to_tensor. #4326

Closed
wants to merge 2 commits into from

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Aug 27, 2021

This PR is an effort to clean up the to_tensor and make it use pil_to_tensor internally. As we can see from the tests, this is impossible as we need to handle special cases.

The to_tensor method is problematic and we should deprecate it. It does too many things and it does not have consistent behaviour.

@datumbox datumbox force-pushed the transforms/cleanup_to_tensor branch from 0a50280 to c762e39 Compare August 27, 2021 10:49
@datumbox datumbox marked this pull request as draft August 27, 2021 10:49
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding some notes:

def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to a float tensor and normalize
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the previous code intended to return a float tensor.

That was not the case when the pic was a PIL but not when it was a numpy array. Proof:

>>> from torchvision.transforms.functional import *
>>> import numpy as np
>>> to_tensor(np.ones((5,5,3), dtype=np.int32)).dtype
torch.int32
>>> to_tensor(np.ones((5,5,3), dtype=np.uint8)).dtype
torch.float32

def to_tensor(pic):
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to a float tensor and normalize
the values to 0-1 scale if the input is of type `uint8`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should disclose under which conditions the normalize.


img = torch.from_numpy(pic.transpose((2, 0, 1))).contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I advocate losing contiguous() from here to align with the behaviour of pil_to_tensor(). Calls to contiguous() should be done by users of the library if they have good reason to do so. No other transform currently makes such a call.

pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
normalize = isinstance(img, torch.ByteTensor) and \
(accimage is None or isinstance(pic, accimage.Image))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this reluctantly for BC. For some reason the previous method, did not normalize the picture if it was of type accimage. I think it would be best if we have a clear policy concerning when normalization happens and that is when the returned type is uint8. Perhaps we should consider this a bug of the previous implementation. Thus I suggest to drop this check:

Suggested change
(accimage is None or isinstance(pic, accimage.Image))

Finally note that due to the explicit casts that we do in this version, having or not having the call is likely to affect less cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See Francisco's comment at #4326 (comment) for this

if accimage is not None and isinstance(pic, accimage.Image):
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
pic.copyto(nppic)
return torch.from_numpy(nppic).to(dtype=default_float_dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said, the previous implementation does not normalize accimage pics even if they are of uint8 type. I think this should be considered a bug of the previous implementation, not a feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI I believe accimage internally normalizes by 255

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reference. Is this for all types?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accimage only supports uint8 internally, and this function copies to float types right away.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the values of pixels are stored at uint8 internally but when you copy them they are rescaled to 0-1 and converted to float?

I'm a bit confused on why this copy works then:

# accimage format is always uint8 internally, so always return uint8 here
nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.uint8)
pic.copyto(nppic)
return torch.as_tensor(nppic)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two functions for copying data from accimage to Tensors: one that works on uint8 tensors and one that works for float tensors. They get dispatched in the code in here

# handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the type does not match with any of the ones in the array then np.uint8 is assumed. I think this has the risk of overflowing the array. Instead a clear I'm in favour of doing an explicit cast as in my implementation.

)

if pic.mode == '1':
img = 255 * img
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer necessary. A boolean tensor will be casted to float. We don't need to multiply with 255 and divide by it later.

img = 255 * img
img = img.view(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.permute((2, 0, 1)).contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my remark for the calls on contiguous() above.

raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
img = torch.from_numpy(pic.transpose((2, 0, 1)))
elif F_pil._is_pil_image(pic):
img = pil_to_tensor(pic)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out we can't just use pil_to_tensor here. The type handling the previous method was doing was necessary because:

TypeError: can't convert np.ndarray of type numpy.uint16. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because PyTorch doesn't (yet) support uint16, and is also a problem when reading PNG images of type uint16. There is an open issue in PyTorch about this in pytorch/pytorch#58734, which is part of the numpy compatibility workstream

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand but this explains why the to_tensor() requires all the extra handling. BTW I think this discussion provides more context on why it's there #4146 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think that there could be corner cases where ToTensor won't work properly, if what we are doing is casting uint16 datatypes into int16 datatypes.

This needs to be validated, but would indicate problems with the current function that we don't want to keep

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, see my earlier remark about overflowing #4326 (comment). I agree that it's not necessary to keep all the workarounds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants