Skip to content

Commit

Permalink
feat: from_sample for IO descriptor (#3143)
Browse files Browse the repository at this point in the history
* feat: from_sample impl

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* feat: from_sample

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: sample shouldn't be a memoryview [skip ci]

depends on #3144 to be merged.

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: move from_sample deps to optional via io-file

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: dispatch by types for sample

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* feat: openapi and dispatcher fix

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: refactor and __slots__ implementation for Base

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: types

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: tests [skip ci]

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: not using singledispatch

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: tests (incremental)

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* fix: different output

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

* chore: update _from_sample implementation

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>

Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 9, 2022
1 parent bd092f2 commit 39faf23
Show file tree
Hide file tree
Showing 14 changed files with 586 additions and 282 deletions.
17 changes: 12 additions & 5 deletions pyproject.toml
Expand Up @@ -112,17 +112,22 @@ include = [
[project.optional-dependencies]
all = [
"bentoml[aws]",
"bentoml[io-json]",
"bentoml[io-image]",
"bentoml[io-pandas]",
"bentoml[io]",
"bentoml[grpc]",
"bentoml[grpc-reflection]",
"bentoml[grpc-channelz]",
"bentoml[tracing]",
]
aws = ["fs-s3fs"]
io = [
"bentoml[io-json]",
"bentoml[io-image]",
"bentoml[io-pandas]",
"bentoml[io-file]",
] # syntatic sugar for bentoml[io-json,io-image,io-pandas,io-file]
io-file = ["filetype"] # Currently use for from_sample
io-json = ["pydantic<2"] # currently we don't have support for pydantic 2.0
io-image = ["Pillow"]
io-image = ["bentoml[io-file]", "Pillow"]
io-pandas = ["pandas", "pyarrow"]
grpc = [
# Restrict maximum version due to breaking protobuf 4.21.0 changes
Expand Down Expand Up @@ -164,6 +169,7 @@ source = ["src"]

[tool.coverage.run]
branch = true
parallel = true
source = ["src/bentoml/"]
omit = [
"src/bentoml/__main__.py",
Expand Down Expand Up @@ -202,7 +208,8 @@ exclude_lines = [
"^\\s*except ImportError",
"if __name__ == .__main__.:",
"^\\s*if TYPE_CHECKING:",
"^\\s*@overload( |$)",
"^\\s*@(t\\.)?overload( |$)",
"@(abc\\.)?abstractmethod",
]

[tool.black]
Expand Down
101 changes: 81 additions & 20 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -19,6 +19,7 @@
from ..types import LazyType
from ..context import InferenceApiContext as Context
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import Reference

InputType = (
Expand All @@ -27,6 +28,7 @@
| LazyType[t.Any]
| dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]]
)
OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]]


IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {}
Expand All @@ -40,19 +42,48 @@ def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)


class IODescriptor(ABC, t.Generic[IOType]):
class _OpenAPIMeta:
@abstractmethod
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError

@abstractmethod
def openapi_example(self) -> t.Any | None:
raise NotImplementedError

@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
raise NotImplementedError

@abstractmethod
def openapi_responses(self) -> dict[str, t.Any]:
raise NotImplementedError


class IODescriptor(ABC, _OpenAPIMeta, t.Generic[IOType]):
"""
IODescriptor describes the input/output data format of an InferenceAPI defined
in a :code:`bentoml.Service`. This is an abstract base class for extending new HTTP
endpoint IO descriptor types in BentoServer.
"""

__slots__ = ("_args", "_kwargs", "_proto_fields", "_mime_type", "descriptor_id")

HTTP_METHODS = ["POST"]

descriptor_id: str | None

_mime_type: str
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
descriptor_id: str | None
_sample: IOType | None = None
_initialized: bool = False
_args: t.Sequence[t.Any]
_kwargs: dict[str, t.Any]

def __init_subclass__(cls, *, descriptor_id: str | None = None):
if descriptor_id is not None:
Expand All @@ -63,52 +94,82 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
IO_DESCRIPTOR_REGISTRY[descriptor_id] = cls
cls.descriptor_id = descriptor_id

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
if TYPE_CHECKING:

@classmethod
@abstractmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError
def __init__(self, **kwargs: t.Any) -> None:
...

def __getattr__(self, name: str) -> t.Any:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__getattribute__(self, name)

def __dir__(self) -> t.Iterable[str]:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__dir__(self)

def __repr__(self) -> str:
return self.__class__.__qualname__

@abstractmethod
def input_type(self) -> InputType:
raise NotImplementedError
def _lazy_init(self) -> None:
self._initialized = True
self.__init__(*self._args, **self._kwargs)
del self._args
del self._kwargs

@property
def sample(self) -> IOType | None:
return self._sample

@sample.setter
def sample(self, value: IOType) -> None:
self._sample = value

@classmethod
def from_sample(cls, sample: IOType | t.Any, **kwargs: t.Any) -> Self:
klass = cls(**kwargs)
sample = klass._from_sample(sample)
klass.sample = sample
return klass

@abstractmethod
def openapi_schema(self) -> Schema | Reference:
def _from_sample(self, sample: t.Any) -> IOType:
raise NotImplementedError

@property
def mime_type(self) -> str:
return self._mime_type

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError

@classmethod
@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError

@abstractmethod
def openapi_responses(self) -> dict[str, t.Any]:
def input_type(self) -> InputType:
raise NotImplementedError

@abstractmethod
async def from_http_request(self, request: Request) -> IOType:
...
raise NotImplementedError

@abstractmethod
async def to_http_response(
self, obj: IOType, ctx: Context | None = None
) -> Response:
...
raise NotImplementedError

@abstractmethod
async def from_proto(self, field: t.Any) -> IOType:
...
raise NotImplementedError

@abstractmethod
async def to_proto(self, obj: IOType) -> t.Any:
...
raise NotImplementedError
39 changes: 31 additions & 8 deletions src/bentoml/_internal/io_descriptors/file.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import logging
from typing import TYPE_CHECKING
Expand All @@ -12,10 +13,12 @@

from .base import IODescriptor
from ..types import FileLike
from ..utils import resolve_user_filepath
from ..utils.http import set_cookies
from ...exceptions import BadInput
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
Expand Down Expand Up @@ -112,19 +115,33 @@ async def predict(input_pdf: io.BytesIO[Any]) -> NDArray[Any]:

_proto_fields = ("file",)

def __new__( # pylint: disable=arguments-differ # returning subclass from new
cls, kind: FileKind = "binaryio", mime_type: str | None = None
def __new__(
cls, kind: FileKind = "binaryio", mime_type: str | None = None, **kwargs: t.Any
) -> File:
mime_type = mime_type if mime_type is not None else "application/octet-stream"
if kind == "binaryio":
res = object.__new__(BytesIOFile)
res = super().__new__(BytesIOFile, **kwargs)
else:
raise ValueError(f"invalid File kind '{kind}'")
res._mime_type = mime_type
return res

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
def _from_sample(self, sample: FileType | str) -> FileType:
try:
import filetype
except ModuleNotFoundError:
raise MissingDependencyException(
"'filetype' is required to use 'from_sample'. Install it with 'pip install bentoml[io-file]'."
)
if isinstance(sample, t.IO):
sample = FileLike[bytes](sample, "<sample>")
self._mime_type = filetype.guess_mime(sample)
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
self._mime_type = filetype.guess_mime(p)
with open(p, "rb") as f:
sample = FileLike[bytes](f, "<sample>")
return sample

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
Expand All @@ -141,6 +158,9 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self):
pass

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
Expand Down Expand Up @@ -192,7 +212,10 @@ async def to_proto(self, obj: FileType) -> pb.File:
async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError


Expand All @@ -206,7 +229,7 @@ def to_spec(self) -> dict[str, t.Any]:
},
}

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
content_type, _ = parse_options_header(request.headers["content-type"])
if content_type.decode("utf-8") == "multipart/form-data":
form = await request.form()
Expand All @@ -228,7 +251,7 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]:
return res # type: ignore
if content_type.decode("utf-8") == self._mime_type:
body = await request.body()
return t.cast(t.IO[bytes], FileLike(io.BytesIO(body), "<request body>"))
return FileLike[bytes](io.BytesIO(body), "<request body>")
raise BentoMLException(
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)
Expand Down

0 comments on commit 39faf23

Please sign in to comment.