Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations for media #4706

Merged
merged 1 commit into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 65 additions & 23 deletions lib/streamlit/elements/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,36 @@

import io
import re
from typing import cast
from typing import cast, Optional, TYPE_CHECKING, Union
from typing_extensions import Final, TypeAlias

from validators import url

import streamlit
from streamlit import type_util
from streamlit.in_memory_file_manager import in_memory_file_manager
from streamlit.proto.Audio_pb2 import Audio as AudioProto
from streamlit.proto.Video_pb2 import Video as VideoProto

if TYPE_CHECKING:
from typing import Any

from numpy import typing as npt

from streamlit.delta_generator import DeltaGenerator


Data: TypeAlias = Union[
str, bytes, io.BytesIO, io.RawIOBase, io.BufferedReader, "npt.NDArray[Any]", None
]


class MediaMixin:
def audio(self, data, format="audio/wav", start_time=0):
def audio(
self,
data: Data,
format: str = "audio/wav",
start_time: int = 0,
) -> "DeltaGenerator":
"""Display an audio player.

Parameters
Expand Down Expand Up @@ -57,9 +74,14 @@ def audio(self, data, format="audio/wav", start_time=0):
audio_proto = AudioProto()
coordinates = self.dg._get_delta_path_str()
marshall_audio(coordinates, audio_proto, data, format, start_time)
return self.dg._enqueue("audio", audio_proto)

def video(self, data, format="video/mp4", start_time=0):
return cast("DeltaGenerator", self.dg._enqueue("audio", audio_proto))

def video(
self,
data: Data,
format: str = "video/mp4",
start_time: int = 0,
) -> "DeltaGenerator":
"""Display a video player.

Parameters
Expand Down Expand Up @@ -98,17 +120,17 @@ def video(self, data, format="video/mp4", start_time=0):
video_proto = VideoProto()
coordinates = self.dg._get_delta_path_str()
marshall_video(coordinates, video_proto, data, format, start_time)
return self.dg._enqueue("video", video_proto)
return cast("DeltaGenerator", self.dg._enqueue("video", video_proto))

@property
def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
def dg(self) -> "DeltaGenerator":
"""Get our DeltaGenerator."""
return cast("streamlit.delta_generator.DeltaGenerator", self)
return cast("DeltaGenerator", self)


# Regular expression explained at https://regexr.com/4n2l2 Covers any youtube
# URL (incl. shortlinks and embed links) and extracts its code.
YOUTUBE_RE = re.compile(
YOUTUBE_RE: Final = re.compile(
# Protocol
r"http(?:s?):\/\/"
# Domain
Expand All @@ -118,15 +140,15 @@ def dg(self) -> "streamlit.delta_generator.DeltaGenerator":
)


def _reshape_youtube_url(url):
def _reshape_youtube_url(url: str) -> Optional[str]:
"""Return whether URL is any kind of YouTube embed or watch link. If so,
reshape URL into an embed link suitable for use in an iframe.

If not a YouTube URL, return None.

Parameters
----------
url : str or bytes
url : str

Example
-------
Expand All @@ -141,7 +163,12 @@ def _reshape_youtube_url(url):
return None


def _marshall_av_media(coordinates, proto, data, mimetype):
def _marshall_av_media(
coordinates: str,
proto: Union[AudioProto, VideoProto],
data: Data,
mimetype: str,
) -> None:
"""Fill audio or video proto based on contents of data.

Given a string, check if it's a url; if so, send it out without modification.
Expand All @@ -159,29 +186,38 @@ def _marshall_av_media(coordinates, proto, data, mimetype):
proto.url = this_file.url
return

data_as_bytes: bytes
if data is None:
# Allow empty values so media players can be shown without media.
return

# Assume bytes; try methods until we run out.
if isinstance(data, bytes):
pass
elif isinstance(data, bytes):
data_as_bytes = data
elif isinstance(data, io.BytesIO):
data.seek(0)
data = data.getvalue()
data_as_bytes = data.getvalue()
elif isinstance(data, io.RawIOBase) or isinstance(data, io.BufferedReader):
data.seek(0)
data = data.read()
read_data = data.read()
if read_data is None:
return
else:
data_as_bytes = read_data
elif type_util.is_type(data, "numpy.ndarray"):
data = data.tobytes()
data_as_bytes = data.tobytes()
else:
raise RuntimeError("Invalid binary data format: %s" % type(data))

this_file = in_memory_file_manager.add(data, mimetype, coordinates)
this_file = in_memory_file_manager.add(data_as_bytes, mimetype, coordinates)
proto.url = this_file.url


def marshall_video(coordinates, proto, data, mimetype="video/mp4", start_time=0):
def marshall_video(
coordinates: str,
proto: VideoProto,
data: Data,
mimetype: str = "video/mp4",
start_time: int = 0,
) -> None:
"""Marshalls a video proto, using url processors as needed.

Parameters
Expand Down Expand Up @@ -218,7 +254,13 @@ def marshall_video(coordinates, proto, data, mimetype="video/mp4", start_time=0)
_marshall_av_media(coordinates, proto, data, mimetype)


def marshall_audio(coordinates, proto, data, mimetype="audio/wav", start_time=0):
def marshall_audio(
coordinates: str,
proto: AudioProto,
data: Data,
mimetype: str = "audio/wav",
start_time: int = 0,
) -> None:
"""Marshalls an audio proto, using data and url processors as needed.

Parameters
Expand Down
8 changes: 7 additions & 1 deletion lib/tests/streamlit/help_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def test_deltagenerator_func(self):
self.assertEqual("audio", ds.name)
self.assertEqual("streamlit", ds.module)
self.assertEqual("<class 'method'>", ds.type)
self.assertEqual("(data, format='audio/wav', start_time=0)", ds.signature)
self.assertEqual(
"(data: Union[str, bytes, _io.BytesIO, io.RawIOBase, "
"_io.BufferedReader, ForwardRef('npt.NDArray[Any]'), NoneType], "
"format: str = 'audio/wav', start_time: int = 0) -> "
"'DeltaGenerator'",
ds.signature,
)
self.assertTrue(ds.doc_string.startswith("Display an audio player"))

def test_unwrapped_deltagenerator_func(self):
Expand Down