Skip to content

Commit

Permalink
feat: openapi and dispatcher fix
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 8, 2022
1 parent a6c77bd commit dafcac5
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 183 deletions.
29 changes: 18 additions & 11 deletions src/bentoml/_internal/io_descriptors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@


@singledispatchmethod
def create_sample(self: IODescriptor[t.Any], value: t.Any) -> None:
def set_sample(self: IODescriptor[t.Any], value: t.Any) -> IODescriptor[t.Any]:
raise InvalidArgument(
f"Unsupported sample type: '{type(value)}' (value: {value}). To register type '{type(value)}' to {self.__class__.__name__} implement a dispatch function and register types to 'create_sample.register'"
f"Unsupported sample type: '{type(value)}' (value: {value}). To register type '{type(value)}' to {self.__class__.__name__} implement a dispatch function and register types to 'set_sample.register'"
)


Expand All @@ -65,7 +65,7 @@ class IODescriptor(ABC, t.Generic[IOType]):
_rpc_content_type: str = "application/grpc"
_proto_fields: tuple[ProtoField]
_sample: IOType | None = None
_create_sample: singledispatchmethod[None] = create_sample
_set_sample: singledispatchmethod["IODescriptor[t.Any]"] = set_sample

def __init_subclass__(cls, *, descriptor_id: str | None = None):
if descriptor_id is not None:
Expand All @@ -76,11 +76,14 @@ def __init_subclass__(cls, *, descriptor_id: str | None = None):
IO_DESCRIPTOR_REGISTRY[descriptor_id] = cls
cls.descriptor_id = descriptor_id

def __new__(cls, *args: t.Any, **kwargs: t.Any):
def __new__(cls, *args: t.Any, **kwargs: t.Any) -> Self:
sample = kwargs.pop("_sample", None)
kls = super().__new__(cls)
if sample is not None:
kls._create_sample(sample)
kls = object.__new__(cls)
if sample is None:
set_sample.register(type(None), lambda self, _: self)
kls = kls._set_sample(sample)
# TODO: lazy init
kls.__init__(*args, **kwargs)
return kls

@property
Expand All @@ -98,6 +101,10 @@ def sample(self, value: IOType) -> None:
def from_sample(cls, sample: IOType | t.Any, **kwargs: t.Any) -> Self:
return cls.__new__(cls, _sample=sample, **kwargs)

@property
def mime_type(self) -> str:
return self._mime_type

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
Expand All @@ -118,14 +125,14 @@ def input_type(self) -> InputType:
def openapi_schema(self) -> Schema | Reference:
raise NotImplementedError

def openapi_example(self) -> t.Any:
if self.sample is not None:
return self.sample

@abstractmethod
def openapi_components(self) -> dict[str, t.Any] | None:
raise NotImplementedError

@abstractmethod
def openapi_example(self) -> t.Any | None:
raise NotImplementedError

@abstractmethod
def openapi_request_body(self) -> dict[str, t.Any]:
raise NotImplementedError
Expand Down
32 changes: 19 additions & 13 deletions src/bentoml/_internal/io_descriptors/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from starlette.responses import Response
from starlette.datastructures import UploadFile

from .base import set_sample
from .base import IODescriptor
from .base import create_sample
from ..types import FileLike
from ..utils import resolve_user_filepath
from ..utils.http import set_cookies
Expand Down Expand Up @@ -140,24 +140,27 @@ def from_sample(cls, sample: FileType | str, kind: FileKind = "binaryio") -> Sel
sample, kind=kind, mime_type=filetype.guess_mime(sample)
)

@create_sample.register(type(FileLike))
def _(self, sample: FileLike[bytes]) -> None:
self.sample = sample
@set_sample.register(type(FileLike))
def _(cls, sample: FileLike[bytes]):
cls.sample = sample
return cls

@create_sample.register(t.IO)
def _(self, sample: t.IO[t.Any]) -> None:
if isinstance(self, File):
self.sample = FileLike[bytes](sample, "<sample>")
@set_sample.register(t.IO)
def _(cls, sample: t.IO[t.Any]):
if isinstance(cls, File):
cls.sample = FileLike[bytes](sample, "<sample>")
return cls

@create_sample.register(str)
@create_sample.register(os.PathLike)
def _(self, sample: str) -> None:
@set_sample.register(str)
@set_sample.register(os.PathLike)
def _(cls, sample: str):
# This is to ensure we can register same type with different
# implementation across different IO descriptors.
if isinstance(self, File):
if isinstance(cls, File):
p = resolve_user_filepath(sample, ctx=None)
with open(p, "rb") as f:
self.sample = FileLike[bytes](f, "<sample>")
cls.sample = FileLike[bytes](f, "<sample>")
return cls

@classmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
Expand All @@ -174,6 +177,9 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self):
pass

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
Expand Down
34 changes: 21 additions & 13 deletions src/bentoml/_internal/io_descriptors/image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import io
import os
import typing as t
import tempfile
import functools
from typing import TYPE_CHECKING
from urllib.parse import quote
Expand All @@ -11,8 +13,8 @@
from starlette.responses import Response
from starlette.datastructures import UploadFile

from .base import set_sample
from .base import IODescriptor
from .base import create_sample
from ..types import LazyType
from ..utils import LazyLoader
from ..utils import resolve_user_filepath
Expand Down Expand Up @@ -234,17 +236,19 @@ def from_sample(

if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(sample):

@create_sample.register(np.ndarray)
def _(self: Self, sample: ext.NpNDArray) -> None:
if isinstance(self, Image):
self.sample = PIL.Image.fromarray(sample, mode=self._pilmode)
@set_sample.register(np.ndarray)
def _(cls: Self, sample: ext.NpNDArray):
if isinstance(cls, Image):
cls.sample = PIL.Image.fromarray(sample, mode=pilmode)
return cls

elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(sample):

@create_sample.register(PIL.Image.Image)
def _(self: Self, sample: PIL.Image.Image) -> None:
if isinstance(self, Image):
self.sample = sample
@set_sample.register(PIL.Image.Image)
def _(cls: Self, sample: PIL.Image.Image):
if isinstance(cls, Image):
cls.sample = sample
return cls

return super().from_sample(
sample,
Expand All @@ -253,15 +257,16 @@ def _(self: Self, sample: PIL.Image.Image) -> None:
allowed_mime_types=allowed_mime_types,
)

@create_sample.register(str)
def _(self, sample: str) -> None:
if isinstance(self, Image):
@set_sample.register(str)
def _(cls, sample: str):
if isinstance(cls, Image):
p = resolve_user_filepath(sample, ctx=None)
try:
with open(p, "rb") as f:
self.sample = PIL.Image.open(f)
cls.sample = PIL.Image.open(f)
except PIL.UnidentifiedImageError as err:
raise BadInput(f"Failed to parse sample image file: {err}") from None
return cls

def to_spec(self) -> dict[str, t.Any]:
return {
Expand Down Expand Up @@ -289,6 +294,9 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self):
pass

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {
Expand Down
44 changes: 27 additions & 17 deletions src/bentoml/_internal/io_descriptors/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from starlette.requests import Request
from starlette.responses import Response

from .base import set_sample
from .base import IODescriptor
from .base import create_sample
from ..types import LazyType
from ..utils import LazyLoader
from ..utils import bentoml_cattr
Expand Down Expand Up @@ -213,24 +213,27 @@ def from_sample(
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(sample):
pydantic_model = sample.__class__

@create_sample.register(pydantic.BaseModel)
def _(self: Self, sample: pydantic.BaseModel):
if isinstance(self, JSON):
self.sample = sample
@set_sample.register(pydantic.BaseModel)
def _(cls: Self, sample: pydantic.BaseModel):
if isinstance(cls, JSON):
cls.sample = sample
return cls

return super().from_sample(
sample, pydantic_model=pydantic_model, json_encoder=json_encoder
)

@create_sample.register(dict)
def _(self, sample: dict[str, t.Any]):
if isinstance(self, JSON):
self.sample = sample
@set_sample.register(dict)
def _(cls, sample: dict[str, t.Any]):
if isinstance(cls, JSON):
cls.sample = sample
return cls

@create_sample.register(str)
def _(self, sample: str):
if isinstance(self, JSON):
self.sample = json.loads(sample)
@set_sample.register(str)
def _(cls, sample: str):
if isinstance(cls, JSON):
cls.sample = json.loads(sample)
return cls

def to_spec(self) -> dict[str, t.Any]:
return {
Expand Down Expand Up @@ -282,7 +285,7 @@ def openapi_components(self) -> dict[str, t.Any] | None:

return {"schemas": pydantic_components_schema(self._pydantic_model)}

def openapi_example(self) -> t.Any:
def openapi_example(self):
if self.sample is not None:
if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(
self.sample
Expand All @@ -299,19 +302,26 @@ def openapi_example(self) -> t.Any:
)
elif isinstance(self.sample, dict):
return self.sample
return

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"required": True,
"x-bentoml-io-descriptor": self.to_spec(),
}

def openapi_responses(self) -> OpenAPIResponse:
return {
"description": SUCCESS_DESCRIPTION,
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"x-bentoml-io-descriptor": self.to_spec(),
}

Expand Down
19 changes: 15 additions & 4 deletions src/bentoml/_internal/io_descriptors/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __repr__(self) -> str:

@classmethod
def from_sample(cls, sample: dict[str, t.Any]) -> Self:
pass
return cls(**sample)

def input_type(
self,
Expand Down Expand Up @@ -222,17 +222,28 @@ def openapi_schema(self) -> Schema:
def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self) -> t.Any:
return {args: io.openapi_example() for args, io in self._inputs.items()}

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"required": True,
"x-bentoml-io-descriptor": self.to_spec(),
}

def openapi_responses(self) -> OpenAPIResponse:
return {
"description": SUCCESS_DESCRIPTION,
"content": {self._mime_type: MediaType(schema=self.openapi_schema())},
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"x-bentoml-io-descriptor": self.to_spec(),
}

Expand All @@ -249,7 +260,7 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]:
for field, descriptor in self._inputs.items():
if field not in form_values:
break
res[field] = descriptor.from_http_request(form_values[field])
res[field] = await descriptor.from_http_request(form_values[field])
else: # NOTE: This is similar to goto, when there is no break.
to_populate = zip(self._inputs.values(), form_values.values())
reqs = await asyncio.gather(
Expand Down

0 comments on commit dafcac5

Please sign in to comment.