Skip to content

Commit

Permalink
Type annotations for type_util (#4704)
Browse files Browse the repository at this point in the history
* Type annotations for type_util

* Fix import

* Formatting

* Fix typing issues

* Formatting

* Formatting

* Add missing comma

* Add type args

* Add type arg

* Defer importing TypeGuard

* Avoid referencing Styler
  • Loading branch information
harahu committed May 9, 2022
1 parent f24afb9 commit ed7699d
Showing 1 changed file with 60 additions and 39 deletions.
99 changes: 60 additions & 39 deletions lib/streamlit/type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,35 @@
"""A bunch of useful utilities for dealing with types."""

import re
from typing import Any, Optional, Sequence, Tuple, Union, cast
from typing import (
Any,
cast,
Iterable,
Optional,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing_extensions import Final, TypeAlias

from pandas import DataFrame, Series, Index
import numpy as np
import pyarrow as pa

from streamlit import errors

if TYPE_CHECKING:
import sympy
from pandas.io.formats.style import Styler
from typing_extensions import TypeGuard

OptionSequence = Union[Sequence[Any], DataFrame, Series, Index, np.ndarray]
Key = Union[str, int]


def is_type(obj, fqn_type_pattern):
def is_type(obj: Any, fqn_type_pattern: Union[str, "re.Pattern[str]"]) -> bool:
"""Check type without importing expensive modules.
Parameters
Expand All @@ -54,58 +70,64 @@ def is_type(obj, fqn_type_pattern):
return fqn_type_pattern.match(fqn_type) is not None


def get_fqn(the_type):
def get_fqn(the_type: type) -> str:
"""Get module.type_name for a given type."""
module = the_type.__module__
name = the_type.__qualname__
return "%s.%s" % (module, name)


def get_fqn_type(obj):
def get_fqn_type(obj: Any) -> str:
"""Get module.type_name for a given object."""
return get_fqn(type(obj))


_PANDAS_DF_TYPE_STR = "pandas.core.frame.DataFrame"
_PANDAS_INDEX_TYPE_STR = "pandas.core.indexes.base.Index"
_PANDAS_SERIES_TYPE_STR = "pandas.core.series.Series"
_PANDAS_STYLER_TYPE_STR = "pandas.io.formats.style.Styler"
_NUMPY_ARRAY_TYPE_STR = "numpy.ndarray"
_PANDAS_DF_TYPE_STR: Final = "pandas.core.frame.DataFrame"
_PANDAS_INDEX_TYPE_STR: Final = "pandas.core.indexes.base.Index"
_PANDAS_SERIES_TYPE_STR: Final = "pandas.core.series.Series"
_PANDAS_STYLER_TYPE_STR: Final = "pandas.io.formats.style.Styler"
_NUMPY_ARRAY_TYPE_STR: Final = "numpy.ndarray"

_DATAFRAME_LIKE_TYPES = (
_DATAFRAME_LIKE_TYPES: Final[Tuple[str, ...]] = (
_PANDAS_DF_TYPE_STR,
_PANDAS_INDEX_TYPE_STR,
_PANDAS_SERIES_TYPE_STR,
_PANDAS_STYLER_TYPE_STR,
_NUMPY_ARRAY_TYPE_STR,
)

_DATAFRAME_COMPATIBLE_TYPES = (
DataFrameLike: TypeAlias = Union[DataFrame, Index, Series, "Styler"]

_DATAFRAME_COMPATIBLE_TYPES: Final[Tuple[type, ...]] = (
dict,
list,
type(None),
) # type: Tuple[type, ...]
)

_BYTES_LIKE_TYPES = (
DataFrameCompatible: TypeAlias = Union[dict, list, None]

_BYTES_LIKE_TYPES: Final[Tuple[type, ...]] = (
bytes,
bytearray,
)

BytesLike: TypeAlias = Union[bytes, bytearray]


def is_dataframe(obj):
def is_dataframe(obj: Any) -> "TypeGuard[DataFrame]":
return is_type(obj, _PANDAS_DF_TYPE_STR)


def is_dataframe_like(obj):
def is_dataframe_like(obj: Any) -> "TypeGuard[DataFrameLike]":
return any(is_type(obj, t) for t in _DATAFRAME_LIKE_TYPES)


def is_dataframe_compatible(obj):
def is_dataframe_compatible(obj: Any) -> "TypeGuard[DataFrameCompatible]":
"""True if type that can be passed to convert_anything_to_df."""
return is_dataframe_like(obj) or type(obj) in _DATAFRAME_COMPATIBLE_TYPES


def is_bytes_like(obj: Any) -> bool:
def is_bytes_like(obj: Any) -> "TypeGuard[BytesLike]":
"""True if the type is considered bytes-like for the purposes of
protobuf data marshalling."""
return isinstance(obj, _BYTES_LIKE_TYPES)
Expand All @@ -125,32 +147,31 @@ def to_bytes(obj: Any) -> bytes:
raise RuntimeError(f"{obj} is not convertible to bytes")


_SYMPY_RE = re.compile(r"^sympy.*$")
_SYMPY_RE: Final = re.compile(r"^sympy.*$")


def is_sympy_expession(obj):
def is_sympy_expession(obj: Any) -> "TypeGuard[sympy.Expr]":
"""True if input is a SymPy expression."""
if not is_type(obj, _SYMPY_RE):
return False

try:
import sympy

if isinstance(obj, sympy.Expr):
return True
except:
return isinstance(obj, sympy.Expr)
except ImportError:
return False


_ALTAIR_RE = re.compile(r"^altair\.vegalite\.v\d+\.api\.\w*Chart$")
_ALTAIR_RE: Final = re.compile(r"^altair\.vegalite\.v\d+\.api\.\w*Chart$")


def is_altair_chart(obj):
def is_altair_chart(obj: Any) -> bool:
"""True if input looks like an Altair chart."""
return is_type(obj, _ALTAIR_RE)


def is_keras_model(obj):
def is_keras_model(obj: Any) -> bool:
"""True if input looks like a Keras model."""
return (
is_type(obj, "keras.engine.sequential.Sequential")
Expand All @@ -160,7 +181,7 @@ def is_keras_model(obj):
)


def is_plotly_chart(obj):
def is_plotly_chart(obj: Any) -> bool:
"""True if input looks like a Plotly chart."""
return (
is_type(obj, "plotly.graph_objs._figure.Figure")
Expand All @@ -169,7 +190,7 @@ def is_plotly_chart(obj):
)


def is_graphviz_chart(obj):
def is_graphviz_chart(obj: Any) -> bool:
"""True if input looks like a GraphViz chart."""
return (
# GraphViz < 0.18
Expand All @@ -181,21 +202,21 @@ def is_graphviz_chart(obj):
)


def _is_plotly_obj(obj):
def _is_plotly_obj(obj: Any) -> bool:
"""True if input if from a type that lives in plotly.plotly_objs."""
the_type = type(obj)
return the_type.__module__.startswith("plotly.graph_objs")


def _is_list_of_plotly_objs(obj):
def _is_list_of_plotly_objs(obj: Any) -> bool:
if type(obj) is not list:
return False
if len(obj) == 0:
return False
return all(_is_plotly_obj(item) for item in obj)


def _is_probably_plotly_dict(obj):
def _is_probably_plotly_dict(obj: Any) -> bool:
if not isinstance(obj, dict):
return False

Expand All @@ -214,15 +235,15 @@ def _is_probably_plotly_dict(obj):
return False


_FUNCTION_TYPE = type(lambda: 0)
_FUNCTION_TYPE: Final[Type[Any]] = type(lambda: 0)


def is_function(x):
def is_function(x: Any) -> bool:
"""Return True if x is a function."""
return type(x) == _FUNCTION_TYPE


def is_namedtuple(x):
def is_namedtuple(x: Any) -> bool:
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
Expand All @@ -233,11 +254,11 @@ def is_namedtuple(x):
return all(type(n).__name__ == "str" for n in f)


def is_pandas_styler(obj):
def is_pandas_styler(obj: Any) -> "TypeGuard[Styler]":
return is_type(obj, _PANDAS_STYLER_TYPE_STR)


def is_pydeck(obj):
def is_pydeck(obj: Any) -> bool:
"""True if input looks like a pydeck chart."""
return is_type(obj, "pydeck.bindings.deck.Deck")

Expand Down Expand Up @@ -292,7 +313,7 @@ def convert_anything_to_df(df: Any) -> DataFrame:
)


def ensure_iterable(obj):
def ensure_iterable(obj: Any) -> Iterable[Any]:
"""Try to convert different formats to something iterable. Most inputs
are assumed to be iterable, but if we have a DataFrame, we can just
select the first column to iterate over. If the input is not iterable,
Expand All @@ -308,12 +329,12 @@ def ensure_iterable(obj):
"""
if is_dataframe(obj):
return obj.iloc[:, 0]
return cast(Iterable[Any], obj.iloc[:, 0])

try:
iter(obj)
return obj
except:
return cast(Iterable[Any], obj)
except TypeError:
raise


Expand Down

0 comments on commit ed7699d

Please sign in to comment.