Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Image.__array__ take optional dtype argument #5572

Merged
merged 3 commits into from Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions Tests/test_image_array.py
Expand Up @@ -14,6 +14,10 @@ def test(mode):
ai = numpy.array(im.convert(mode))
return ai.shape, ai.dtype.str, ai.nbytes

def test_with_dtype(dtype):
ai = numpy.array(im, dtype=dtype)
assert ai.dtype == dtype

# assert test("1") == ((100, 128), '|b1', 1600))
assert test("L") == ((100, 128), "|u1", 12800)

Expand All @@ -27,6 +31,9 @@ def test(mode):
assert test("RGBA") == ((100, 128, 4), "|u1", 51200)
assert test("RGBX") == ((100, 128, 4), "|u1", 51200)

test_with_dtype(numpy.float)
radarhere marked this conversation as resolved.
Show resolved Hide resolved
test_with_dtype(numpy.uint8)

with Image.open("Tests/images/truncated_jpeg.jpg") as im_truncated:
with pytest.raises(OSError):
numpy.array(im_truncated)
Expand Down
7 changes: 5 additions & 2 deletions src/PIL/Image.py
Expand Up @@ -681,7 +681,7 @@ def _repr_png_(self):
raise ValueError("Could not save to PNG for display") from e
return b.getvalue()

def __array__(self):
def __array__(self, dtype=None):
# numpy array interface support
import numpy as np

Expand All @@ -700,7 +700,10 @@ def __array__(self):
class ArrayData:
__array_interface__ = new

return np.array(ArrayData())
arr = np.array(ArrayData())
if dtype is not None:
arr = arr.astype(dtype)
return arr
t-vi marked this conversation as resolved.
Show resolved Hide resolved

def __getstate__(self):
return [self.info, self.mode, self.size, self.getpalette(), self.tobytes()]
Expand Down