Skip to content

Commit

Permalink
feat: from_sample
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 7, 2022
1 parent 54ad1ff commit 166009f
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 67 deletions.
18 changes: 11 additions & 7 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -53,7 +53,7 @@ class IODescriptor(ABC, t.Generic[IOType]):
_mime_type: str
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample_input: IOType | None = None
_sample: IOType | None = None
descriptor_id: str | None

def __init_subclass__(cls, *, descriptor_id: str | None = None):
Expand All @@ -66,12 +66,12 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
cls.descriptor_id = descriptor_id

@property
def sample_input(self) -> IOType | None:
return self._sample_input
def sample(self) -> IOType | None:
return self._sample

@sample_input.setter
def sample_input(self, value: IOType) -> None:
self._sample_input = value
@sample.setter
def sample(self, value: IOType) -> None:
self._sample = value

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
Expand All @@ -93,6 +93,10 @@ def input_type(self) -> InputType:
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

def openapi_example(self) -> t.Any:
if self.sample is not None:
return self.sample

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError
Expand Down Expand Up @@ -125,5 +129,5 @@ async def to_proto(self, obj: IOType) -> t.Any:

@classmethod
@abstractmethod
def from_sample(cls, sample_input: IOType, **kwargs: t.Any) -> Self:
def from_sample(cls, sample: IOType, **kwargs: t.Any) -> Self:
...
24 changes: 18 additions & 6 deletions src/bentoml/_internal/io_descriptors/file.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import logging
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -122,12 +123,23 @@ def __new__(cls, kind: FileKind = "binaryio", mime_type: str | None = None) -> F
return res

@classmethod
def from_sample(cls, sample_input: FileType, kind: FileKind = "binaryio") -> Self:
def from_sample(cls, sample: FileType | str, kind: FileKind = "binaryio") -> Self:
import filetype

mime_type: str | None = filetype.guess_mime(sample_input)
mime_type: str | None = filetype.guess_mime(sample)

kls = cls(kind=kind, mime_type=mime_type)
kls.sample_input = sample_input

if isinstance(sample, FileLike):
kls.sample = sample
elif isinstance(sample, t.IO):
kls.sample = FileLike[bytes](sample, "<sample>")
elif isinstance(sample, str) and os.path.exists(sample):
with open(sample, "rb") as f:
kls.sample = FileLike[bytes](f, "<sample>")
else:
raise InvalidArgument(f"Unknown sample type: '{sample}'")

return kls

@classmethod
Expand Down Expand Up @@ -196,7 +208,7 @@ async def to_proto(self, obj: FileType) -> pb.File:
async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
Expand All @@ -213,7 +225,7 @@ def to_spec(self) -> dict[str, t.Any]:
},
}

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
content_type, _ = parse_options_header(request.headers["content-type"])
if content_type.decode("utf-8") == "multipart/form-data":
form = await request.form()
Expand All @@ -235,7 +247,7 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]:
return res # type: ignore
if content_type.decode("utf-8") == self._mime_type:
body = await request.body()
return t.cast(t.IO[bytes], FileLike(io.BytesIO(body), "<request body>"))
return FileLike[bytes](io.BytesIO(body), "<request body>")
raise BentoMLException(
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)
Expand Down
59 changes: 45 additions & 14 deletions src/bentoml/_internal/io_descriptors/image.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import functools
from typing import TYPE_CHECKING
Expand All @@ -22,6 +23,8 @@
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType

PIL_EXC_MSG = "'Pillow' is required to use the Image IO descriptor. Install with 'pip install bentoml[io-image]'."

if TYPE_CHECKING:
from types import UnionType

Expand All @@ -30,7 +33,6 @@
from typing_extensions import Self

from bentoml.grpc.v1alpha1 import service_pb2 as pb
from .base import OpenAPIResponse

from .. import external_typing as ext
from .base import OpenAPIResponse
Expand All @@ -44,9 +46,8 @@

# NOTE: pillow-simd only benefits users who want to do preprocessing
# TODO: add options for users to choose between simd and native mode
_exc = "'Pillow' is required to use the Image IO descriptor. Install it with: 'pip install -U Pillow'."
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=_exc)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=_exc)
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=PIL_EXC_MSG)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=PIL_EXC_MSG)

pb, _ = import_generated_stubs()

Expand All @@ -59,10 +60,7 @@
DEFAULT_PIL_MODE = "RGB"


PIL_WRITE_ONLY_FORMATS = {
"PALM",
"PDF",
}
PIL_WRITE_ONLY_FORMATS = {"PALM", "PDF"}
READABLE_MIMES: set[str] = None # type: ignore (lazy constant)
MIME_EXT_MAPPING: dict[str, str] = None # type: ignore (lazy constant)

Expand All @@ -75,9 +73,7 @@ def initialize_pillow():
try:
import PIL.Image
except ImportError:
raise InternalServerError(
f"'Pillow' is required to use {__name__}. Install Pillow with 'pip install bentoml[io-image]'"
)
raise InternalServerError(PIL_EXC_MSG)

PIL.Image.init()
MIME_EXT_MAPPING = {v: k for k, v in PIL.Image.MIME.items()} # type: ignore (lazy constant)
Expand Down Expand Up @@ -214,6 +210,41 @@ def __init__(
self._pilmode: _Mode | None = pilmode
self._format: str = MIME_EXT_MAPPING[self._mime_type]

@classmethod
def from_sample(
cls,
sample: ImageType | str,
pilmode: _Mode | None = DEFAULT_PIL_MODE,
*,
allowed_mime_types: t.Iterable[str] | None = None,
) -> Self:
from filetype.match import image_match

img_type = image_match(sample)
if img_type is None:
raise InvalidArgument(f"{sample} is not a valid image file type.")

kls = cls(
mime_type=img_type.mime,
pilmode=pilmode,
allowed_mime_types=allowed_mime_types,
)

if isinstance(sample, str) and os.path.exists(sample):
try:
with open(sample, "rb") as f:
kls.sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None
elif LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(sample):
kls.sample = PIL.Image.fromarray(sample, mode=pilmode)
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(sample):
kls.sample = sample
else:
raise InvalidArgument(f"Unknown sample type: '{sample}'")

return kls

def to_spec(self) -> dict[str, t.Any]:
return {
"id": self.descriptor_id,
Expand Down Expand Up @@ -318,15 +349,15 @@ async def from_http_request(self, request: Request) -> ImageType:

try:
return PIL.Image.open(io.BytesIO(bytes_))
except PIL.UnidentifiedImageError: # type: ignore (bad pillow types)
raise BadInput("Failed to parse uploaded image file") from None
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse uploaded image file: {err}") from None

async def to_http_response(
self, obj: ImageType, ctx: Context | None = None
) -> Response:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
elif LazyType[PIL.Image.Image]("PIL.Image.Image").isinstance(obj):
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj):
image = obj
else:
raise BadInput(
Expand Down
42 changes: 39 additions & 3 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -23,6 +23,8 @@
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType

EXC_MSG = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install bentoml[io-json]'."

if TYPE_CHECKING:
from types import UnionType

Expand All @@ -36,9 +38,8 @@
from ..context import InferenceApiContext as Context

else:
_exc_msg = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install pydantic'."
pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=_exc_msg)
schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=_exc_msg)
pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=EXC_MSG)
schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=EXC_MSG)
# lazy load our proto generated.
struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2")
# lazy load numpy for processing ndarray.
Expand Down Expand Up @@ -200,6 +201,22 @@ def __init__(
"'validate_json' option from 'bentoml.io.JSON' has been deprecated. Use a Pydantic model to specify validation options instead."
)

@classmethod
def from_sample(
cls,
sample: JSONType,
*,
json_encoder: t.Type[json.JSONEncoder] = DefaultJsonEncoder,
) -> Self:
pydantic_model: t.Type[pydantic.BaseModel] | None = None
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(sample):
pydantic_model = sample.__class__

kls = cls(pydantic_model=pydantic_model, json_encoder=json_encoder)

kls.sample = sample
return kls

def to_spec(self) -> dict[str, t.Any]:
return {
"id": self.descriptor_id,
Expand Down Expand Up @@ -250,6 +267,25 @@ def openapi_components(self) -> dict[str, t.Any] | None:

return {"schemas": pydantic_components_schema(self._pydantic_model)}

def openapi_example(self) -> t.Any:
if self.sample is not None:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(
self.sample
):
return self.sample.dict()
elif isinstance(self.sample, str):
return json.dumps(
self.sample,
cls=self._json_encoder,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
)
elif isinstance(self.sample, dict):
return self.sample
return

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
Expand Down
11 changes: 8 additions & 3 deletions src/bentoml/_internal/io_descriptors/multipart.py
Expand Up @@ -16,15 +16,16 @@
from ..utils.formparser import populate_multipart_requests
from ..utils.formparser import concat_to_multipart_response
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import RequestBody

if TYPE_CHECKING:
from types import UnionType

from typing_extensions import Self

from bentoml.grpc.v1alpha1 import service_pb2 as pb

from .base import OpenAPIResponse
from ..types import LazyType
from ..context import InferenceApiContext as Context
else:
Expand Down Expand Up @@ -174,6 +175,10 @@ def __init__(self, **inputs: IODescriptor[t.Any]):
def __repr__(self) -> str:
return f"Multipart({','.join([f'{k}={v}' for k,v in zip(self._inputs, map(repr, self._inputs.values()))])})"

@classmethod
def from_sample(cls, sample: dict[str, t.Any]) -> Self:
pass

def input_type(
self,
) -> dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]]:
Expand Down Expand Up @@ -217,7 +222,7 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_request_body(self) -> RequestBody:
def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"required": True,
Expand Down

0 comments on commit 166009f

Please sign in to comment.