Skip to content

Commit

Permalink
fix: io descriptor backward compatibility (#3327)
Browse files Browse the repository at this point in the history
Ensure tests for IO descriptor to be backward compatible
  • Loading branch information
sauyon committed Dec 8, 2022
1 parent a2a4c5e commit b6a4158
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 20 deletions.
12 changes: 7 additions & 5 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -28,7 +28,7 @@
| LazyType[t.Any]
| dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]]
)
OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]]
OpenAPIResponse = dict[str, str | dict[str, t.Any]]


IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {}
Expand All @@ -37,8 +37,12 @@


def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]:
if spec["id"] is None:
raise BentoMLException("No IO descriptor spec found.")

if "id" not in spec:
raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.")

return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)


Expand Down Expand Up @@ -123,12 +127,10 @@ def _from_sample(self, sample: t.Any) -> IOType:
def mime_type(self) -> str:
return self._mime_type

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
def to_spec(self) -> dict[str, t.Any] | None:
return None

@classmethod
@abstractmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError

Expand Down
3 changes: 0 additions & 3 deletions src/bentoml/_internal/io_descriptors/file.py
Expand Up @@ -205,9 +205,6 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError


class BytesIOFile(File, descriptor_id=None):
def to_spec(self) -> dict[str, t.Any]:
Expand Down
32 changes: 20 additions & 12 deletions src/bentoml/client.py
Expand Up @@ -94,18 +94,26 @@ def from_url(server_url: str) -> Client:
raise BentoMLException(
f"Malformed BentoML spec received from BentoML server {server_url}"
)
dummy_service.apis[meth_spec["x-bentoml-name"]] = InferenceAPI(
None,
bentoml.io.from_spec(
meth_spec["requestBody"]["x-bentoml-io-descriptor"]
),
bentoml.io.from_spec(
meth_spec["responses"]["200"]["x-bentoml-io-descriptor"]
),
name=meth_spec["x-bentoml-name"],
doc=meth_spec["description"],
route=route.lstrip("/"),
)
try:
api = InferenceAPI(
None,
bentoml.io.from_spec(
meth_spec["requestBody"]["x-bentoml-io-descriptor"]
),
bentoml.io.from_spec(
meth_spec["responses"]["200"]["x-bentoml-io-descriptor"]
),
name=meth_spec["x-bentoml-name"],
doc=meth_spec["description"],
route=route.lstrip("/"),
)
dummy_service.apis[meth_spec["x-bentoml-name"]] = api
except BentoMLException as e:
logger.error(
"Failed to instantiate client for API %s: ",
meth_spec["x-bentoml-name"],
e,
)

res = HTTPClient(dummy_service, server_url)
res.server_url = server_url
Expand Down
109 changes: 109 additions & 0 deletions tests/unit/_internal/io/test_custom.py
@@ -0,0 +1,109 @@
from __future__ import annotations

import typing as t
from typing import TYPE_CHECKING

from starlette.requests import Request
from starlette.responses import Response

import bentoml
from bentoml.io import IODescriptor
from bentoml.exceptions import BentoMLException
from bentoml._internal.utils.http import set_cookies
from bentoml._internal.service.openapi import SUCCESS_DESCRIPTION
from bentoml._internal.service.openapi.specification import Schema
from bentoml._internal.service.openapi.specification import MediaType

if TYPE_CHECKING:
from google.protobuf import wrappers_pb2

from bentoml._internal.context import InferenceApiContext as Context
from bentoml._internal.io_descriptors.base import OpenAPIResponse


# testing the minimal required IO descriptor to ensure we don't break
# compatibility with custom descriptors when implementing new features
# in IODescriptor.
class CustomDescriptor(IODescriptor[str]):
_mime_type = "text/custom"

def __init__(self, *args: t.Any, **kwargs: t.Any):
if args or kwargs:
raise BentoMLException(
f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization."
) from None

def input_type(self) -> t.Type[str]:
return str

def _from_sample(self, sample: str | bytes) -> str:
if isinstance(sample, bytes):
sample = sample.decode("utf-8")
return sample

def openapi_schema(self) -> Schema:
return Schema(type="string")

def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self):
return str(self.sample)

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"required": True,
"x-bentoml-io-descriptor": self.to_spec(),
}

def openapi_responses(self) -> OpenAPIResponse:
return {
"description": SUCCESS_DESCRIPTION,
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
}

async def from_http_request(self, request: Request) -> str:
body = await request.body()
return body.decode("cp1252")

async def to_http_response(self, obj: str, ctx: Context | None = None) -> Response:
if ctx is not None:
res = Response(
obj,
media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
return Response(obj, media_type=self._mime_type)

async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
if isinstance(field, bytes):
return field.decode("cp1252")
else:
assert isinstance(field, wrappers_pb2.StringValue)
return field.value

async def to_proto(self, obj: str) -> wrappers_pb2.StringValue:
return wrappers_pb2.StringValue(value=obj)


def test_custom_io_descriptor():
svc = bentoml.Service("test")

@svc.api(input=CustomDescriptor(), output=CustomDescriptor())
def descriptor_test_api(inp):
return inp

svc.asgi_app

0 comments on commit b6a4158

Please sign in to comment.