Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: io descriptor backward compatibility #3327

Merged
merged 3 commits into from Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/bentoml/_internal/io_descriptors/base.py
Expand Up @@ -28,7 +28,7 @@
| LazyType[t.Any]
| dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]]
)
OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]]
OpenAPIResponse = dict[str, str | dict[str, t.Any]]


IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {}
Expand All @@ -37,8 +37,12 @@


def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]:
if spec["id"] is None:
raise BentoMLException("No IO descriptor spec found.")

if "id" not in spec:
raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.")

return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec)


Expand Down Expand Up @@ -123,12 +127,10 @@ def _from_sample(self, sample: t.Any) -> IOType:
def mime_type(self) -> str:
return self._mime_type

@abstractmethod
def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError
def to_spec(self) -> dict[str, t.Any] | None:
return None

@classmethod
@abstractmethod
def from_spec(cls, spec: dict[str, t.Any]) -> Self:
raise NotImplementedError

Expand Down
3 changes: 0 additions & 3 deletions src/bentoml/_internal/io_descriptors/file.py
Expand Up @@ -205,9 +205,6 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
async def from_http_request(self, request: Request) -> FileLike[bytes]:
raise NotImplementedError

def to_spec(self) -> dict[str, t.Any]:
raise NotImplementedError


class BytesIOFile(File, descriptor_id=None):
def to_spec(self) -> dict[str, t.Any]:
Expand Down
32 changes: 20 additions & 12 deletions src/bentoml/client.py
Expand Up @@ -94,18 +94,26 @@ def from_url(server_url: str) -> Client:
raise BentoMLException(
f"Malformed BentoML spec received from BentoML server {server_url}"
)
dummy_service.apis[meth_spec["x-bentoml-name"]] = InferenceAPI(
None,
bentoml.io.from_spec(
meth_spec["requestBody"]["x-bentoml-io-descriptor"]
),
bentoml.io.from_spec(
meth_spec["responses"]["200"]["x-bentoml-io-descriptor"]
),
name=meth_spec["x-bentoml-name"],
doc=meth_spec["description"],
route=route.lstrip("/"),
)
try:
api = InferenceAPI(
None,
bentoml.io.from_spec(
meth_spec["requestBody"]["x-bentoml-io-descriptor"]
),
bentoml.io.from_spec(
meth_spec["responses"]["200"]["x-bentoml-io-descriptor"]
),
name=meth_spec["x-bentoml-name"],
doc=meth_spec["description"],
route=route.lstrip("/"),
)
dummy_service.apis[meth_spec["x-bentoml-name"]] = api
except BentoMLException as e:
logger.error(
"Failed to instantiate client for API %s: ",
meth_spec["x-bentoml-name"],
e,
)

res = HTTPClient(dummy_service, server_url)
res.server_url = server_url
Expand Down
109 changes: 109 additions & 0 deletions tests/unit/_internal/io/test_custom.py
@@ -0,0 +1,109 @@
from __future__ import annotations

import typing as t
from typing import TYPE_CHECKING

from starlette.requests import Request
from starlette.responses import Response

import bentoml
from bentoml.io import IODescriptor
from bentoml.exceptions import BentoMLException
from bentoml._internal.utils.http import set_cookies
from bentoml._internal.service.openapi import SUCCESS_DESCRIPTION
from bentoml._internal.service.openapi.specification import Schema
from bentoml._internal.service.openapi.specification import MediaType

if TYPE_CHECKING:
from google.protobuf import wrappers_pb2

from bentoml._internal.context import InferenceApiContext as Context
from bentoml._internal.io_descriptors.base import OpenAPIResponse


# testing the minimal required IO descriptor to ensure we don't break
# compatibility with custom descriptors when implementing new features
# in IODescriptor.
class CustomDescriptor(IODescriptor[str]):
_mime_type = "text/custom"

def __init__(self, *args: t.Any, **kwargs: t.Any):
if args or kwargs:
raise BentoMLException(
f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization."
) from None

def input_type(self) -> t.Type[str]:
return str

def _from_sample(self, sample: str | bytes) -> str:
if isinstance(sample, bytes):
sample = sample.decode("utf-8")
return sample

def openapi_schema(self) -> Schema:
return Schema(type="string")

def openapi_components(self) -> dict[str, t.Any] | None:
pass

def openapi_example(self):
return str(self.sample)

def openapi_request_body(self) -> dict[str, t.Any]:
return {
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
"required": True,
"x-bentoml-io-descriptor": self.to_spec(),
}

def openapi_responses(self) -> OpenAPIResponse:
return {
"description": SUCCESS_DESCRIPTION,
"content": {
self._mime_type: MediaType(
schema=self.openapi_schema(), example=self.openapi_example()
)
},
}

async def from_http_request(self, request: Request) -> str:
body = await request.body()
return body.decode("cp1252")

async def to_http_response(self, obj: str, ctx: Context | None = None) -> Response:
if ctx is not None:
res = Response(
obj,
media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
return Response(obj, media_type=self._mime_type)

async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
if isinstance(field, bytes):
return field.decode("cp1252")
else:
assert isinstance(field, wrappers_pb2.StringValue)
return field.value

async def to_proto(self, obj: str) -> wrappers_pb2.StringValue:
return wrappers_pb2.StringValue(value=obj)


def test_custom_io_descriptor():
svc = bentoml.Service("test")

@svc.api(input=CustomDescriptor(), output=CustomDescriptor())
def descriptor_test_api(inp):
return inp

svc.asgi_app