Skip to content

Commit

Permalink
feat: from_sample
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Oct 26, 2022
1 parent 2c40bdf commit ee88749
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 105 deletions.
18 changes: 11 additions & 7 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -54,7 +54,7 @@ class IODescriptor(ABC, t.Generic[IOType]):
_mime_type: str
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample_input: IOType | None = None
_sample: IOType | None = None
descriptor_id: str | None

def __init_subclass__(cls, *, descriptor_id: str | None = None):
Expand All @@ -67,12 +67,12 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
cls.descriptor_id = descriptor_id

@property
def sample_input(self) -> IOType | None:
return self._sample_input
def sample(self) -> IOType | None:
return self._sample

@sample_input.setter
def sample_input(self, value: IOType) -> None:
self._sample_input = value
@sample.setter
def sample(self, value: IOType) -> None:
self._sample = value

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
Expand All @@ -94,6 +94,10 @@ def input_type(self) -> InputType:
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

def openapi_example(self) -> t.Any:
if self.sample is not None:
return self.sample

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError
Expand Down Expand Up @@ -126,5 +130,5 @@ async def to_proto(self, obj: IOType) -> t.Any:

@classmethod
@abstractmethod
def from_sample(cls, sample_input: IOType, **kwargs: t.Any) -> Self:
def from_sample(cls, sample: IOType, **kwargs: t.Any) -> Self:
...
24 changes: 18 additions & 6 deletions src/bentoml/_internal/io_descriptors/file.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import logging
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -122,12 +123,23 @@ def __new__(cls, kind: FileKind = "binaryio", mime_type: str | None = None) -> F
return res

@classmethod
def from_sample(cls, sample_input: FileType, kind: FileKind = "binaryio") -> Self:
def from_sample(cls, sample: FileType | str, kind: FileKind = "binaryio") -> Self:
import filetype

mime_type: str | None = filetype.guess_mime(sample_input)
mime_type: str | None = filetype.guess_mime(sample)

kls = cls(kind=kind, mime_type=mime_type)
kls.sample_input = sample_input

if isinstance(sample, FileLike):
kls.sample = sample
elif isinstance(sample, t.IO):
kls.sample = FileLike[bytes](sample, "<sample>")
elif isinstance(sample, str) and os.path.exists(sample):
with open(sample, "rb") as f:
kls.sample = FileLike[bytes](f, "<sample>")
else:
raise InvalidArgument(f"Unknown sample type: '{sample}'")

return kls

@classmethod
Expand Down Expand Up @@ -196,7 +208,7 @@ async def to_proto(self, obj: FileType) -> pb.File:
async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
raise NotImplementedError

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
Expand All @@ -213,7 +225,7 @@ def to_spec(self) -> dict[str, t.Any]:
},
}

async def from_http_request(self, request: Request) -> t.IO[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
content_type, _ = parse_options_header(request.headers["content-type"])
if content_type.decode("utf-8") == "multipart/form-data":
form = await request.form()
Expand All @@ -235,7 +247,7 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]:
return res # type: ignore
if content_type.decode("utf-8") == self._mime_type:
body = await request.body()
return t.cast(t.IO[bytes], FileLike(io.BytesIO(body), "<request body>"))
return FileLike[bytes](io.BytesIO(body), "<request body>")
raise BentoMLException(
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)
Expand Down
84 changes: 59 additions & 25 deletions src/bentoml/_internal/io_descriptors/image.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
import os
import typing as t
import functools
from typing import TYPE_CHECKING
Expand All @@ -21,18 +22,20 @@
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import RequestBody

PIL_EXC_MSG = "'Pillow' is required to use the Image IO descriptor. Install with 'pip install bentoml[io-image]'."

if TYPE_CHECKING:
from types import UnionType

import PIL
import PIL.Image
from typing_extensions import Self

from bentoml.grpc.v1alpha1 import service_pb2 as pb
from .base import OpenAPIResponse

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

_Mode = t.Literal[
Expand All @@ -43,9 +46,8 @@

# NOTE: pillow-simd only benefits users who want to do preprocessing
# TODO: add options for users to choose between simd and native mode
_exc = "'Pillow' is required to use the Image IO descriptor. Install it with: 'pip install -U Pillow'."
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=_exc)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=_exc)
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=PIL_EXC_MSG)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=PIL_EXC_MSG)

pb, _ = import_generated_stubs()

Expand All @@ -58,10 +60,7 @@
DEFAULT_PIL_MODE = "RGB"


PIL_WRITE_ONLY_FORMATS = {
"PALM",
"PDF",
}
PIL_WRITE_ONLY_FORMATS = {"PALM", "PDF"}
READABLE_MIMES: set[str] = None # type: ignore (lazy constant)
MIME_EXT_MAPPING: dict[str, str] = None # type: ignore (lazy constant)

Expand All @@ -74,9 +73,7 @@ def initialize_pillow():
try:
import PIL.Image
except ImportError:
raise InternalServerError(
f"'Pillow' is required to use {__name__}. Install Pillow with 'pip install bentoml[io-image]'"
)
raise InternalServerError(PIL_EXC_MSG)

PIL.Image.init()
MIME_EXT_MAPPING = {v: k for k, v in PIL.Image.MIME.items()} # type: ignore (lazy constant)
Expand Down Expand Up @@ -213,6 +210,41 @@ def __init__(
self._pilmode: _Mode | None = pilmode
self._format: str = MIME_EXT_MAPPING[self._mime_type]

@classmethod
def from_sample(
cls,
sample: ImageType | str,
pilmode: _Mode | None = DEFAULT_PIL_MODE,
*,
allowed_mime_types: t.Iterable[str] | None = None,
) -> Self:
from filetype.match import image_match

img_type = image_match(sample)
if img_type is None:
raise InvalidArgument(f"{sample} is not a valid image file type.")

kls = cls(
mime_type=img_type.mime,
pilmode=pilmode,
allowed_mime_types=allowed_mime_types,
)

if isinstance(sample, str) and os.path.exists(sample):
try:
with open(sample, "rb") as f:
kls.sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None
elif LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(sample):
kls.sample = PIL.Image.fromarray(sample, mode=pilmode)
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(sample):
kls.sample = sample
else:
raise InvalidArgument(f"Unknown sample type: '{sample}'")

return kls

def to_spec(self) -> dict[str, t.Any]:
return {
"id": self.descriptor_id,
Expand All @@ -224,7 +256,7 @@ def to_spec(self) -> dict[str, t.Any]:
}

@classmethod
def from_spec(cls) -> Self:
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in Image spec: {spec}")

Expand All @@ -239,20 +271,22 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_request_body(self) -> RequestBody:
return RequestBody(
content={
def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {
mtype: MediaType(schema=self.openapi_schema())
for mtype in self._allowed_mimes
},
required=True,
)
"required": True,
"x-bentoml-io-descriptor": self.to_spec(),
}

def openapi_responses(self) -> OpenAPIResponse:
return OpenAPIResponse(
description=SUCCESS_DESCRIPTION,
content={self._mime_type: MediaType(schema=self.openapi_schema())},
)
return {
"description": SUCCESS_DESCRIPTION,
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"x-bentoml-io-descriptor": self.to_spec(),
}

async def from_http_request(self, request: Request) -> ImageType:
content_type, _ = parse_options_header(request.headers["content-type"])
Expand Down Expand Up @@ -315,15 +349,15 @@ async def from_http_request(self, request: Request) -> ImageType:

try:
return PIL.Image.open(io.BytesIO(bytes_))
except PIL.UnidentifiedImageError: # type: ignore (bad pillow types)
raise BadInput("Failed to parse uploaded image file") from None
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse uploaded image file: {err}") from None

async def to_http_response(
self, obj: ImageType, ctx: Context | None = None
) -> Response:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
image = PIL.Image.fromarray(obj, mode=self._pilmode)
elif LazyType[PIL.Image.Image]("PIL.Image.Image").isinstance(obj):
elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj):
image = obj
else:
raise BadInput(
Expand Down
52 changes: 44 additions & 8 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -10,35 +10,36 @@
from starlette.requests import Request
from starlette.responses import Response

from bentoml.exceptions import BadInput

from .base import IODescriptor
from ..types import LazyType
from ..utils import LazyLoader
from ..utils import bentoml_cattr
from ..utils.pkg import pkg_version_info
from ..utils.http import set_cookies
from ...exceptions import BadInput
from ...exceptions import InvalidArgument
from ..service.openapi import REF_PREFIX
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
from ..service.openapi.specification import RequestBody

EXC_MSG = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install bentoml[io-json]'."

if TYPE_CHECKING:
from types import UnionType

import pydantic
import pydantic.schema as schema
from google.protobuf import struct_pb2
from typing_extensions import Self

from .. import external_typing as ext
from .base import OpenAPIResponse
from ..context import InferenceApiContext as Context

else:
_exc_msg = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install pydantic'."
pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=_exc_msg)
schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=_exc_msg)
pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=EXC_MSG)
schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=EXC_MSG)
# lazy load our proto generated.
struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2")
# lazy load numpy for processing ndarray.
Expand Down Expand Up @@ -200,6 +201,22 @@ def __init__(
"'validate_json' option from 'bentoml.io.JSON' has been deprecated. Use a Pydantic model to specify validation options instead."
)

@classmethod
def from_sample(
cls,
sample: JSONType,
*,
json_encoder: t.Type[json.JSONEncoder] = DefaultJsonEncoder,
) -> Self:
pydantic_model: t.Type[pydantic.BaseModel] | None = None
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(sample):
pydantic_model = sample.__class__

kls = cls(pydantic_model=pydantic_model, json_encoder=json_encoder)

kls.sample = sample
return kls

def to_spec(self) -> dict[str, t.Any]:
return {
"id": self.descriptor_id,
Expand Down Expand Up @@ -250,7 +267,26 @@ def openapi_components(self) -> dict[str, t.Any] | None:

return {"schemas": pydantic_components_schema(self._pydantic_model)}

def openapi_request_body(self) -> RequestBody:
def openapi_example(self) -> t.Any:
if self.sample is not None:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(
self.sample
):
return self.sample.dict()
elif isinstance(self.sample, str):
return json.dumps(
self.sample,
cls=self._json_encoder,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
)
elif isinstance(self.sample, dict):
return self.sample
return

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"required": True,
Expand Down

0 comments on commit ee88749

Please sign in to comment.