Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with PIL 8.3.0: __array__() takes 1 positional argument but 2 were given #4146

Closed
slettner opened this issue Jul 1, 2021 · 15 comments 路 Fixed by #4148
Closed

Error with PIL 8.3.0: __array__() takes 1 positional argument but 2 were given #4146

slettner opened this issue Jul 1, 2021 · 15 comments 路 Fixed by #4148

Comments

@slettner
Copy link

slettner commented Jul 1, 2021

馃悰 Bug

The torchvision.transformer function ToTensor raise a TypeError when receiving a image loaded with PIL.
The error origins from a call to np.array and contains the message:
__array__() takes 1 positional argument but 2 were given

To Reproduce

Steps to reproduce the behavior:

import io
import requests
import torchvision.transforms as T
from PIL import Image

resp = requests.get('https://picsum.photos/200')
img = Image.open(io.BytesIO(resp.content))

preprocess1 = T.Compose([
   T.ToTensor(),
])

x = preprocess1(img)
print(x.shape)

Pip Freeze Log:

certifi==2021.5.30
chardet==4.0.0
idna==2.10
numpy==1.21.0
Pillow==8.3.0
requests==2.25.1
torch==1.9.0
torchvision==0.10.0
typing-extensions==3.10.0.0
urllib3==1.26.6

Full Traceback:

Traceback (most recent call last):
  File "/Users/sebastianlettner/Desktop/tmp/test.py", line 13, in <module>
    x = preprocess1(img)
  File "/Users/sebastianlettner/Desktop/tmp/.venv/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
    img = t(img)
  File "/Users/sebastianlettner/Desktop/tmp/.venv/lib/python3.8/site-packages/torchvision/transforms/transforms.py", line 97, in __call__
    return F.to_tensor(pic)
  File "/Users/sebastianlettner/Desktop/tmp/.venv/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 129, in to_tensor
    np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
TypeError: __array__() takes 1 positional argument but 2 were given

Expected behavior

I expect that the conversion works and I reveive a torch.Tensor with shape (200, 200, 3)

Environment

Collecting environment information...
PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 10.15.7 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.29)
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.10 (default, May  4 2021, 03:05:50)  [Clang 12.0.0 (clang-1200.0.32.29)] (64-bit runtime)
Python platform: macOS-10.15.7-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.0
[pip3] torch==1.9.0
[pip3] torchvision==0.10.0
[conda] Could not collect
  • PyTorch / torchvision Version (e.g., 1.0 / 0.4.0): 1.9.0 / 0.10.0
  • OS (e.g., Linux): macOS (10.15.7 (19H15))
  • How you installed PyTorch / torchvision (conda, pip, source): pip
  • Build command you used (if compiling from source): n/A
  • Python version: 3.8.10
  • CUDA/cuDNN version: n/A
  • GPU models and configuration: n/A
  • Any other relevant information:

Additional context

n/A

@NicolasHug
Copy link
Member

Thanks for the report, I can reproduce.

The issue seems to come from the PIL version. This passes with PIL 8.2, but fails with PIL 8.3.

import numpy as np
from PIL import Image
img = Image.fromarray(np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8))
np.array(img, np.int16)

As a temporary workaround for you, downgrade PIL to 8.2. However it would be nice to get this fixed.

Would you mind checking if a similar issue was raised in the PIL repo, and if not, submit one and ping me there? Thanks!!

@NicolasHug
Copy link
Member

lol: python-pillow/Pillow#5571

@slettner
Copy link
Author

slettner commented Jul 1, 2021

Hi Nicolas, nice to hear that you can reproduce the issue. Let's see how they proceed in the PIL repo.

Cheers, Basti

@pmeier
Copy link
Collaborator

pmeier commented Jul 1, 2021

@NicolasHug We should change

pillow_ver = ' >= 5.3.0'

to pillow_ver = ' >= 5.3.0, != 8.3.0'. Depending on how fast they fix it, this patch could make it into torchvision==0.10.1 (or are we not going to release a bug fix patch?). This way users can get the normal behavior back out of the box fairly soon.

@NicolasHug
Copy link
Member

NicolasHug commented Jul 1, 2021

It's up to PIL to release a bugfix release, not us :)

8.3 was released 8 hours ago and they're already working on a fix python-pillow/Pillow#5572 which should be in soon.

@pmeier
Copy link
Collaborator

pmeier commented Jul 1, 2021

True, but we know that our stuff will not work with 8.3.0. Meaning, if we don't exclude it pip install torchvision pillow==8.3.0 is valid.

@NicolasHug
Copy link
Member

Is this a common practice? I've never seen it.
I understand in this case this is a question of compatibility between torchvision and PIL, but if we pushed the logic further, we might as well exclude all PIL versions that have known bugs, a.k.a only accept the very latest PIL?

@pmeier
Copy link
Collaborator

pmeier commented Jul 1, 2021

Is this a common practice? I've never seen it.

I've seen it a few times, for example here

but if we pushed the logic further, we might as well exclude all PIL versions that have known bugs, a.k.a only accept the very latest PIL?

If we know that there are Pillow versions that will break our code, than yes, we should explicitly exclude them. Most of the time Pillow bugs will only degrade our quality.

@NicolasHug
Copy link
Member

thanks for the details, that sounds good then. I opened #4148

@t-vi
Copy link
Contributor

t-vi commented Jul 1, 2021

The alternatives are:

  • Don't pass dtype to np.array,
  • detect the problem and monkey-patch __array__. Should be low risk to check for Image.__array__ and replace it with a wrapper.

@pmeier
Copy link
Collaborator

pmeier commented Jul 1, 2021

@t-vi

Don't pass dtype to np.array,

I don't know for sure, but the extra stuff suggests that we need to do that:

# handle PIL Image
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)
)

detect the problem and monkey-patch __array__. Should be low risk to check for Image.__array__ and replace it with a wrapper.

If pillow publishes a release with your patch before we do, this would be irrelevant. I imagine they will release a small patch a lot sooner than we do.

@jamt9000
Copy link
Contributor

jamt9000 commented Jul 1, 2021

Just ran in to this with adjust_hue

Traceback (most recent call last):
  File "testhue.py", line 6, in <module>
    F.adjust_hue(im, 0.1)
  File "/Users/jamesthewlis/miniconda3/envs/unitary/lib/python3.8/site-packages/torchvision/transforms/functional.py", line 844, in adjust_hue
    return F_pil.adjust_hue(img, hue_factor)
  File "/Users/jamesthewlis/miniconda3/envs/unitary/lib/python3.8/site-packages/torchvision/transforms/functional_pil.py", line 96, in adjust_hue
    np_h = np.array(h, dtype=np.uint8)
TypeError: __array__() takes 1 positional argument but 2 were given
pillow                    8.3.0            py38hee640a0_0    conda-forge

@NicolasHug NicolasHug changed the title transform toTensor fails with PIL image Error with PIL 8.3.0: __array__() takes 1 positional argument but 2 were given Jul 5, 2021
@NicolasHug
Copy link
Member

Re-opening temporarily as other users might bump into this issue, this will hopefully help avoid duplicated entries like #4152

@radarhere
Copy link
Contributor

Pillow 8.3.1 should now be released with a fix for this.

@pmeier
Copy link
Collaborator

pmeier commented Jul 7, 2021

Pillow 8.3.1 should now be released with a fix for this.

We can close this now. Unless someone explicitly specifies Pillow==8.3.0, pip install torch torchvision will now install a compatible Pillow version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants