From b6a41589daaf41fc67353ae139c4404eda42e751 Mon Sep 17 00:00:00 2001 From: Sauyon Lee <2347889+sauyon@users.noreply.github.com> Date: Wed, 7 Dec 2022 22:27:03 -0800 Subject: [PATCH] fix: io descriptor backward compatibility (#3327) Ensure tests for IO descriptor to be backward compatible --- src/bentoml/_internal/io_descriptors/base.py | 12 +- src/bentoml/_internal/io_descriptors/file.py | 3 - src/bentoml/client.py | 32 ++++-- tests/unit/_internal/io/test_custom.py | 109 +++++++++++++++++++ 4 files changed, 136 insertions(+), 20 deletions(-) create mode 100644 tests/unit/_internal/io/test_custom.py diff --git a/src/bentoml/_internal/io_descriptors/base.py b/src/bentoml/_internal/io_descriptors/base.py index f6c03e0f1e1..091f3e78fd5 100644 --- a/src/bentoml/_internal/io_descriptors/base.py +++ b/src/bentoml/_internal/io_descriptors/base.py @@ -28,7 +28,7 @@ | LazyType[t.Any] | dict[str, t.Type[t.Any] | UnionType | LazyType[t.Any]] ) - OpenAPIResponse = dict[str, str | dict[str, MediaType] | dict[str, t.Any]] + OpenAPIResponse = dict[str, str | dict[str, t.Any]] IO_DESCRIPTOR_REGISTRY: dict[str, type[IODescriptor[t.Any]]] = {} @@ -37,8 +37,12 @@ def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]: + if spec["id"] is None: + raise BentoMLException("No IO descriptor spec found.") + if "id" not in spec: raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.") + return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec) @@ -123,12 +127,10 @@ def _from_sample(self, sample: t.Any) -> IOType: def mime_type(self) -> str: return self._mime_type - @abstractmethod - def to_spec(self) -> dict[str, t.Any]: - raise NotImplementedError + def to_spec(self) -> dict[str, t.Any] | None: + return None @classmethod - @abstractmethod def from_spec(cls, spec: dict[str, t.Any]) -> Self: raise NotImplementedError diff --git a/src/bentoml/_internal/io_descriptors/file.py b/src/bentoml/_internal/io_descriptors/file.py index ed4ccefe01c..fb0eea25570 100644 --- a/src/bentoml/_internal/io_descriptors/file.py +++ b/src/bentoml/_internal/io_descriptors/file.py @@ -205,9 +205,6 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]: async def from_http_request(self, request: Request) -> FileLike[bytes]: raise NotImplementedError - def to_spec(self) -> dict[str, t.Any]: - raise NotImplementedError - class BytesIOFile(File, descriptor_id=None): def to_spec(self) -> dict[str, t.Any]: diff --git a/src/bentoml/client.py b/src/bentoml/client.py index e75630f9e92..e22c966c19a 100644 --- a/src/bentoml/client.py +++ b/src/bentoml/client.py @@ -94,18 +94,26 @@ def from_url(server_url: str) -> Client: raise BentoMLException( f"Malformed BentoML spec received from BentoML server {server_url}" ) - dummy_service.apis[meth_spec["x-bentoml-name"]] = InferenceAPI( - None, - bentoml.io.from_spec( - meth_spec["requestBody"]["x-bentoml-io-descriptor"] - ), - bentoml.io.from_spec( - meth_spec["responses"]["200"]["x-bentoml-io-descriptor"] - ), - name=meth_spec["x-bentoml-name"], - doc=meth_spec["description"], - route=route.lstrip("/"), - ) + try: + api = InferenceAPI( + None, + bentoml.io.from_spec( + meth_spec["requestBody"]["x-bentoml-io-descriptor"] + ), + bentoml.io.from_spec( + meth_spec["responses"]["200"]["x-bentoml-io-descriptor"] + ), + name=meth_spec["x-bentoml-name"], + doc=meth_spec["description"], + route=route.lstrip("/"), + ) + dummy_service.apis[meth_spec["x-bentoml-name"]] = api + except BentoMLException as e: + logger.error( + "Failed to instantiate client for API %s: ", + meth_spec["x-bentoml-name"], + e, + ) res = HTTPClient(dummy_service, server_url) res.server_url = server_url diff --git a/tests/unit/_internal/io/test_custom.py b/tests/unit/_internal/io/test_custom.py new file mode 100644 index 00000000000..a50357b2ed1 --- /dev/null +++ b/tests/unit/_internal/io/test_custom.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +from starlette.requests import Request +from starlette.responses import Response + +import bentoml +from bentoml.io import IODescriptor +from bentoml.exceptions import BentoMLException +from bentoml._internal.utils.http import set_cookies +from bentoml._internal.service.openapi import SUCCESS_DESCRIPTION +from bentoml._internal.service.openapi.specification import Schema +from bentoml._internal.service.openapi.specification import MediaType + +if TYPE_CHECKING: + from google.protobuf import wrappers_pb2 + + from bentoml._internal.context import InferenceApiContext as Context + from bentoml._internal.io_descriptors.base import OpenAPIResponse + + +# testing the minimal required IO descriptor to ensure we don't break +# compatibility with custom descriptors when implementing new features +# in IODescriptor. +class CustomDescriptor(IODescriptor[str]): + _mime_type = "text/custom" + + def __init__(self, *args: t.Any, **kwargs: t.Any): + if args or kwargs: + raise BentoMLException( + f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization." + ) from None + + def input_type(self) -> t.Type[str]: + return str + + def _from_sample(self, sample: str | bytes) -> str: + if isinstance(sample, bytes): + sample = sample.decode("utf-8") + return sample + + def openapi_schema(self) -> Schema: + return Schema(type="string") + + def openapi_components(self) -> dict[str, t.Any] | None: + pass + + def openapi_example(self): + return str(self.sample) + + def openapi_request_body(self) -> dict[str, t.Any]: + return { + "content": { + self._mime_type: MediaType( + schema=self.openapi_schema(), example=self.openapi_example() + ) + }, + "required": True, + "x-bentoml-io-descriptor": self.to_spec(), + } + + def openapi_responses(self) -> OpenAPIResponse: + return { + "description": SUCCESS_DESCRIPTION, + "content": { + self._mime_type: MediaType( + schema=self.openapi_schema(), example=self.openapi_example() + ) + }, + } + + async def from_http_request(self, request: Request) -> str: + body = await request.body() + return body.decode("cp1252") + + async def to_http_response(self, obj: str, ctx: Context | None = None) -> Response: + if ctx is not None: + res = Response( + obj, + media_type=self._mime_type, + headers=ctx.response.metadata, # type: ignore (bad starlette types) + status_code=ctx.response.status_code, + ) + set_cookies(res, ctx.response.cookies) + return res + else: + return Response(obj, media_type=self._mime_type) + + async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str: + if isinstance(field, bytes): + return field.decode("cp1252") + else: + assert isinstance(field, wrappers_pb2.StringValue) + return field.value + + async def to_proto(self, obj: str) -> wrappers_pb2.StringValue: + return wrappers_pb2.StringValue(value=obj) + + +def test_custom_io_descriptor(): + svc = bentoml.Service("test") + + @svc.api(input=CustomDescriptor(), output=CustomDescriptor()) + def descriptor_test_api(inp): + return inp + + svc.asgi_app