Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations for image #4708

Merged
merged 3 commits into from
May 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
143 changes: 94 additions & 49 deletions lib/streamlit/elements/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,53 @@
import imghdr
import io
import mimetypes
from typing import cast
from typing import cast, List, Optional, Sequence, TYPE_CHECKING, Tuple, Union
from typing_extensions import Final, Literal, TypeAlias
from urllib.parse import urlparse
import re

import numpy as np
from PIL import Image, ImageFile

import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.Image_pb2 import ImageList as ImageListProto

LOGGER = get_logger(__name__)
if TYPE_CHECKING:
import numpy.typing as npt
from typing import Any
from streamlit.delta_generator import DeltaGenerator

LOGGER: Final = get_logger(__name__)

# This constant is related to the frontend maximum content width specified
# in App.jsx main container
# 730 is the max width of element-container in the frontend, and 2x is for high
# DPI.
MAXIMUM_CONTENT_WIDTH = 2 * 730
MAXIMUM_CONTENT_WIDTH: Final[int] = 2 * 730

PILImage: TypeAlias = Union[ImageFile.ImageFile, Image.Image]
AtomicImage: TypeAlias = Union[PILImage, "npt.NDArray[Any]", io.BytesIO, str]
ImageOrImageList: TypeAlias = Union[AtomicImage, List[AtomicImage]]
UseColumnWith: TypeAlias = Optional[Union[Literal["auto", "always", "never"], bool]]
Channels: TypeAlias = Literal["RGB", "BGR"]
OutputFormat: TypeAlias = Literal["JPEG", "PNG", "auto"]


class ImageMixin:
def image(
self,
image,
caption=None,
width=None,
use_column_width=None,
clamp=False,
channels="RGB",
output_format="auto",
):
image: ImageOrImageList,
# TODO: Narrow type of caption, dependent on type of image,
# by way of overload
caption: Optional[Union[str, List[str]]] = None,
width: Optional[int] = None,
use_column_width: UseColumnWith = None,
clamp: bool = False,
channels: Channels = "RGB",
output_format: OutputFormat = "auto",
) -> "DeltaGenerator":
"""Display an image or list of images.

Parameters
Expand Down Expand Up @@ -126,25 +140,34 @@ def image(
channels,
output_format,
)
return self.dg._enqueue("imgs", image_list_proto)
return cast(
"DeltaGenerator",
self.dg._enqueue("imgs", image_list_proto),
)

@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
return cast("DeltaGenerator", self)


def _image_may_have_alpha_channel(image):
def _image_may_have_alpha_channel(image: PILImage) -> bool:
if image.mode in ("RGBA", "LA", "P"):
return True
else:
return False


def _format_from_image_type(image, output_format):
def _format_from_image_type(
image: PILImage,
output_format: str,
) -> Literal["JPEG", "PNG"]:
output_format = output_format.upper()
if output_format == "JPEG" or output_format == "PNG":
return output_format
return cast(
Literal["JPEG", "PNG"],
output_format,
)

# We are forgiving on the spelling of JPEG
if output_format == "JPG":
Expand All @@ -156,7 +179,11 @@ def _format_from_image_type(image, output_format):
return "JPEG"


def _PIL_to_bytes(image, format="JPEG", quality=100):
def _PIL_to_bytes(
image: PILImage,
format: Literal["JPEG", "PNG"] = "JPEG",
quality: int = 100,
) -> bytes:
tmp = io.BytesIO()

# User must have specified JPEG, so we must convert it
Expand All @@ -168,23 +195,26 @@ def _PIL_to_bytes(image, format="JPEG", quality=100):
return tmp.getvalue()


def _BytesIO_to_bytes(data):
def _BytesIO_to_bytes(data: io.BytesIO) -> bytes:
data.seek(0)
return data.getvalue()


def _np_array_to_bytes(array, output_format="JPEG"):
def _np_array_to_bytes(
array: "npt.NDArray[Any]",
output_format="JPEG",
) -> bytes:
img = Image.fromarray(array.astype(np.uint8))
format = _format_from_image_type(img, output_format)

return _PIL_to_bytes(img, format)


def _4d_to_list_3d(array):
def _4d_to_list_3d(array: "npt.NDArray[Any]") -> List["npt.NDArray[Any]"]:
return [array[i, :, :, :] for i in range(0, array.shape[0])]


def _verify_np_shape(array):
def _verify_np_shape(array: "npt.NDArray[Any]") -> "npt.NDArray[Any]":
if len(array.shape) not in (2, 3):
raise StreamlitAPIException("Numpy shape has to be of length 2 or 3.")
if len(array.shape) == 3 and array.shape[-1] not in (1, 3, 4):
Expand All @@ -200,7 +230,11 @@ def _verify_np_shape(array):
return array


def _normalize_to_bytes(data, width, output_format):
def _normalize_to_bytes(
data,
width: int,
output_format: OutputFormat,
) -> Tuple[bytes, str]:
image = Image.open(io.BytesIO(data))
actual_width, actual_height = image.size
format = _format_from_image_type(image, output_format)
Expand All @@ -225,7 +259,7 @@ def _normalize_to_bytes(data, width, output_format):
return data, mimetype


def _clip_image(image, clamp):
def _clip_image(image: "npt.NDArray[Any]", clamp: bool) -> "npt.NDArray[Any]":
data = image
if issubclass(image.dtype.type, np.floating):
if clamp:
Expand All @@ -244,7 +278,13 @@ def _clip_image(image, clamp):


def image_to_url(
image, width, clamp, channels, output_format, image_id, allow_emoji=False
image: AtomicImage,
width: int,
clamp: bool,
channels: Channels,
output_format: OutputFormat,
image_id: str,
allow_emoji: bool = False,
):
# PIL Images
if isinstance(image, ImageFile.ImageFile) or isinstance(image, Image.Image):
Expand All @@ -258,20 +298,25 @@ def image_to_url(
data = _BytesIO_to_bytes(image)

# Numpy Arrays (ie opencv)
elif type(image) is np.ndarray:
data = _verify_np_shape(image)
data = _clip_image(data, clamp)
elif isinstance(image, np.ndarray):
image = _clip_image(
_verify_np_shape(image),
clamp,
)

if channels == "BGR":
if len(data.shape) == 3:
data = data[:, :, [2, 1, 0]]
if len(image.shape) == 3:
image = image[:, :, [2, 1, 0]]
else:
raise StreamlitAPIException(
'When using `channels="BGR"`, the input image should '
"have exactly 3 color channels"
)

data = _np_array_to_bytes(data, output_format=output_format)
data = _np_array_to_bytes(
array=cast("npt.NDArray[Any]", image),
output_format=output_format,
)

# Strings
elif isinstance(image, str):
Expand Down Expand Up @@ -305,33 +350,33 @@ def image_to_url(


def marshall_images(
coordinates,
image,
caption,
width,
proto_imgs,
clamp,
channels="RGB",
output_format="auto",
):
channels = channels.upper()
coordinates: str,
image: ImageOrImageList,
caption: Optional[Union[str, "npt.NDArray[Any]", List[str]]],
width: int,
proto_imgs: ImageListProto,
clamp: bool,
channels: Channels = "RGB",
output_format: OutputFormat = "auto",
) -> None:
channels = cast(Channels, channels.upper())

# Turn single image and caption into one element list.
if type(image) is list:
images: Sequence[AtomicImage]
if isinstance(image, list):
images = image
elif isinstance(image, np.ndarray) and len(image.shape) == 4:
images = _4d_to_list_3d(image)
else:
if type(image) == np.ndarray and len(image.shape) == 4:
images = _4d_to_list_3d(image)
else:
images = [image]
images = [image]

if type(caption) is list:
captions = caption
captions: Sequence[Optional[str]] = caption
else:
if isinstance(caption, str):
captions = [caption]
# You can pass in a 1-D Numpy array as captions.
elif type(caption) == np.ndarray and len(caption.shape) == 1:
elif isinstance(caption, np.ndarray) and len(caption.shape) == 1:
captions = caption.tolist()
# If there are no captions then make the captions list the same size
# as the images list.
Expand Down