diff --git a/Tests/test_image.py b/Tests/test_image.py index 7922505185b..0f73795363d 100644 --- a/Tests/test_image.py +++ b/Tests/test_image.py @@ -604,6 +604,15 @@ def test_remap_palette(self): with Image.open("Tests/images/hopper.gif") as im: assert_image_equal(im, im.remap_palette(list(range(256)))) + # Test identity transform with an RGBA palette + im = Image.new("P", (256, 1)) + for x in range(256): + im.putpixel((x, 0), x) + im.putpalette(list(range(256)) * 4, "RGBA") + im_remapped = im.remap_palette(list(range(256))) + assert_image_equal(im, im_remapped) + assert im.palette.palette == im_remapped.palette.palette + # Test illegal image mode with hopper() as im: with pytest.raises(ValueError): diff --git a/src/PIL/Image.py b/src/PIL/Image.py index 0ba2808f851..eb9239bec8c 100644 --- a/src/PIL/Image.py +++ b/src/PIL/Image.py @@ -1867,10 +1867,15 @@ def remap_palette(self, dest_map, source_palette=None): if self.mode not in ("L", "P"): raise ValueError("illegal image mode") + bands = 3 + palette_mode = "RGB" if source_palette is None: if self.mode == "P": self.load() - source_palette = self.im.getpalette("RGB")[:768] + palette_mode = self.im.getpalettemode() + if palette_mode == "RGBA": + bands = 4 + source_palette = self.im.getpalette(palette_mode, palette_mode) else: # L-mode source_palette = bytearray(i // 3 for i in range(768)) @@ -1879,7 +1884,9 @@ def remap_palette(self, dest_map, source_palette=None): # pick only the used colors from the palette for i, oldPosition in enumerate(dest_map): - palette_bytes += source_palette[oldPosition * 3 : oldPosition * 3 + 3] + palette_bytes += source_palette[ + oldPosition * bands : oldPosition * bands + bands + ] new_positions[oldPosition] = i # replace the palette color id of all pixel with the new id @@ -1905,19 +1912,23 @@ def remap_palette(self, dest_map, source_palette=None): m_im = self.copy() m_im.mode = "P" - m_im.palette = ImagePalette.ImagePalette("RGB", palette=mapping_palette * 3) + m_im.palette = ImagePalette.ImagePalette( + palette_mode, palette=mapping_palette * bands + ) # possibly set palette dirty, then # m_im.putpalette(mapping_palette, 'L') # converts to 'P' # or just force it. # UNDONE -- this is part of the general issue with palettes - m_im.im.putpalette("RGB;L", m_im.palette.tobytes()) + m_im.im.putpalette(palette_mode + ";L", m_im.palette.tobytes()) m_im = m_im.convert("L") - # Internally, we require 768 bytes for a palette. - new_palette_bytes = palette_bytes + (768 - len(palette_bytes)) * b"\x00" - m_im.putpalette(new_palette_bytes) - m_im.palette = ImagePalette.ImagePalette("RGB", palette=palette_bytes) + # Internally, we require 256 palette entries. + new_palette_bytes = ( + palette_bytes + ((256 * bands) - len(palette_bytes)) * b"\x00" + ) + m_im.putpalette(new_palette_bytes, palette_mode) + m_im.palette = ImagePalette.ImagePalette(palette_mode, palette=palette_bytes) if "transparency" in self.info: try: