Skip to content

Commit

Permalink
feat: from_sample classmethod [skip ci]
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Oct 25, 2022
1 parent e987397 commit 2c40bdf
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 29 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies = [
"deepmerge",
"fs",
"numpy",
"filetype",
"opentelemetry-api>=1.9.0",
"opentelemetry-instrumentation==0.33b0",
"opentelemetry-instrumentation-aiohttp-client==0.33b0",
Expand Down
5 changes: 2 additions & 3 deletions src/bentoml/_internal/bento/build_dev_bentoml_whl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ def build_bentoml_editable_wheel(target_path: str) -> None:
return

try:
from build.env import IsolatedEnvBuilder

from build import ProjectBuilder
from build.env import IsolatedEnvBuilder # isort: skip
from build import ProjectBuilder # isort: skip
except ModuleNotFoundError as e:
raise MissingDependencyException(_exc_message) from e

Expand Down
20 changes: 17 additions & 3 deletions src/bentoml/_internal/io_descriptors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
from ..types import LazyType
from ..context import InferenceApiContext as Context
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import Reference
from ..service.openapi.specification import RequestBody

InputType = (
UnionType
| t.Type[t.Any]
| LazyType[t.Any]
| dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]]
)
OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]]


IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {}
Expand All @@ -54,6 +54,7 @@ class IODescriptor(ABC, t.Generic[IOType]):
_mime_type: str
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample_input: IOType | None = None
descriptor_id: str | None

def __init_subclass__(cls, *, descriptor_id: str | None = None):
Expand All @@ -65,6 +66,14 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
IO_DESCRIPTOR_REGISTRY[descriptor_id] = cls
cls.descriptor_id = descriptor_id

@property
def sample_input(self) -> IOType | None:
return self._sample_input

@sample_input.setter
def sample_input(self, value: IOType) -> None:
self._sample_input = value

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
Expand All @@ -90,7 +99,7 @@ def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError

@abstractmethod
def openapi_request_body(self) -> RequestBody:
def openapi_request_body(self) -> dict[str, t.Any]:
raise NotImplementedError

@abstractmethod
Expand All @@ -114,3 +123,8 @@ async def from_proto(self, field: t.Any) -> IOType:
@abstractmethod
async def to_proto(self, obj: IOType) -> t.Any:
...

@classmethod
@abstractmethod
def from_sample(cls, sample_input: IOType, **kwargs: t.Any) -> Self:
...
23 changes: 16 additions & 7 deletions src/bentoml/_internal/io_descriptors/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
from ..types import FileLike
from ..utils.http import set_cookies
from ...exceptions import BadInput
from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import RequestBody

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from typing_extensions import Self

from bentoml.grpc.v1alpha1 import service_pb2 as pb

from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

FileKind: t.TypeAlias = t.Literal["binaryio", "textio"]
Expand Down Expand Up @@ -110,9 +112,7 @@ async def predict(input_pdf: io.BytesIO[Any]) -> NDArray[Any]:

_proto_fields = ("file",)

def __new__( # pylint: disable=arguments-differ # returning subclass from new
cls, kind: FileKind = "binaryio", mime_type: str | None = None
) -> File:
def __new__(cls, kind: FileKind = "binaryio", mime_type: str | None = None) -> File:
mime_type = mime_type if mime_type is not None else "application/octet-stream"
if kind == "binaryio":
res = object.__new__(BytesIOFile)
Expand All @@ -121,8 +121,14 @@ def __new__( # pylint: disable=arguments-differ # returning subclass from new
res._mime_type = mime_type
return res

def to_spec(self):
raise NotImplementedError
@classmethod
def from_sample(cls, sample_input: FileType, kind: FileKind = "binaryio") -> Self:
import filetype

mime_type: str | None = filetype.guess_mime(sample_input)
kls = cls(kind=kind, mime_type=mime_type)
kls.sample_input = sample_input
return kls

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
Expand Down Expand Up @@ -193,6 +199,9 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
async def from_http_request(self, request: Request) -> t.IO[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError


class BytesIOFile(File, descriptor_id=None):
def to_spec(self) -> dict[str, t.Any]:
Expand Down
4 changes: 2 additions & 2 deletions src/bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ...exceptions import InternalServerError
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import RequestBody

Expand All @@ -31,6 +30,7 @@
import PIL.Image

from bentoml.grpc.v1alpha1 import service_pb2 as pb
from .base import OpenAPIResponse

from .. import external_typing as ext
from ..context import InferenceApiContext as Context
Expand Down Expand Up @@ -75,7 +75,7 @@ def initialize_pillow():
import PIL.Image
except ImportError:
raise InternalServerError(
"`Pillow` is required to use {__name__}\n Instructions: `pip install -U Pillow`"
f"'Pillow' is required to use {__name__}. Install Pillow with 'pip install bentoml[io-image]'"
)

PIL.Image.init()
Expand Down
15 changes: 3 additions & 12 deletions src/bentoml/_internal/io_descriptors/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
from ...exceptions import UnprocessableEntity
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import RequestBody

if TYPE_CHECKING:
import numpy as np
from typing_extensions import Self

from bentoml.grpc.v1alpha1 import service_pb2 as pb

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context
else:
from bentoml.grpc.utils import import_generated_stubs
Expand Down Expand Up @@ -228,8 +229,6 @@ def __init__(
self._enforce_dtype = enforce_dtype
self._enforce_shape = enforce_shape

self._sample_input = None

if self._enforce_dtype and not self._dtype:
raise InvalidArgument(
"'dtype' must be specified when 'enforce_dtype=True'"
Expand Down Expand Up @@ -269,14 +268,6 @@ def from_spec(cls, spec: dict[str, t.Any]) -> Self:
res = NumpyNdarray(**spec["args"])
return res

@property
def sample_input(self) -> ext.NpNDArray | None:
return self._sample_input

@sample_input.setter
def sample_input(self, value: ext.NpNDArray) -> None:
self._sample_input = value

def openapi_schema(self) -> Schema:
# Note that we are yet provide
# supports schemas for arrays that is > 2D.
Expand Down Expand Up @@ -407,7 +398,7 @@ def from_sample(
sample_input: ext.NpNDArray,
enforce_dtype: bool = True,
enforce_shape: bool = True,
) -> NumpyNdarray:
) -> Self:
"""
Create a :obj:`NumpyNdarray` IO Descriptor from given inputs.
Expand Down
6 changes: 4 additions & 2 deletions src/bentoml/_internal/service/openapi/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class Operation:
description: t.Optional[str] = None
externalDocs: t.Optional[ExternalDocumentation] = None
operationId: t.Optional[str] = None
requestBody: t.Optional[t.Union[RequestBody, Reference]] = None
requestBody: t.Optional[t.Union[RequestBody, Reference, t.Dict[str, t.Any]]] = None

# Not yet supported: parameters, callbacks, deprecated, servers, security

Expand Down Expand Up @@ -253,7 +253,9 @@ class Components:
schemas: t.Dict[str, t.Union[Schema, Reference]]
responses: t.Optional[t.Dict[str, t.Union[Response, Reference]]] = None
examples: t.Optional[t.Dict[str, t.Union[Example, Reference]]] = None
requestBodies: t.Optional[t.Dict[str, t.Union[RequestBody, Reference]]] = None
requestBodies: t.Optional[
t.Dict[str, t.Union[RequestBody, Reference, t.Dict[str, t.Any]]]
] = None
links: t.Optional[t.Dict[str, t.Union[Link, Reference]]] = None

# Not yet supported: securitySchemes, callbacks, parameters, headers
Expand Down

0 comments on commit 2c40bdf

Please sign in to comment.