Skip to content

Commit

Permalink
Add type annotations for image (#4708)
Browse files Browse the repository at this point in the history
* Add type annotations for image

* Standardise type narrowing

* Fix typo
  • Loading branch information
harahu committed May 10, 2022
1 parent af5a25b commit d7d0257
Showing 1 changed file with 94 additions and 49 deletions.
143 changes: 94 additions & 49 deletions lib/streamlit/elements/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,53 @@
import imghdr
import io
import mimetypes
from typing import cast
from typing import cast, List, Optional, Sequence, TYPE_CHECKING, Tuple, Union
from typing_extensions import Final, Literal, TypeAlias
from urllib.parse import urlparse
import re

import numpy as np
from PIL import Image, ImageFile

import streamlit
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.Image_pb2 import ImageList as ImageListProto

LOGGER = get_logger(__name__)
if TYPE_CHECKING:
import numpy.typing as npt
from typing import Any
from streamlit.delta_generator import DeltaGenerator

LOGGER: Final = get_logger(__name__)

# This constant is related to the frontend maximum content width specified
# in App.jsx main container
# 730 is the max width of element-container in the frontend, and 2x is for high
# DPI.
MAXIMUM_CONTENT_WIDTH = 2 * 730
MAXIMUM_CONTENT_WIDTH: Final[int] = 2 * 730

PILImage: TypeAlias = Union[ImageFile.ImageFile, Image.Image]
AtomicImage: TypeAlias = Union[PILImage, "npt.NDArray[Any]", io.BytesIO, str]
ImageOrImageList: TypeAlias = Union[AtomicImage, List[AtomicImage]]
UseColumnWith: TypeAlias = Optional[Union[Literal["auto", "always", "never"], bool]]
Channels: TypeAlias = Literal["RGB", "BGR"]
OutputFormat: TypeAlias = Literal["JPEG", "PNG", "auto"]


class ImageMixin:
def image(
self,
image,
caption=None,
width=None,
use_column_width=None,
clamp=False,
channels="RGB",
output_format="auto",
):
image: ImageOrImageList,
# TODO: Narrow type of caption, dependent on type of image,
# by way of overload
caption: Optional[Union[str, List[str]]] = None,
width: Optional[int] = None,
use_column_width: UseColumnWith = None,
clamp: bool = False,
channels: Channels = "RGB",
output_format: OutputFormat = "auto",
) -> "DeltaGenerator":
"""Display an image or list of images.
Parameters
Expand Down Expand Up @@ -126,25 +140,34 @@ def image(
channels,
output_format,
)
return self.dg._enqueue("imgs", image_list_proto)
return cast(
"DeltaGenerator",
self.dg._enqueue("imgs", image_list_proto),
)

@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
return cast("DeltaGenerator", self)


def _image_may_have_alpha_channel(image):
def _image_may_have_alpha_channel(image: PILImage) -> bool:
if image.mode in ("RGBA", "LA", "P"):
return True
else:
return False


def _format_from_image_type(image, output_format):
def _format_from_image_type(
image: PILImage,
output_format: str,
) -> Literal["JPEG", "PNG"]:
output_format = output_format.upper()
if output_format == "JPEG" or output_format == "PNG":
return output_format
return cast(
Literal["JPEG", "PNG"],
output_format,
)

# We are forgiving on the spelling of JPEG
if output_format == "JPG":
Expand All @@ -156,7 +179,11 @@ def _format_from_image_type(image, output_format):
return "JPEG"


def _PIL_to_bytes(image, format="JPEG", quality=100):
def _PIL_to_bytes(
image: PILImage,
format: Literal["JPEG", "PNG"] = "JPEG",
quality: int = 100,
) -> bytes:
tmp = io.BytesIO()

# User must have specified JPEG, so we must convert it
Expand All @@ -168,23 +195,26 @@ def _PIL_to_bytes(image, format="JPEG", quality=100):
return tmp.getvalue()


def _BytesIO_to_bytes(data):
def _BytesIO_to_bytes(data: io.BytesIO) -> bytes:
data.seek(0)
return data.getvalue()


def _np_array_to_bytes(array, output_format="JPEG"):
def _np_array_to_bytes(
array: "npt.NDArray[Any]",
output_format="JPEG",
) -> bytes:
img = Image.fromarray(array.astype(np.uint8))
format = _format_from_image_type(img, output_format)

return _PIL_to_bytes(img, format)


def _4d_to_list_3d(array):
def _4d_to_list_3d(array: "npt.NDArray[Any]") -> List["npt.NDArray[Any]"]:
return [array[i, :, :, :] for i in range(0, array.shape[0])]


def _verify_np_shape(array):
def _verify_np_shape(array: "npt.NDArray[Any]") -> "npt.NDArray[Any]":
if len(array.shape) not in (2, 3):
raise StreamlitAPIException("Numpy shape has to be of length 2 or 3.")
if len(array.shape) == 3 and array.shape[-1] not in (1, 3, 4):
Expand All @@ -200,7 +230,11 @@ def _verify_np_shape(array):
return array


def _normalize_to_bytes(data, width, output_format):
def _normalize_to_bytes(
data,
width: int,
output_format: OutputFormat,
) -> Tuple[bytes, str]:
image = Image.open(io.BytesIO(data))
actual_width, actual_height = image.size
format = _format_from_image_type(image, output_format)
Expand All @@ -225,7 +259,7 @@ def _normalize_to_bytes(data, width, output_format):
return data, mimetype


def _clip_image(image, clamp):
def _clip_image(image: "npt.NDArray[Any]", clamp: bool) -> "npt.NDArray[Any]":
data = image
if issubclass(image.dtype.type, np.floating):
if clamp:
Expand All @@ -244,7 +278,13 @@ def _clip_image(image, clamp):


def image_to_url(
image, width, clamp, channels, output_format, image_id, allow_emoji=False
image: AtomicImage,
width: int,
clamp: bool,
channels: Channels,
output_format: OutputFormat,
image_id: str,
allow_emoji: bool = False,
):
# PIL Images
if isinstance(image, ImageFile.ImageFile) or isinstance(image, Image.Image):
Expand All @@ -258,20 +298,25 @@ def image_to_url(
data = _BytesIO_to_bytes(image)

# Numpy Arrays (ie opencv)
elif type(image) is np.ndarray:
data = _verify_np_shape(image)
data = _clip_image(data, clamp)
elif isinstance(image, np.ndarray):
image = _clip_image(
_verify_np_shape(image),
clamp,
)

if channels == "BGR":
if len(data.shape) == 3:
data = data[:, :, [2, 1, 0]]
if len(image.shape) == 3:
image = image[:, :, [2, 1, 0]]
else:
raise StreamlitAPIException(
'When using `channels="BGR"`, the input image should '
"have exactly 3 color channels"
)

data = _np_array_to_bytes(data, output_format=output_format)
data = _np_array_to_bytes(
array=cast("npt.NDArray[Any]", image),
output_format=output_format,
)

# Strings
elif isinstance(image, str):
Expand Down Expand Up @@ -305,33 +350,33 @@ def image_to_url(


def marshall_images(
coordinates,
image,
caption,
width,
proto_imgs,
clamp,
channels="RGB",
output_format="auto",
):
channels = channels.upper()
coordinates: str,
image: ImageOrImageList,
caption: Optional[Union[str, "npt.NDArray[Any]", List[str]]],
width: int,
proto_imgs: ImageListProto,
clamp: bool,
channels: Channels = "RGB",
output_format: OutputFormat = "auto",
) -> None:
channels = cast(Channels, channels.upper())

# Turn single image and caption into one element list.
if type(image) is list:
images: Sequence[AtomicImage]
if isinstance(image, list):
images = image
elif isinstance(image, np.ndarray) and len(image.shape) == 4:
images = _4d_to_list_3d(image)
else:
if type(image) == np.ndarray and len(image.shape) == 4:
images = _4d_to_list_3d(image)
else:
images = [image]
images = [image]

if type(caption) is list:
captions = caption
captions: Sequence[Optional[str]] = caption
else:
if isinstance(caption, str):
captions = [caption]
# You can pass in a 1-D Numpy array as captions.
elif type(caption) == np.ndarray and len(caption.shape) == 1:
elif isinstance(caption, np.ndarray) and len(caption.shape) == 1:
captions = caption.tolist()
# If there are no captions then make the captions list the same size
# as the images list.
Expand Down

0 comments on commit d7d0257

Please sign in to comment.