Skip to content

Commit

Permalink
fix: not using singledispatch
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 8, 2022
1 parent cf6c4a1 commit 1eec99d
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 193 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ source = ["src"]

[tool.coverage.run]
branch = true
parallel = true
source = ["src/bentoml/"]
omit = [
"src/bentoml/__main__.py",
Expand Down Expand Up @@ -208,7 +209,8 @@ exclude_lines = [
"^\\s*except ImportError",
"if __name__ == .__main__.:",
"^\\s*if TYPE_CHECKING:",
"^\\s*@overload( |$)",
"^\\s*@(t\\.)?overload( |$)",
"@(abc\\.)?abstractmethod",
]

[tool.black]
Expand Down
35 changes: 14 additions & 21 deletions src/bentoml/_internal/io_descriptors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from abc import abstractmethod
from typing import TYPE_CHECKING

from ..utils import singledispatchmethod
from ...exceptions import InvalidArgument

if TYPE_CHECKING:
Expand Down Expand Up @@ -37,13 +36,6 @@
IOType = t.TypeVar("IOType")


@singledispatchmethod
def set_sample(self: IODescriptor[t.Any], value: t.Any) -> None:
raise InvalidArgument(
f"Unsupported sample type: '{type(value)}' (value: {value}). To register type '{type(value)}' to {self.__class__.__name__} implement a dispatch function and register types to 'set_sample.register'"
)


def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
if "id" not in spec:
raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.")
Expand Down Expand Up @@ -103,13 +95,10 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
cls.descriptor_id = descriptor_id

def __new__(cls, *args: t.Any, **kwargs: t.Any) -> Self:
sample = kwargs.pop("_sample", None)
klass = object.__new__(cls)
if sample is None:
set_sample.register(type(None), lambda self, _: self)
klass._set_sample(sample)
klass._args = args
klass._kwargs = kwargs
klass.sample = t.cast(IOType, kwargs.pop("_sample", None))
klass._args = args or ()
klass._kwargs = kwargs or {}
return klass

def __getattr__(self, name: str) -> t.Any:
Expand All @@ -118,17 +107,21 @@ def __getattr__(self, name: str) -> t.Any:
assert self._initialized
return object.__getattribute__(self, name)

def __dir__(self) -> t.Iterable[str]:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__dir__(self)

def __repr__(self) -> str:
return self.__class__.__qualname__

def _lazy_init(self) -> None:
self.__init__(*self._args, **self._kwargs)
self._initialized = True
self.__init__(*self._args, **self._kwargs)
del self._args
del self._kwargs

_set_sample: singledispatchmethod[None] = set_sample

@property
def sample(self) -> IOType | None:
return self._sample
Expand Down Expand Up @@ -163,18 +156,18 @@ def input_type(self) -> InputType:

@abstractmethod
async def from_http_request(self, request: Request) -> IOType:
...
raise NotImplementedError

@abstractmethod
async def to_http_response(
self, obj: IOType, ctx: Context | None = None
) -> Response:
...
raise NotImplementedError

@abstractmethod
async def from_proto(self, field: t.Any) -> IOType:
...
raise NotImplementedError

@abstractmethod
async def to_proto(self, obj: IOType) -> t.Any:
...
raise NotImplementedError
26 changes: 6 additions & 20 deletions src/bentoml/_internal/io_descriptors/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from starlette.responses import Response
from starlette.datastructures import UploadFile

from .base import set_sample
from .base import IODescriptor
from ..types import FileLike
from ..utils import resolve_user_filepath
Expand Down Expand Up @@ -135,30 +134,17 @@ def from_sample(cls, sample: FileType | str, kind: FileKind = "binaryio") -> Sel
raise MissingDependencyException(
"'filetype' is required to use 'from_sample'. Install it with 'pip install bentoml[io-file]'."
)
if isinstance(sample, t.IO):
sample = FileLike[bytes](sample, "<sample>")
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
with open(p, "rb") as f:
sample = FileLike[bytes](f, "<sample>")

return super().from_sample(
sample, kind=kind, mime_type=filetype.guess_mime(sample)
)

@set_sample.register(type(FileLike))
def _(cls, sample: FileLike[bytes]):
cls.sample = sample

@set_sample.register(t.IO)
def _(cls, sample: t.IO[t.Any]):
if isinstance(cls, File):
cls.sample = FileLike[bytes](sample, "<sample>")

@set_sample.register(str)
@set_sample.register(os.PathLike)
def _(cls, sample: str):
# This is to ensure we can register same type with different
# implementation across different IO descriptors.
if isinstance(cls, File):
p = resolve_user_filepath(sample, ctx=None)
with open(p, "rb") as f:
cls.sample = FileLike[bytes](f, "<sample>")

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
if "args" not in spec:
Expand Down
35 changes: 9 additions & 26 deletions src/bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import tempfile
import functools
from typing import TYPE_CHECKING
from urllib.parse import quote
Expand All @@ -13,7 +11,6 @@
from starlette.responses import Response
from starlette.datastructures import UploadFile

from .base import set_sample
from .base import IODescriptor
from ..types import LazyType
from ..utils import LazyLoader
Expand All @@ -33,7 +30,6 @@
from types import UnionType

import PIL
import numpy as np
import PIL.Image
from typing_extensions import Self

Expand All @@ -54,7 +50,6 @@
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=PIL_EXC_MSG)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=PIL_EXC_MSG)

np = LazyLoader("np", globals(), "numpy")
pb, _ = import_generated_stubs()

# NOTES: we will keep type in quotation to avoid backward compatibility
Expand Down Expand Up @@ -235,18 +230,16 @@ def from_sample(
raise InvalidArgument(f"{sample} is not a valid image file type.")

if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(sample):

@set_sample.register(np.ndarray)
def _(cls: Self, sample: ext.NpNDArray):
if isinstance(cls, Image):
cls.sample = PIL.Image.fromarray(sample, mode=pilmode)

sample = PIL.Image.fromarray(sample, mode=pilmode)
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(sample):

@set_sample.register(PIL.Image.Image)
def _(cls: Self, sample: PIL.Image.Image):
if isinstance(cls, Image):
cls.sample = sample
sample = sample
elif isinstance(sample, str):
p = resolve_user_filepath(sample, ctx=None)
try:
with open(p, "rb") as f:
cls.sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None

return super().from_sample(
sample,
Expand All @@ -255,16 +248,6 @@ def _(cls: Self, sample: PIL.Image.Image):
allowed_mime_types=allowed_mime_types,
)

@set_sample.register(str)
def _(cls, sample: str):
if isinstance(cls, Image):
p = resolve_user_filepath(sample, ctx=None)
try:
with open(p, "rb") as f:
cls.sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None

def to_spec(self) -> dict[str, t.Any]:
return {
"id": self.descriptor_id,
Expand Down
34 changes: 18 additions & 16 deletions src/bentoml/_internal/io_descriptors/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from starlette.requests import Request
from starlette.responses import Response

from .base import set_sample
from .base import IODescriptor
from ..types import LazyType
from ..utils import LazyLoader
Expand Down Expand Up @@ -212,26 +211,29 @@ def from_sample(
pydantic_model: t.Type[pydantic.BaseModel] | None = None
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(sample):
pydantic_model = sample.__class__

@set_sample.register(pydantic.BaseModel)
def _(cls: Self, sample: pydantic.BaseModel):
if isinstance(cls, JSON):
cls.sample = sample
elif isinstance(sample, str):
try:
sample = json.loads(sample)
except json.JSONDecodeError as e:
raise BadInput(
f"Unable to parse JSON string. Please make sure the input is a valid JSON string: {e}"
) from None
elif isinstance(sample, bytes):
try:
sample = json.loads(sample.decode())
except json.JSONDecodeError as e:
raise BadInput(
f"Unable to parse JSON bytes. Please make sure the input is a valid JSON bytes: {e}"
) from None
elif not isinstance(sample, (dict, list)):
raise BadInput(
f"Unable to infer JSON type from sample: {sample}. Please make sure the input is a valid JSON object."
)

return super().from_sample(
sample, pydantic_model=pydantic_model, json_encoder=json_encoder
)

@set_sample.register(dict)
def _(cls, sample: dict[str, t.Any]):
if isinstance(cls, JSON):
cls.sample = sample

@set_sample.register(str)
def _(cls, sample: str):
if isinstance(cls, JSON):
cls.sample = json.loads(sample)

def to_spec(self) -> dict[str, t.Any]:
return {
"id": self.descriptor_id,
Expand Down
6 changes: 4 additions & 2 deletions src/bentoml/_internal/io_descriptors/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def __repr__(self) -> str:
return f"Multipart({','.join([f'{k}={v}' for k,v in zip(self._inputs, map(repr, self._inputs.values()))])})"

@classmethod
def from_sample(cls, sample: dict[str, t.Any]) -> Self:
return cls(**sample)
def from_sample(
cls, sample: dict[str, t.Any] # pylint: disable=unused-argument
) -> Self:
raise NotImplementedError("'from_sample' is not supported for Multipart.")

def input_type(
self,
Expand Down
35 changes: 18 additions & 17 deletions src/bentoml/_internal/io_descriptors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from starlette.requests import Request
from starlette.responses import Response

from .base import set_sample
from .base import IODescriptor
from ..types import LazyType
from ..utils import LazyLoader
Expand Down Expand Up @@ -217,6 +216,15 @@ def __init__(
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
):
if enforce_dtype and not dtype:
raise InvalidArgument(
"'dtype' must be specified when 'enforce_dtype=True'"
) from None
if enforce_shape and not shape:
raise InvalidArgument(
"'shape' must be specified when 'enforce_shape=True'"
) from None

if dtype and not isinstance(dtype, np.dtype):
# Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64")
try:
Expand Down Expand Up @@ -429,29 +437,22 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
raise BentoMLException(
"'NumpyNdarray.from_sample()' expects a 'numpy.array', not 'numpy.generic'."
) from None
try:
if not isinstance(sample, np.ndarray):
sample = np.array(sample)
except ValueError:
raise BentoMLException(
f"Failed to create a 'numpy.ndarray' from given sample {sample}"
) from None

return super().from_sample(
sample,
shape=sample.shape,
dtype=sample.dtype,
enforce_dtype=enforce_dtype,
enforce_shape=enforce_shape,
)

@set_sample.register(np.ndarray)
def _(cls, sample: ext.NpNDArray):
if isinstance(cls, NumpyNdarray):
cls.sample = sample
cls._shape = sample.shape
cls._dtype = sample.dtype

@set_sample.register(list)
@set_sample.register(tuple)
def _(cls, sample: t.Sequence[t.Any]):
if isinstance(cls, NumpyNdarray):
__ = np.array(sample)
cls.sample = __
cls._shape = __.shape
cls._dtype = __.dtype

async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
"""
Process incoming protobuf request and convert it to ``numpy.ndarray``
Expand Down

0 comments on commit 1eec99d

Please sign in to comment.