Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: from_sample for IO descriptor #3143

Merged
merged 13 commits into from Nov 9, 2022
17 changes: 12 additions & 5 deletions pyproject.toml
Expand Up @@ -112,17 +112,22 @@ include = [
[project.optional-dependencies]
all = [
"bentoml[aws]",
"bentoml[io-json]",
"bentoml[io-image]",
"bentoml[io-pandas]",
"bentoml[io]",
"bentoml[grpc]",
"bentoml[grpc-reflection]",
"bentoml[grpc-channelz]",
"bentoml[tracing]",
]
aws = ["fs-s3fs"]
io = [
"bentoml[io-json]",
"bentoml[io-image]",
"bentoml[io-pandas]",
"bentoml[io-file]",
] # syntatic sugar for bentoml[io-json,io-image,io-pandas,io-file]
io-file = ["filetype"] # Currently use for from_sample
io-json = ["pydantic<2"] # currently we don't have support for pydantic 2.0
io-image = ["Pillow"]
io-image = ["bentoml[io-file]", "Pillow"]
io-pandas = ["pandas", "pyarrow"]
grpc = [
# Restrict maximum version due to breaking protobuf 4.21.0 changes
Expand Down Expand Up @@ -164,6 +169,7 @@ source = ["src"]

[tool.coverage.run]
branch = true
parallel = true
source = ["src/bentoml/"]
omit = [
"src/bentoml/__main__.py",
Expand Down Expand Up @@ -202,7 +208,8 @@ exclude_lines = [
"^\\s*except ImportError",
"if __name__ == .__main__.:",
"^\\s*if TYPE_CHECKING:",
"^\\s*@overload( |$)",
"^\\s*@(t\\.)?overload( |$)",
"@(abc\\.)?abstractmethod",
]

[tool.black]
Expand Down
103 changes: 81 additions & 22 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -19,6 +19,7 @@
from ..types import LazyType
from ..context import InferenceApiContext as Context
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import Reference

InputType = (
Expand All @@ -27,6 +28,7 @@
| LazyType[t.Any]
| dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]]
)
OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]]


IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {}
Expand All @@ -40,19 +42,48 @@ def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]:
return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)


class IODescriptor(ABC, t.Generic[IOType]):
class _OpenAPIMeta:
@abstractmethod
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError

@abstractmethod
def openapi_example(self) -> t.Any | None:
raise NotImplementedError

@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
raise NotImplementedError

@abstractmethod
def openapi_responses(self) -> dict[str, t.Any]:
raise NotImplementedError


class IODescriptor(ABC, _OpenAPIMeta, t.Generic[IOType]):
"""
IODescriptor describes the input/output data format of an InferenceAPI defined
in a :code:`bentoml.Service`. This is an abstract base class for extending new HTTP
endpoint IO descriptor types in BentoServer.
"""

__slots__ = ("_args", "_kwargs", "_proto_fields", "_mime_type", "descriptor_id")

HTTP_METHODS = ["POST"]

descriptor_id: str | None

_mime_type: str
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
descriptor_id: str | None
_sample: IOType | None = None
_initialized: bool = False
_args: t.Sequence[t.Any]
_kwargs: dict[str, t.Any]

def __init_subclass__(cls, *, descriptor_id: str | None = None):
if descriptor_id is not None:
Expand All @@ -63,52 +94,80 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
IO_DESCRIPTOR_REGISTRY[descriptor_id] = cls
cls.descriptor_id = descriptor_id

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError

@classmethod
@abstractmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError
def __new__(cls, *args: t.Any, **kwargs: t.Any) -> Self:
klass = object.__new__(cls)
klass.sample = t.cast(IOType, kwargs.pop("_sample", None))
klass._args = args or ()
klass._kwargs = kwargs or {}
return klass

def __getattr__(self, name: str) -> t.Any:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__getattribute__(self, name)

def __dir__(self) -> t.Iterable[str]:
if not self._initialized:
self._lazy_init()
assert self._initialized
return object.__dir__(self)

def __repr__(self) -> str:
return self.__class__.__qualname__

@abstractmethod
def input_type(self) -> InputType:
raise NotImplementedError
def _lazy_init(self) -> None:
self._initialized = True
self.__init__(*self._args, **self._kwargs)
del self._args
del self._kwargs

@property
def sample(self) -> IOType | None:
return self._sample

@sample.setter
def sample(self, value: IOType) -> None:
self._sample = value

# NOTE: for custom types handle, use 'create_sample.register' to register
# custom types for 'from_sample'
@classmethod
@abstractmethod
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError
def from_sample(cls, sample: IOType | t.Any, **kwargs: t.Any) -> Self:
return cls.__new__(cls, _sample=sample, **kwargs)

@property
def mime_type(self) -> str:
return self._mime_type

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError

@classmethod
@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError

@abstractmethod
def openapi_responses(self) -> dict[str, t.Any]:
def input_type(self) -> InputType:
raise NotImplementedError

@abstractmethod
async def from_http_request(self, request: Request) -> IOType:
...
raise NotImplementedError

@abstractmethod
async def to_http_response(
self, obj: IOType, ctx: Context | None = None
) -> Response:
...
raise NotImplementedError

@abstractmethod
async def from_proto(self, field: t.Any) -> IOType:
...
raise NotImplementedError

@abstractmethod
async def to_proto(self, obj: IOType) -> t.Any:
...
raise NotImplementedError
41 changes: 33 additions & 8 deletions src/bentoml/_internal/io_descriptors/file.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import logging
from typing import TYPE_CHECKING
Expand All @@ -12,10 +13,12 @@

from .base import IODescriptor
from ..types import FileLike
from ..utils import resolve_user_filepath
from ..utils.http import set_cookies
from ...exceptions import BadInput
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
Expand Down Expand Up @@ -112,19 +115,35 @@ async def predict(input_pdf: io.BytesIO[Any]) -> NDArray[Any]:

_proto_fields = ("file",)

def __new__( # pylint: disable=arguments-differ # returning subclass from new
cls, kind: FileKind = "binaryio", mime_type: str | None = None
def __new__(
cls, kind: FileKind = "binaryio", mime_type: str | None = None, **kwargs: t.Any
) -> File:
mime_type = mime_type if mime_type is not None else "application/octet-stream"
if kind == "binaryio":
res = object.__new__(BytesIOFile)
res = super().__new__(BytesIOFile, **kwargs)
else:
raise ValueError(f"invalid File kind '{kind}'")
res._mime_type = mime_type
return res

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
@classmethod
def from_sample(cls, sample: FileType | str, kind: FileKind = "binaryio") -> Self:
try:
import filetype
except ModuleNotFoundError:
raise MissingDependencyException(
"'filetype' is required to use 'from_sample'. Install it with 'pip install bentoml[io-file]'."
)
if isinstance(sample, t.IO):
sample = FileLike[bytes](sample, "<sample>")
elif isinstance(sample, (str, os.PathLike)):
p = resolve_user_filepath(sample, ctx=None)
with open(p, "rb") as f:
sample = FileLike[bytes](f, "<sample>")

return super().from_sample(
sample, kind=kind, mime_type=filetype.guess_mime(sample)
)

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
Expand All @@ -141,6 +160,9 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self):
pass

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
Expand Down Expand Up @@ -192,7 +214,10 @@ async def to_proto(self, obj: FileType) -> pb.File:
async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError


Expand All @@ -206,7 +231,7 @@ def to_spec(self) -> dict[str, t.Any]:
},
}

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
content_type, _ = parse_options_header(request.headers["content-type"])
if content_type.decode("utf-8") == "multipart/form-data":
form = await request.form()
Expand All @@ -228,7 +253,7 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]:
return res # type: ignore
if content_type.decode("utf-8") == self._mime_type:
body = await request.body()
return t.cast(t.IO[bytes], FileLike(io.BytesIO(body), "<request body>"))
return FileLike[bytes](io.BytesIO(body), "<request body>")
raise BentoMLException(
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)
Expand Down