diff --git a/bentoml/_internal/external_typing/__init__.py b/bentoml/_internal/external_typing/__init__.py index 6d0dcfbcd17..9087543eb1b 100644 --- a/bentoml/_internal/external_typing/__init__.py +++ b/bentoml/_internal/external_typing/__init__.py @@ -8,6 +8,8 @@ from pandas import Series as PdSeries from pandas import DataFrame as PdDataFrame + from pandas._typing import Dtype as PdDType + from pandas._typing import DtypeArg as PdDTypeArg from pyarrow.plasma import ObjectID from pyarrow.plasma import PlasmaClient @@ -17,7 +19,7 @@ # numpy is always required by bentoml from numpy import generic as NpGeneric from numpy.typing import NDArray as _NDArray - from numpy.typing import DTypeLike as NpDTypeLike # type: ignore (incomplete numpy types) + from numpy.typing import DTypeLike as NpDTypeLike NpNDArray = _NDArray[t.Any] @@ -36,6 +38,8 @@ __all__ = [ "PdSeries", "PdDataFrame", + "PdDType", + "PdDTypeArg", "DataFrameOrient", "SeriesOrient", "ObjectID", diff --git a/bentoml/_internal/io_descriptors/base.py b/bentoml/_internal/io_descriptors/base.py index 09c669b6985..c558034a84c 100644 --- a/bentoml/_internal/io_descriptors/base.py +++ b/bentoml/_internal/io_descriptors/base.py @@ -3,14 +3,22 @@ import typing as t from abc import ABC from abc import abstractmethod +from typing import overload from typing import TYPE_CHECKING if TYPE_CHECKING: from types import UnionType + from google.protobuf import message + from google.protobuf import struct_pb2 + from google.protobuf import wrappers_pb2 from typing_extensions import Self from starlette.requests import Request from starlette.responses import Response + from google.protobuf.internal.containers import MessageMap + + from bentoml.grpc.types import ProtoField + from bentoml.grpc.v1alpha1 import service_pb2 as pb from ..types import LazyType from ..context import InferenceApiContext as Context @@ -39,20 +47,31 @@ class IODescriptor(ABC, t.Generic[IOType]): HTTP_METHODS = ["POST"] - _init_str: str = "" - _mime_type: str - - def __new__(cls: t.Type[Self], *args: t.Any, **kwargs: t.Any) -> Self: + _rpc_content_type: str + _proto_field: str + + def __new__( # pylint: disable=unused-argument + cls: t.Type[Self], + *args: t.Any, + **kwargs: t.Any, + ) -> Self: self = super().__new__(cls) - # default mime type is application/json - self._mime_type = "application/json" - self._init_str = cls.__qualname__ + # default grpc content type is application/grpc + self._rpc_content_type = "application/grpc" return self + @property + def accepted_proto_fields(self) -> ProtoField: + """ + Returns a proto field that the IODescriptor can accept. + Note that all proto field will also accept _internal_bytes_contents + """ + return t.cast("ProtoField", self._proto_field) + def __repr__(self) -> str: - return self._init_str + return self.__class__.__qualname__ @abstractmethod def input_type(self) -> InputType: @@ -83,3 +102,74 @@ async def to_http_response( self, obj: IOType, ctx: Context | None = None ) -> Response: ... + + @overload + @abstractmethod + async def from_proto( + self, + field: wrappers_pb2.StringValue | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool, + ) -> IOType: + ... + + @overload + @abstractmethod + async def from_proto( + self, + field: struct_pb2.Value | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool, + ) -> IOType: + ... + + @overload + @abstractmethod + async def from_proto( + self, field: MessageMap[str, pb.Part], *, _use_internal_bytes_contents: bool + ) -> IOType: + ... + + @overload + @abstractmethod + async def from_proto( + self, field: pb.NDArray | pb.Part | bytes, *, _use_internal_bytes_contents: bool + ) -> IOType: + ... + + @overload + @abstractmethod + async def from_proto( + self, field: pb.File | pb.Part | bytes, *, _use_internal_bytes_contents: bool + ) -> IOType: + ... + + @overload + @abstractmethod + async def from_proto( + self, + field: pb.DataFrame | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool, + ) -> IOType: + ... + + @overload + @abstractmethod + async def from_proto( + self, field: pb.Series | pb.Part | bytes, *, _use_internal_bytes_contents: bool + ) -> IOType: + ... + + @abstractmethod + async def from_proto( + self, + field: message.Message | bytes | MessageMap[str, pb.Part], + *, + _use_internal_bytes_contents: bool = False, + ) -> IOType: + ... + + @abstractmethod + async def to_proto(self, obj: IOType) -> MessageMap[str, pb.Part] | message.Message: + ... diff --git a/bentoml/_internal/io_descriptors/file.py b/bentoml/_internal/io_descriptors/file.py index 5a21d0ab907..f28cd586f78 100644 --- a/bentoml/_internal/io_descriptors/file.py +++ b/bentoml/_internal/io_descriptors/file.py @@ -13,6 +13,7 @@ from .base import IODescriptor from ..types import FileLike from ..utils.http import set_cookies +from ...exceptions import BadInput from ...exceptions import BentoMLException from ..service.openapi import SUCCESS_DESCRIPTION from ..service.openapi.specification import Schema @@ -23,10 +24,17 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from ..context import InferenceApiContext as Context FileKind: t.TypeAlias = t.Literal["binaryio", "textio"] -FileType: t.TypeAlias = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]] +else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + +FileType = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]] class File(IODescriptor[FileType]): @@ -100,16 +108,16 @@ async def predict(input_pdf: io.BytesIO[Any]) -> NDArray[Any]: """ + _proto_field: str = "file" + def __new__( # pylint: disable=arguments-differ # returning subclass from new cls, kind: FileKind = "binaryio", mime_type: str | None = None ) -> File: mime_type = mime_type if mime_type is not None else "application/octet-stream" - if kind == "binaryio": res = object.__new__(BytesIOFile) else: raise ValueError(f"invalid File kind '{kind}'") - res._mime_type = mime_type return res @@ -134,11 +142,7 @@ def openapi_responses(self) -> OpenAPIResponse: content={self._mime_type: MediaType(schema=self.openapi_schema())}, ) - async def to_http_response( - self, - obj: FileType, - ctx: Context | None = None, - ): + async def to_http_response(self, obj: FileType, ctx: Context | None = None): if isinstance(obj, bytes): body = obj else: @@ -155,6 +159,36 @@ async def to_http_response( res = Response(body) return res + async def to_proto(self, obj: FileType) -> pb.File: + from bentoml.grpc.utils import mimetype_to_filetype_pb_map + + if isinstance(obj, bytes): + body = obj + else: + body = obj.read() + + try: + kind = mimetype_to_filetype_pb_map()[self._mime_type] + except KeyError: + raise BadInput( + f"{self._mime_type} doesn't have a corresponding File 'kind'" + ) from None + + return pb.File(kind=kind, content=body) + + if TYPE_CHECKING: + + async def from_proto( + self, + field: pb.File | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> FileLike[bytes]: + ... + + async def from_http_request(self, request: Request) -> t.IO[bytes]: + ... + class BytesIOFile(File): async def from_http_request(self, request: Request) -> t.IO[bytes]: @@ -183,3 +217,37 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]: raise BentoMLException( f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead" ) + + async def from_proto( + self, + field: pb.File | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> FileLike[bytes]: + from bentoml.grpc.utils import filetype_pb_to_mimetype_map + + mapping = filetype_pb_to_mimetype_map() + # check if the request message has the correct field + if not _use_internal_bytes_contents: + if isinstance(field, pb.Part): + field = field.file + assert isinstance(field, pb.File) + if field.kind: + try: + mime_type = mapping[field.kind] + if mime_type != self._mime_type: + raise BadInput( + f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'", + ) + except KeyError: + raise BadInput( + f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb.File.FileType.items()]}", + ) from None + content = field.content + if not content: + raise BadInput("Content is empty!") from None + else: + assert isinstance(field, bytes) + content = field + + return FileLike[bytes](io.BytesIO(content), "") diff --git a/bentoml/_internal/io_descriptors/image.py b/bentoml/_internal/io_descriptors/image.py index af49128e0c1..3bf18149c6e 100644 --- a/bentoml/_internal/io_descriptors/image.py +++ b/bentoml/_internal/io_descriptors/image.py @@ -15,7 +15,6 @@ from ..utils.http import set_cookies from ...exceptions import BadInput from ...exceptions import InvalidArgument -from ...exceptions import InternalServerError from ..service.openapi import SUCCESS_DESCRIPTION from ..service.openapi.specification import Schema from ..service.openapi.specification import Response as OpenAPIResponse @@ -25,8 +24,11 @@ if TYPE_CHECKING: from types import UnionType + import PIL import PIL.Image + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from .. import external_typing as ext from ..context import InferenceApiContext as Context @@ -34,17 +36,21 @@ "1", "CMYK", "F", "HSV", "I", "L", "LAB", "P", "RGB", "RGBA", "RGBX", "YCbCr" ] else: + from bentoml.grpc.utils import import_generated_stubs # NOTE: pillow-simd only benefits users who want to do preprocessing # TODO: add options for users to choose between simd and native mode - _exc = f"'Pillow' is required to use {__name__}. Install with: 'pip install -U Pillow'." + _exc = "'Pillow' is required to use the Image IO descriptor. Install it with: 'pip install -U Pillow'." PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=_exc) PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=_exc) + pb, _ = import_generated_stubs() + + # NOTES: we will keep type in quotation to avoid backward compatibility # with numpy < 1.20, since we will use the latest stubs from the main branch of numpy. # that enable a new way to type hint an ndarray. -ImageType: t.TypeAlias = t.Union["PIL.Image.Image", "ext.NpNDArray"] +ImageType = t.Union["PIL.Image.Image", "ext.NpNDArray"] DEFAULT_PIL_MODE = "RGB" @@ -137,30 +143,24 @@ async def predict_image(f: Image) -> NDArray[Any]: MIME_EXT_MAPPING: t.Dict[str, str] = {} + _proto_field: str = "file" + def __init__( self, pilmode: _Mode | None = DEFAULT_PIL_MODE, mime_type: str = "image/jpeg", ): - try: - import PIL.Image - except ImportError: - raise InternalServerError( - "`Pillow` is required to use {__name__}\n Instructions: `pip install -U Pillow`" - ) PIL.Image.init() self.MIME_EXT_MAPPING.update({v: k for k, v in PIL.Image.MIME.items()}) if mime_type.lower() not in self.MIME_EXT_MAPPING: # pragma: no cover raise InvalidArgument( - f"Invalid Image mime_type '{mime_type}', " - f"Supported mime types are {', '.join(PIL.Image.MIME.values())} " - ) + f"Invalid Image mime_type '{mime_type}'. Supported mime types are {', '.join(PIL.Image.MIME.values())}." + ) from None if pilmode is not None and pilmode not in PIL.Image.MODES: # pragma: no cover raise InvalidArgument( - f"Invalid Image pilmode '{pilmode}', " - f"Supported PIL modes are {', '.join(PIL.Image.MODES)} " - ) + f"Invalid Image pilmode '{pilmode}'. Supported PIL modes are {', '.join(PIL.Image.MODES)}." + ) from None self._mime_type = mime_type.lower() self._pilmode: _Mode | None = pilmode @@ -197,13 +197,12 @@ async def from_http_request(self, request: Request) -> ImageType: bytes_ = await request.body() else: raise BadInput( - f"{self.__class__.__name__} should get `multipart/form-data`, " - f"`{self._mime_type}` or `image/*`, got {content_type} instead" + f"{self.__class__.__name__} should get 'multipart/form-data', '{self._mime_type}' or 'image/*', got '{content_type}' instead." ) try: return PIL.Image.open(io.BytesIO(bytes_)) - except PIL.UnidentifiedImageError: - raise BadInput("Failed reading image file uploaded") from None + except PIL.UnidentifiedImageError as e: + raise BadInput(f"Failed reading image file uploaded: {e}") from None async def to_http_response( self, obj: ImageType, ctx: Context | None = None @@ -213,10 +212,9 @@ async def to_http_response( elif LazyType[PIL.Image.Image]("PIL.Image.Image").isinstance(obj): image = obj else: - raise InternalServerError( - f"Unsupported Image type received: {type(obj)}, `{self.__class__.__name__}`" - " only supports `np.ndarray` and `PIL.Image`" - ) + raise BadInput( + f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'." + ) from None filename = f"output.{self._format.lower()}" ret = io.BytesIO() @@ -248,3 +246,60 @@ async def to_http_response( media_type=self._mime_type, headers={"content-disposition": content_disposition}, ) + + async def from_proto( + self, + field: pb.File | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> ImageType: + from bentoml.grpc.utils import filetype_pb_to_mimetype_map + + mapping = filetype_pb_to_mimetype_map() + # check if the request message has the correct field + if not _use_internal_bytes_contents: + if isinstance(field, pb.Part): + field = field.file + assert isinstance(field, pb.File) + if field.kind: + try: + mime_type = mapping[field.kind] + if mime_type != self._mime_type: + raise BadInput( + f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'", + ) + except KeyError: + raise BadInput( + f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb.File.FileType.items()]}", + ) from None + content = field.content + if not content: + raise BadInput("Content is empty!") from None + else: + assert isinstance(field, bytes) + content = field + + return PIL.Image.open(io.BytesIO(content)) + + async def to_proto(self, obj: ImageType) -> pb.File: + from bentoml.grpc.utils import mimetype_to_filetype_pb_map + + if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj): + image = PIL.Image.fromarray(obj, mode=self._pilmode) + elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj): + image = obj + else: + raise BadInput( + f"Unsupported Image type received: '{type(obj)}', '{self.__class__.__name__}' only supports 'np.ndarray' and 'PIL.Image'.", + ) from None + ret = io.BytesIO() + image.save(ret, format=self._format) + + try: + kind = mimetype_to_filetype_pb_map()[self._mime_type] + except KeyError: + raise BadInput( + f"{self._mime_type} doesn't have a corresponding File 'kind'", + ) from None + + return pb.File(kind=kind, content=ret.getvalue()) diff --git a/bentoml/_internal/io_descriptors/json.py b/bentoml/_internal/io_descriptors/json.py index 648cf7de7ec..a678623a868 100644 --- a/bentoml/_internal/io_descriptors/json.py +++ b/bentoml/_internal/io_descriptors/json.py @@ -10,12 +10,14 @@ from starlette.requests import Request from starlette.responses import Response +from bentoml.exceptions import BadInput + from .base import IODescriptor from ..types import LazyType from ..utils import LazyLoader from ..utils import bentoml_cattr +from ..utils.pkg import pkg_version_info from ..utils.http import set_cookies -from ...exceptions import BadInput from ..service.openapi import REF_PREFIX from ..service.openapi import SUCCESS_DESCRIPTION from ..service.openapi.specification import Schema @@ -28,26 +30,35 @@ import pydantic import pydantic.schema as schema + from google.protobuf import struct_pb2 + + from bentoml.grpc.v1alpha1 import service_pb2 as pb from .. import external_typing as ext from ..context import InferenceApiContext as Context - _Serializable = ext.NpNDArray | ext.PdDataFrame | t.Type[pydantic.BaseModel] | type else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() _exc_msg = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install pydantic'." pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=_exc_msg) schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=_exc_msg) + # lazy load our proto generated. + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") -JSONType = t.Union[str, t.Dict[str, t.Any], "pydantic.BaseModel", None] + # lazy load numpy for processing ndarray. + np = LazyLoader("np", globals(), "numpy") -MIME_TYPE_JSON = "application/json" + +JSONType = t.Union[str, t.Dict[str, t.Any], "pydantic.BaseModel", None] logger = logging.getLogger(__name__) class DefaultJsonEncoder(json.JSONEncoder): - def default(self, o: _Serializable) -> t.Any: + def default(self, o: type) -> t.Any: if dataclasses.is_dataclass(o): return dataclasses.asdict(o) if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(o): @@ -63,9 +74,8 @@ def default(self, o: _Serializable) -> t.Any: if "__root__" in obj_dict: obj_dict = obj_dict.get("__root__") return obj_dict - if attr.has(o): # type: ignore (trivial case) + if attr.has(o): return bentoml_cattr.unstructure(o) - return super().default(o) @@ -168,7 +178,9 @@ def classify(input_data: IrisFeatures) -> NDArray[Any]: :obj:`JSON`: IO Descriptor that represents JSON format. """ - _mime_type: str = MIME_TYPE_JSON + _proto_field: str = "json" + # default mime type is application/json + _mime_type = "application/json" def __init__( self, @@ -177,7 +189,11 @@ def __init__( validate_json: bool | None = None, json_encoder: t.Type[json.JSONEncoder] = DefaultJsonEncoder, ): - if pydantic_model: + if pydantic_model is not None: + if pkg_version_info("pydantic")[0] >= 2: + raise BadInput( + "pydantic 2.x is not yet supported. Add upper bound to 'pydantic': 'pip install \"pydantic<2\"'" + ) from None assert issubclass( pydantic_model, pydantic.BaseModel ), "'pydantic_model' must be a subclass of 'pydantic.BaseModel'." @@ -236,12 +252,12 @@ async def from_http_request(self, request: Request) -> JSONType: except json.JSONDecodeError as e: raise BadInput(f"Invalid JSON input received: {e}") from None - if self._pydantic_model is not None: + if self._pydantic_model: try: pydantic_model = self._pydantic_model.parse_obj(json_obj) return pydantic_model except pydantic.ValidationError as e: - raise BadInput(f"Invalid JSON input received: {e}") from e + raise BadInput(f"Invalid JSON input received: {e}") from None else: return json_obj @@ -269,11 +285,67 @@ async def to_http_response( if ctx is not None: res = Response( json_str, - media_type=MIME_TYPE_JSON, + media_type=self._mime_type, headers=ctx.response.metadata, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(json_str, media_type=MIME_TYPE_JSON) + return Response(json_str, media_type=self._mime_type) + + async def from_proto( + self, + field: struct_pb2.Value | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> JSONType: + from google.protobuf.json_format import MessageToDict + + if not _use_internal_bytes_contents: + if isinstance(field, pb.Part): + field = field.json + assert isinstance(field, struct_pb2.Value) + parsed = MessageToDict(field, preserving_proto_field_name=True) + + if self._pydantic_model: + try: + return self._pydantic_model.parse_obj(parsed) + except pydantic.ValidationError as e: + raise BadInput(f"Invalid JSON input received: {e}") from None + else: + assert isinstance(field, bytes) + content = field + if self._pydantic_model: + try: + return self._pydantic_model.parse_raw(content) + except pydantic.ValidationError as e: + raise BadInput(f"Invalid JSON input received: {e}") from None + + try: + parsed = json.loads(content) + except json.JSONDecodeError as e: + raise BadInput(f"Invalid JSON input received: {e}") from None + + return parsed + + async def to_proto(self, obj: JSONType) -> struct_pb2.Value: + if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj): + obj = obj.dict() + msg = struct_pb2.Value() + # To handle None cases. + if obj is not None: + from google.protobuf.json_format import ParseDict + + if isinstance(obj, (dict, str, list, float, int, bool)): + # ParseDict handles google.protobuf.Struct type + # directly if given object has a supported type + ParseDict(obj, msg) + else: + # If given object doesn't have a supported type, we will + # use given JSON encoder to convert it to dictionary + # and then parse it to google.protobuf.Struct. + # Note that if a custom JSON encoder is used, it mustn't + # take any arguments. + ParseDict(self._json_encoder().default(obj), msg) + return msg diff --git a/bentoml/_internal/io_descriptors/multipart.py b/bentoml/_internal/io_descriptors/multipart.py index c6f9190dd88..499e48d4b8b 100644 --- a/bentoml/_internal/io_descriptors/multipart.py +++ b/bentoml/_internal/io_descriptors/multipart.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +import asyncio from typing import TYPE_CHECKING from starlette.requests import Request @@ -21,11 +22,19 @@ if TYPE_CHECKING: from types import UnionType + from google.protobuf.internal.containers import MessageMap + + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from ..types import LazyType from ..context import InferenceApiContext as Context +else: + from bentoml.grpc.utils import import_generated_stubs + pb, _ = import_generated_stubs() -class Multipart(IODescriptor[t.Any]): + +class Multipart(IODescriptor[t.Dict[str, t.Any]]): """ :obj:`Multipart` defines API specification for the inputs/outputs of a Service, where inputs/outputs of a Service can receive/send a **multipart** request/responses as specified in your API function signature. @@ -153,14 +162,18 @@ async def predict( :obj:`Multipart`: IO Descriptor that represents a Multipart request/response. """ + _proto_field = "multipart" + _mime_type = "multipart/form-data" + def __init__(self, **inputs: IODescriptor[t.Any]): - for descriptor in inputs.values(): - if isinstance(descriptor, Multipart): # pragma: no cover - raise InvalidArgument( - "Multipart IO can not contain nested Multipart IO descriptor" - ) - self._inputs: dict[str, t.Any] = inputs - self._mime_type = "multipart/form-data" + if any(isinstance(descriptor, Multipart) for descriptor in inputs.values()): + raise InvalidArgument( + "Multipart IO can not contain nested Multipart IO descriptor" + ) from None + self._inputs = inputs + + def __repr__(self) -> str: + return f"Multipart({','.join([f'{k}={v}' for k,v in zip(self._inputs, map(repr, self._inputs.values()))])})" def input_type( self, @@ -171,7 +184,7 @@ def input_type( if isinstance(inp_type, dict): raise TypeError( "A multipart descriptor cannot take a multi-valued I/O descriptor as input" - ) + ) from None res[k] = inp_type return res @@ -202,22 +215,74 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]: if ctype != b"multipart/form-data": raise BentoMLException( f"{self.__class__.__name__} only accepts `multipart/form-data` as Content-Type header, got {ctype} instead." - ) - - res: dict[str, t.Any] = dict() - reqs = await populate_multipart_requests(request) + ) from None - for k, i in self._inputs.items(): - req = reqs[k] - v = await i.from_http_request(req) - res[k] = v - return res + to_populate = zip( + self._inputs.values(), (await populate_multipart_requests(request)).values() + ) + reqs = await asyncio.gather( + *tuple(io_.from_http_request(req) for io_, req in to_populate) + ) + return dict(zip(self._inputs, reqs)) async def to_http_response( self, obj: dict[str, t.Any], ctx: Context | None = None ) -> Response: - res_mapping: dict[str, Response] = {} - for k, io_ in self._inputs.items(): - data = obj[k] - res_mapping[k] = await io_.to_http_response(data, ctx) - return await concat_to_multipart_response(res_mapping, ctx) + resps = await asyncio.gather( + *tuple( + io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items() + ) + ) + return await concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx) + + def validate_input_mapping( + self, field: MessageMap[str, pb.Part] | dict[str, t.Any] + ) -> None: + if len(set(field) - set(self._inputs)) != 0: + raise InvalidArgument( + f"'{repr(self)}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}", + ) from None + + async def from_proto( + self, + field: MessageMap[str, pb.Part], + *, + _use_internal_bytes_contents: bool = False, + ) -> dict[str, t.Any]: + if _use_internal_bytes_contents: + raise InvalidArgument( + f"cannot use '_internal_bytes_contents' with {self.__class__.__name__}" + ) from None + self.validate_input_mapping(field) + to_populate = zip(self._inputs.values(), field.values()) + reqs = await asyncio.gather( + *tuple( + io_.from_proto( + input_pb, _use_internal_bytes_contents=_use_internal_bytes_contents + ) + for io_, input_pb in to_populate + ) + ) + return dict(zip(field, reqs)) + + async def to_proto(self, obj: dict[str, t.Any]) -> MessageMap[str, pb.Part]: + self.validate_input_mapping(obj) + resps = await asyncio.gather( + *tuple( + io_.to_proto(data) + for io_, data in zip(self._inputs.values(), obj.values()) + ) + ) + + return t.cast( + "MessageMap[str, pb.Part]", + { + key: pb.Part( + **{ + io_.accepted_proto_fields: resp + for io_, resp in zip(self._inputs.values(), resps) + } + ) + for key in obj + }, + ) diff --git a/bentoml/_internal/io_descriptors/numpy.py b/bentoml/_internal/io_descriptors/numpy.py index 9ae561d47ae..21e11ab7ffa 100644 --- a/bentoml/_internal/io_descriptors/numpy.py +++ b/bentoml/_internal/io_descriptors/numpy.py @@ -4,19 +4,20 @@ import typing as t import logging from typing import TYPE_CHECKING +from functools import lru_cache from starlette.requests import Request from starlette.responses import Response from .base import IODescriptor -from .json import MIME_TYPE_JSON from ..types import LazyType +from ..utils import LazyLoader from ..utils.http import set_cookies from ...exceptions import BadInput +from ...exceptions import InvalidArgument from ...exceptions import BentoMLException -from ...exceptions import InternalServerError +from ...exceptions import UnprocessableEntity from ..service.openapi import SUCCESS_DESCRIPTION -from ..utils.lazy_loader import LazyLoader from ..service.openapi.specification import Schema from ..service.openapi.specification import Response as OpenAPIResponse from ..service.openapi.specification import MediaType @@ -25,18 +26,79 @@ if TYPE_CHECKING: import numpy as np + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from .. import external_typing as ext from ..context import InferenceApiContext as Context else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() np = LazyLoader("np", globals(), "numpy") logger = logging.getLogger(__name__) -def _is_matched_shape( - left: t.Optional[t.Tuple[int, ...]], - right: t.Optional[t.Tuple[int, ...]], -) -> bool: # pragma: no cover +# TODO: support the following types for for protobuf message: +# - support complex64, complex128, object and struct types +# - BFLOAT16, QINT32, QINT16, QUINT16, QINT8, QUINT8 +# +# For int16, uint16, int8, uint8 -> specify types in NumpyNdarray + using int_values. +# +# For bfloat16, half (float16) -> specify types in NumpyNdarray + using float_values. +# +# for string_values, use dict[pb.NDArray.DType.ValueType, ext.NpDTypeLike]: + # pb.NDArray.Dtype -> np.dtype + return { + pb.NDArray.DTYPE_FLOAT: np.dtype("float32"), + pb.NDArray.DTYPE_DOUBLE: np.dtype("double"), + pb.NDArray.DTYPE_INT32: np.dtype("int32"), + pb.NDArray.DTYPE_INT64: np.dtype("int64"), + pb.NDArray.DTYPE_UINT32: np.dtype("uint32"), + pb.NDArray.DTYPE_UINT64: np.dtype("uint64"), + pb.NDArray.DTYPE_BOOL: np.dtype("bool"), + pb.NDArray.DTYPE_STRING: np.dtype(" dict[pb.NDArray.DType.ValueType, str]: + return {k: npdtype_to_fieldpb_map()[v] for k, v in dtypepb_to_npdtype_map().items()} + + +@lru_cache(maxsize=1) +def fieldpb_to_npdtype_map() -> dict[str, ext.NpDTypeLike]: + # str -> np.dtype + return {k: np.dtype(v) for k, v in FIELDPB_TO_NPDTYPE_NAME_MAP.items()} + + +@lru_cache(maxsize=1) +def npdtype_to_dtypepb_map() -> dict[ext.NpDTypeLike, pb.NDArray.DType.ValueType]: + # np.dtype -> pb.NDArray.Dtype + return {v: k for k, v in dtypepb_to_npdtype_map().items()} + + +@lru_cache(maxsize=1) +def npdtype_to_fieldpb_map() -> dict[ext.NpDTypeLike, str]: + # np.dtype -> str + return {v: k for k, v in fieldpb_to_npdtype_map().items()} + + +def _is_matched_shape(left: tuple[int, ...], right: tuple[int, ...]) -> bool: if (left is None) or (right is None): return False @@ -52,6 +114,7 @@ def _is_matched_shape( return True +# TODO: when updating docs, add examples with gRPCurl class NumpyNdarray(IODescriptor["ext.NpNDArray"]): """ :obj:`NumpyNdarray` defines API specification for the inputs/outputs of a Service, where @@ -135,6 +198,9 @@ async def predict(input_array: np.ndarray) -> np.ndarray: :obj:`~bentoml._internal.io_descriptors.IODescriptor`: IO Descriptor that represents a :code:`np.ndarray`. """ + _proto_field: str = "ndarray" + _mime_type = "application/json" + def __init__( self, dtype: str | ext.NpDTypeLike | None = None, @@ -143,15 +209,11 @@ def __init__( enforce_shape: bool = False, ): if dtype and not isinstance(dtype, np.dtype): - # Convert from primitive type or type string, e.g.: - # np.dtype(float) - # np.dtype("float64") + # Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64") try: dtype = np.dtype(dtype) except TypeError as e: - raise BentoMLException( - f'NumpyNdarray: Invalid dtype "{dtype}": {e}' - ) from e + raise UnprocessableEntity(f'Invalid dtype "{dtype}": {e}') from None self._dtype = dtype self._shape = shape @@ -160,6 +222,15 @@ def __init__( self._sample_input = None + if self._enforce_dtype and not self._dtype: + raise InvalidArgument( + "'dtype' must be specified when 'enforce_dtype=True'" + ) from None + if self._enforce_shape and not self._shape: + raise InvalidArgument( + "'shape' must be specified when 'enforce_shape=True'" + ) from None + def _openapi_types(self) -> str: # convert numpy dtypes to openapi compatible types. var_type = "integer" @@ -195,7 +266,9 @@ def openapi_components(self) -> dict[str, t.Any] | None: def openapi_example(self) -> t.Any: if self.sample_input is not None: if isinstance(self.sample_input, np.generic): - raise BadInput("NumpyNdarray: sample_input must be a numpy array.") + raise BadInput( + "NumpyNdarray: sample_input must be a numpy array." + ) from None return self.sample_input.tolist() return @@ -219,33 +292,33 @@ def openapi_responses(self) -> OpenAPIResponse: }, ) - def _verify_ndarray( - self, obj: ext.NpNDArray, exception_cls: t.Type[Exception] = BadInput + def validate_array( + self, arr: ext.NpNDArray, exception_cls: t.Type[Exception] = BadInput ) -> ext.NpNDArray: - if self._dtype is not None and self._dtype != obj.dtype: + if self._dtype is not None and self._dtype != arr.dtype: # ‘same_kind’ means only safe casts or casts within a kind, like float64 # to float32, are allowed. - if np.can_cast(obj.dtype, self._dtype, casting="same_kind"): - obj = obj.astype(self._dtype, casting="same_kind") # type: ignore + if np.can_cast(arr.dtype, self._dtype, casting="same_kind"): + arr = arr.astype(self._dtype, casting="same_kind") # type: ignore else: - msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{obj.dtype}" was received.' + msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{arr.dtype}" was received.' if self._enforce_dtype: - raise exception_cls(msg) + raise exception_cls(msg) from None else: logger.debug(msg) - if self._shape is not None and not _is_matched_shape(self._shape, obj.shape): - msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{obj.shape}" was received.' + if self._shape is not None and not _is_matched_shape(self._shape, arr.shape): + msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{arr.shape}" was received.' if self._enforce_shape: - raise exception_cls(msg) + raise exception_cls(msg) from None try: - obj = obj.reshape(self._shape) + arr = arr.reshape(self._shape) except ValueError as e: logger.debug(f"{msg} Failed to reshape: {e}.") - return obj + return arr - async def from_http_request(self, request: Request) -> "ext.NpNDArray": + async def from_http_request(self, request: Request) -> ext.NpNDArray: """ Process incoming requests and convert incoming objects to ``numpy.ndarray``. @@ -262,7 +335,7 @@ async def from_http_request(self, request: Request) -> "ext.NpNDArray": except ValueError: res = np.array(obj) - return self._verify_ndarray(res) + return self.validate_array(res) async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None): """ @@ -276,18 +349,19 @@ async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None) HTTP Response of type ``starlette.responses.Response``. This can be accessed via cURL or any external web traffic. """ - obj = self._verify_ndarray(obj, InternalServerError) + obj = self.validate_array(obj) + if ctx is not None: res = Response( json.dumps(obj.tolist()), - media_type=MIME_TYPE_JSON, + media_type=self._mime_type, headers=ctx.response.metadata, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(json.dumps(obj.tolist()), media_type=MIME_TYPE_JSON) + return Response(json.dumps(obj.tolist()), media_type=self._mime_type) @classmethod def from_sample( @@ -336,8 +410,8 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]: """ if isinstance(sample_input, np.generic): raise BentoMLException( - "NumpyNdarray.from_sample() expects a numpy.array, not numpy.generic." - ) + "'NumpyNdarray.from_sample()' expects a 'numpy.array', not 'numpy.generic'." + ) from None inst = cls( dtype=sample_input.dtype, @@ -348,3 +422,96 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]: inst.sample_input = sample_input return inst + + async def from_proto( + self, + field: pb.NDArray | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> ext.NpNDArray: + """ + Process incoming protobuf request and convert it to ``numpy.ndarray`` + + Args: + request: Incoming RPC request message. + context: grpc.ServicerContext + + Returns: + a ``numpy.ndarray`` object. This can then be used + inside users defined logics. + """ + if not _use_internal_bytes_contents: + if isinstance(field, pb.Part): + field = field.ndarray + assert isinstance(field, pb.NDArray) + if field.dtype == pb.NDArray.DTYPE_UNSPECIFIED: + dtype = None + else: + try: + dtype = dtypepb_to_npdtype_map()[field.dtype] + except KeyError: + raise BadInput(f"{field.dtype} is invalid.") from None + if dtype is not None: + values_array = getattr(field, dtypepb_to_fieldpb_map()[field.dtype]) + else: + fieldpb = [ + f.name for f, _ in field.ListFields() if f.name.endswith("_values") + ] + if len(fieldpb) == 0: + # input message doesn't have any fields. + return np.empty(shape=field.shape or 0) + elif len(fieldpb) > 1: + # when there are more than two values provided in the proto. + raise BadInput( + f"Array contents can only be one of given values key. Use one of '{fieldpb}' instead.", + ) from None + + dtype: ext.NpDTypeLike = fieldpb_to_npdtype_map()[fieldpb[0]] + values_array = getattr(field, fieldpb[0]) + try: + array = np.array(values_array, dtype=dtype) + except ValueError: + array = np.array(values_array) + + if field.shape: + array = np.reshape(array, field.shape) + else: + assert isinstance(field, bytes) + if not self._dtype: + raise UnprocessableEntity( + "'_internal_bytes_contents' requires specifying 'dtype'." + ) from None + + dtype: ext.NpDTypeLike = self._dtype + array = np.frombuffer(field, dtype=self._dtype) + + return self.validate_array(array) + + async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray: + """ + Process given objects and convert it to grpc protobuf response. + + Args: + obj: `np.ndarray` that will be serialized to protobuf + context: grpc.aio.ServicerContext from grpc.aio.Server + Returns: + `io_descriptor_pb2.Array`: + Protobuf representation of given `np.ndarray` + """ + try: + obj = self.validate_array(obj) + except BadInput as e: + raise e from None + + try: + fieldpb = npdtype_to_fieldpb_map()[obj.dtype] + dtypepb = npdtype_to_dtypepb_map()[obj.dtype] + return pb.NDArray( + dtype=dtypepb, + shape=tuple(obj.shape), + **{fieldpb: obj.ravel().tolist()}, + ) + except KeyError: + raise BadInput( + f"Unsupported dtype '{obj.dtype}' for response message.", + ) from None diff --git a/bentoml/_internal/io_descriptors/pandas.py b/bentoml/_internal/io_descriptors/pandas.py index b01256bfc5e..2988cfc8553 100644 --- a/bentoml/_internal/io_descriptors/pandas.py +++ b/bentoml/_internal/io_descriptors/pandas.py @@ -4,19 +4,20 @@ import typing as t import logging import functools -import importlib.util from enum import Enum from typing import TYPE_CHECKING +from concurrent.futures import ThreadPoolExecutor from starlette.requests import Request from starlette.responses import Response from .base import IODescriptor -from .json import MIME_TYPE_JSON from ..types import LazyType +from ..utils.pkg import find_spec from ..utils.http import set_cookies from ...exceptions import BadInput from ...exceptions import InvalidArgument +from ...exceptions import UnprocessableEntity from ...exceptions import MissingDependencyException from ..service.openapi import SUCCESS_DESCRIPTION from ..utils.lazy_loader import LazyLoader @@ -28,15 +29,20 @@ if TYPE_CHECKING: import pandas as pd + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from .. import external_typing as ext from ..context import InferenceApiContext as Context else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() pd = LazyLoader( "pd", globals(), "pandas", - exc_msg="`pandas` is required to use PandasDataFrame or PandasSeries. Install with `pip install -U pandas`", + exc_msg='pandas" is required to use PandasDataFrame or PandasSeries. Install with "pip install -U pandas"', ) logger = logging.getLogger(__name__) @@ -45,9 +51,9 @@ # Check for parquet support @functools.lru_cache(maxsize=1) def get_parquet_engine() -> str: - if importlib.util.find_spec("pyarrow") is not None: + if find_spec("pyarrow") is not None: return "pyarrow" - elif importlib.util.find_spec("fastparquet") is not None: + elif find_spec("fastparquet") is not None: return "fastparquet" else: logger.warning( @@ -72,9 +78,7 @@ def _openapi_types(item: str) -> str: # pragma: no cover return "object" -def _openapi_schema( - dtype: bool | dict[str, t.Any] | None -) -> Schema: # pragma: no cover +def _openapi_schema(dtype: bool | ext.PdDTypeArg | None) -> Schema: # pragma: no cover if isinstance(dtype, dict): return Schema( type="object", @@ -111,15 +115,12 @@ def _infer_serialization_format_from_request( return SerializationFormat.CSV elif content_type: logger.debug( - "Unknown content-type (%s), falling back to %s serialization format.", - content_type, - default_format, + f"Unknown content-type ('{content_type}'), falling back to '{default_format}' serialization format.", ) return default_format else: logger.debug( - "Content-type not specified, falling back to %s serialization format.", - default_format, + f"Content-type not specified, falling back to '{default_format}' serialization format.", ) return default_format @@ -203,7 +204,7 @@ def predict(input_arr): - :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``} - :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}] - :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}} - - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}} + - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}} - :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays columns: List of columns name that users wish to update. apply_column_names: Whether to update incoming DataFrame columns. If :code:`apply_column_names=True`, @@ -248,12 +249,14 @@ def predict(input_df: pd.DataFrame) -> pd.DataFrame: :obj:`PandasDataFrame`: IO Descriptor that represents a :code:`pd.DataFrame`. """ + _proto_field: str = "dataframe" + def __init__( self, orient: ext.DataFrameOrient = "records", - apply_column_names: bool = False, columns: list[str] | None = None, - dtype: bool | dict[str, t.Any] | None = None, + apply_column_names: bool = False, + dtype: bool | ext.PdDTypeArg | None = None, enforce_dtype: bool = False, shape: tuple[int, ...] | None = None, enforce_shape: bool = False, @@ -324,49 +327,21 @@ async def from_http_request(self, request: Request) -> ext.PdDataFrame: _validate_serialization_format(serialization_format) obj = await request.body() - if self._enforce_dtype: - if self._dtype is None: - logger.warning( - "`dtype` is None or undefined, while `enforce_dtype`=True" - ) - # TODO(jiang): check dtype - if serialization_format is SerializationFormat.JSON: + assert not isinstance(self._dtype, bool) res = pd.read_json(io.BytesIO(obj), dtype=self._dtype, orient=self._orient) elif serialization_format is SerializationFormat.PARQUET: res = pd.read_parquet(io.BytesIO(obj), engine=get_parquet_engine()) elif serialization_format is SerializationFormat.CSV: + assert not isinstance(self._dtype, bool) res: ext.PdDataFrame = pd.read_csv(io.BytesIO(obj), dtype=self._dtype) else: raise InvalidArgument( f"Unknown serialization format ({serialization_format})." - ) + ) from None assert isinstance(res, pd.DataFrame) - - if self._apply_column_names: - if self._columns is None: - logger.warning( - "`columns` is None or undefined, while `apply_column_names`=True" - ) - elif len(self._columns) != res.shape[1]: - raise BadInput( - "length of `columns` does not match the columns of incoming data" - ) - else: - res.columns = pd.Index(self._columns) - if self._enforce_shape: - if self._shape is None: - logger.warning( - "`shape` is None or undefined, while `enforce_shape`=True" - ) - else: - assert all( - left == right - for left, right in zip(self._shape, res.shape) # type: ignore (shape type) - if left != -1 and right != -1 - ), f"incoming has shape {res.shape} where enforced shape to be {self._shape}" - return res + return self.validate_dataframe(res) async def to_http_response( self, obj: ext.PdDataFrame, ctx: Context | None = None @@ -381,6 +356,7 @@ async def to_http_response( HTTP Response of type `starlette.responses.Response`. This can be accessed via cURL or any external web traffic. """ + obj = self.validate_dataframe(obj) # For the response it doesn't make sense to enforce the same serialization format as specified # by the request's headers['content-type']. Instead we simply use the _default_format. @@ -399,7 +375,7 @@ async def to_http_response( else: raise InvalidArgument( f"Unknown serialization format ({serialization_format})." - ) + ) from None if ctx is not None: res = Response( @@ -420,7 +396,7 @@ def from_sample( orient: ext.DataFrameOrient = "records", apply_column_names: bool = True, enforce_shape: bool = True, - enforce_dtype: bool = False, + enforce_dtype: bool = True, default_format: t.Literal["json", "parquet", "csv"] = "json", ) -> PandasDataFrame: """ @@ -435,7 +411,7 @@ def from_sample( - :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``} - :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}] - :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}} - - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}} + - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}} - :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays apply_column_names: Update incoming DataFrame columns. ``columns`` must be specified at function signature. If you don't want to enforce a specific columns @@ -469,22 +445,152 @@ def from_sample( @svc.api(input=input_spec, output=PandasDataFrame()) def predict(inputs: pd.DataFrame) -> pd.DataFrame: ... """ - columns = [str(x) for x in list(sample_input.columns)] - inst = cls( orient=orient, enforce_shape=enforce_shape, shape=sample_input.shape, apply_column_names=apply_column_names, - columns=columns, + columns=[str(x) for x in list(sample_input.columns)], enforce_dtype=enforce_dtype, - dtype=None, # TODO: not breaking atm + dtype=True, # set to True to infer from given input default_format=default_format, ) inst.sample_input = sample_input return inst + def validate_dataframe( + self, dataframe: ext.PdDataFrame, exception_cls: t.Type[Exception] = BadInput + ) -> ext.PdDataFrame: + + if not LazyType["ext.PdDataFrame"]("pd.DataFrame").isinstance(dataframe): + raise InvalidArgument( + f"return object is not of type 'pd.DataFrame', got type '{type(dataframe)}' instead" + ) from None + + # TODO: dtype check + # if self._dtype is not None and self._dtype != dataframe.dtypes: + # msg = f'{self.__class__.__name__}: Expecting DataFrame of dtype "{self._dtype}", but "{dataframe.dtypes}" was received.' + # if self._enforce_dtype: + # raise exception_cls(msg) from None + + if self._columns is not None and len(self._columns) != dataframe.shape[1]: + msg = f"length of 'columns' ({len(self._columns)}) does not match the # of columns of incoming data." + if self._apply_column_names: + raise BadInput(msg) from None + else: + logger.debug(msg) + dataframe.columns = pd.Index(self._columns) + + # TODO: convert from wide to long format (melt()) + if self._shape is not None and self._shape != dataframe.shape: + msg = f'{self.__class__.__name__}: Expecting DataFrame of shape "{self._shape}", but "{dataframe.shape}" was received.' + if self._enforce_shape and not all( + left == right + for left, right in zip(self._shape, dataframe.shape) + if left != -1 and right != -1 + ): + raise exception_cls(msg) from None + + return dataframe + + async def from_proto( + self, + field: pb.DataFrame | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> ext.PdDataFrame: + """ + Process incoming protobuf request and convert it to ``pandas.DataFrame`` + + Args: + request: Incoming RPC request message. + context: grpc.ServicerContext + + Returns: + a ``pandas.DataFrame`` object. This can then be used + inside users defined logics. + """ + # TODO: support different serialization format + if not _use_internal_bytes_contents: + if isinstance(field, pb.Part): + field = field.dataframe + # note that there is a current bug where we don't check for + # dtype of given fields per Series to match with types of a given + # columns, hence, this would result in a wrong DataFrame that is not + # expected by our users. + assert isinstance(field, pb.DataFrame) + # columns orient: { column_name : {index : columns.series._value}} + if self._orient != "columns": + raise BadInput( + f"'dataframe' field currently only supports 'columns' orient. Make sure to set 'orient=columns' in {self.__class__.__name__}." + ) from None + data: list[t.Any] = [] + + def process_columns_contents(content: pb.Series) -> dict[str, t.Any]: + # To be use inside a ThreadPoolExecutor to handle + # large tabular data + if len(content.ListFields()) != 1: + raise BadInput( + f"Array contents can only be one of given values key. Use one of '{list(map(lambda f: f[0].name,content.ListFields()))}' instead." + ) from None + return {str(i): c for i, c in enumerate(content.ListFields()[0][1])} + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = executor.map(process_columns_contents, field.columns) + data.extend([i for i in list(futures)]) + dataframe = pd.DataFrame( + dict(zip(field.column_names, data)), + columns=t.cast(t.List[str], field.column_names), + ) + else: + # TODO: handle _internal_bytes_contents for dataframe + assert isinstance(field, bytes) + raise NotImplementedError( + 'Currently not yet implemented. Use "dataframe" instead.' + ) + return self.validate_dataframe(dataframe) + + async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame: + """ + Process given objects and convert it to grpc protobuf response. + + Args: + obj: ``pandas.DataFrame`` that will be serialized to protobuf + context: grpc.aio.ServicerContext from grpc.aio.Server + Returns: + ``service_pb2.Response``: + Protobuf representation of given ``pandas.DataFrame`` + """ + from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map + + # TODO: support different serialization format + obj = self.validate_dataframe(obj) + mapping = npdtype_to_fieldpb_map() + # note that this is not safe, since we are not checking the dtype of the series + # FIXME(aarnphm): validate and handle mix columns dtype + # currently we don't support ExtensionDtype + columns_name: list[str] = list(map(str, obj.columns)) + not_supported: list[ext.PdDType] = list( + filter( + lambda x: x not in mapping, + map(lambda x: t.cast("ext.PdSeries", obj[x]).dtype, columns_name), + ) + ) + if len(not_supported) > 0: + raise UnprocessableEntity( + f'dtype in column "{obj.columns}" is not currently supported.' + ) from None + return pb.DataFrame( + column_names=columns_name, + columns=[ + pb.Series( + **{mapping[t.cast("ext.NpDTypeLike", obj[col].dtype)]: obj[col]} + ) + for col in columns_name + ], + ) + class PandasSeries(IODescriptor["ext.PdSeries"]): """ @@ -551,7 +657,7 @@ def predict(input_arr): - :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``} - :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}] - :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}} - - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}} + - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}} - :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays columns: List of columns name that users wish to update. apply_column_names (`bool`, `optional`, default to :code:`False`): @@ -573,8 +679,8 @@ def predict(input_arr): from bentoml.io import PandasSeries @svc.api(input=PandasSeries(shape=(51,10), enforce_shape=True), output=PandasSeries()) - def infer(input_df: pd.DataFrame) -> pd.DataFrame: - # if input_df have shape (40,9), it will throw out errors + def infer(input_series: pd.Series) -> pd.Series: + # if input_series have shape (40,9), it will throw out errors ... enforce_shape: Whether to enforce a certain shape. If ``enforce_shape=True`` then ``shape`` must be specified. @@ -582,12 +688,13 @@ def infer(input_df: pd.DataFrame) -> pd.DataFrame: :obj:`PandasSeries`: IO Descriptor that represents a :code:`pd.Series`. """ - _mime_type: str = MIME_TYPE_JSON + _proto_field: str = "series" + _mime_type = "application/json" def __init__( self, orient: ext.SeriesOrient = "records", - dtype: bool | dict[str, t.Any] | None = None, + dtype: ext.PdDTypeArg | None = None, enforce_dtype: bool = False, shape: tuple[int, ...] | None = None, enforce_shape: bool = False, @@ -630,29 +737,13 @@ async def from_http_request(self, request: Request) -> ext.PdSeries: a ``pd.Series`` object. This can then be used inside users defined logics. """ obj = await request.body() - if self._enforce_dtype: - if self._dtype is None: - logger.warning( - "`dtype` is None or undefined, while `enforce_dtype=True`" - ) - - # TODO(jiang): check dtypes when enforce_dtype is set - res = pd.read_json(obj, typ="series", orient=self._orient, dtype=self._dtype) - - assert isinstance(res, pd.Series) - - if self._enforce_shape: - if self._shape is None: - logger.warning( - "`shape` is None or undefined, while `enforce_shape`=True" - ) - else: - assert all( - left == right - for left, right in zip(self._shape, res.shape) - if left != -1 and right != -1 - ), f"incoming has shape {res.shape} where enforced shape to be {self._shape}" - return res + res: ext.PdSeries = pd.read_json( + obj, + typ="series", + orient=self._orient, + dtype=self._dtype, + ) + return self.validate_series(res) async def to_http_response( self, obj: t.Any, ctx: Context | None = None @@ -665,19 +756,48 @@ async def to_http_response( Returns: HTTP Response of type ``starlette.responses.Response``. This can be accessed via cURL or any external web traffic. """ - if not LazyType["ext.PdSeries"](pd.Series).isinstance(obj): - raise InvalidArgument( - f"return object is not of type `pd.Series`, got type {type(obj)} instead" - ) - + obj = self.validate_series(obj) if ctx is not None: res = Response( obj.to_json(orient=self._orient), - media_type=MIME_TYPE_JSON, + media_type=self._mime_type, headers=ctx.response.headers, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(obj.to_json(orient=self._orient), media_type=MIME_TYPE_JSON) + return Response( + obj.to_json(orient=self._orient), media_type=self._mime_type + ) + + def validate_series( + self, series: ext.PdSeries, exception_cls: t.Type[Exception] = BadInput + ) -> ext.PdSeries: + # TODO: dtype check + if not LazyType["ext.PdSeries"]("pd.Series").isinstance(series): + raise InvalidArgument( + f"return object is not of type 'pd.Series', got type '{type(series)}' instead" + ) from None + # TODO: convert from wide to long format (melt()) + if self._shape is not None and self._shape != series.shape: + msg = f"{self.__class__.__name__}: Expecting Series of shape '{self._shape}', but '{series.shape}' was received." + if self._enforce_shape and not all( + left == right + for left, right in zip(self._shape, series.shape) + if left != -1 and right != -1 + ): + raise exception_cls(msg) from None + + return series + + async def from_proto( + self, + field: pb.Series | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> ext.PdSeries: + raise NotImplementedError("Currently not yet implemented.") + + async def to_proto(self, obj: ext.PdSeries) -> pb.Series: + raise NotImplementedError("Currently not yet implemented.") diff --git a/bentoml/_internal/io_descriptors/text.py b/bentoml/_internal/io_descriptors/text.py index db372250cb2..e7dad38b6bf 100644 --- a/bentoml/_internal/io_descriptors/text.py +++ b/bentoml/_internal/io_descriptors/text.py @@ -11,14 +11,23 @@ from .base import IODescriptor from ..utils.http import set_cookies from ..service.openapi import SUCCESS_DESCRIPTION +from ..utils.lazy_loader import LazyLoader +from ..service.openapi.specification import Schema +from ..service.openapi.specification import Response as OpenAPIResponse from ..service.openapi.specification import MediaType +from ..service.openapi.specification import RequestBody if TYPE_CHECKING: + from google.protobuf import wrappers_pb2 + + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from ..context import InferenceApiContext as Context +else: + from bentoml.grpc.utils import import_generated_stubs -from ..service.openapi.specification import Schema -from ..service.openapi.specification import Response as OpenAPIResponse -from ..service.openapi.specification import RequestBody + pb, _ = import_generated_stubs() + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") MIME_TYPE = "text/plain" @@ -86,13 +95,14 @@ def predict(text: str) -> str: :obj:`Text`: IO Descriptor that represents strings type. """ + _proto_field: str = "text" + _mime_type = MIME_TYPE + def __init__(self, *args: t.Any, **kwargs: t.Any): if args or kwargs: raise BentoMLException( - "'Text' is not designed to take any args or kwargs during initialization." - ) - - self._mime_type = MIME_TYPE + f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization." + ) from None def input_type(self) -> t.Type[str]: return str @@ -123,11 +133,29 @@ async def to_http_response(self, obj: str, ctx: Context | None = None) -> Respon if ctx is not None: res = Response( obj, - media_type=MIME_TYPE, + media_type=self._mime_type, headers=ctx.response.metadata, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(obj, media_type=MIME_TYPE) + return Response(obj, media_type=self._mime_type) + + async def from_proto( + self, + field: wrappers_pb2.StringValue | pb.Part | bytes, + *, + _use_internal_bytes_contents: bool = False, + ) -> str: + if not _use_internal_bytes_contents: + if isinstance(field, pb.Part): + field = field.text + assert isinstance(field, wrappers_pb2.StringValue) + return field.value + else: + assert isinstance(field, bytes) + return field.decode("utf-8") + + async def to_proto(self, obj: str) -> wrappers_pb2.StringValue: + return wrappers_pb2.StringValue(value=obj) diff --git a/tests/unit/_internal/io/test_numpy.py b/tests/unit/_internal/io/test_numpy.py index 811891af763..2ea623e59ae 100644 --- a/tests/unit/_internal/io/test_numpy.py +++ b/tests/unit/_internal/io/test_numpy.py @@ -32,7 +32,7 @@ def test_invalid_dtype(): generic = ExampleGeneric("asdf") with pytest.raises(BentoMLException) as e: _ = NumpyNdarray.from_sample(generic) # type: ignore (test exception) - assert "expects a numpy.array" in str(e.value) + assert "expects a 'numpy.array'" in str(e.value) @pytest.mark.parametrize("dtype, expected", [("float", "number"), (">U8", "integer")]) @@ -82,22 +82,20 @@ def test_numpy_openapi_responses(): def test_verify_numpy_ndarray(caplog: LogCaptureFixture): - partial_check = partial( - from_example._verify_ndarray, exception_cls=BentoMLException # type: ignore (test internal check) - ) + partial_check = partial(from_example.validate_array, exception_cls=BentoMLException) with pytest.raises(BentoMLException) as ex: partial_check(np.array(["asdf"])) - assert f'Expecting ndarray of dtype "{from_example._dtype}"' in str(ex.value) # type: ignore (testing message) + assert f'Expecting ndarray of dtype "{from_example._dtype}"' in str(ex.value) with pytest.raises(BentoMLException) as e: partial_check(np.array([[1]])) - assert f'Expecting ndarray of shape "{from_example._shape}"' in str(e.value) # type: ignore (testing message) + assert f'Expecting ndarray of shape "{from_example._shape}"' in str(e.value) - # test cases whwere reshape is failed + # test cases where reshape is failed example = NumpyNdarray.from_sample(np.ones((2, 2, 3))) - example._enforce_shape = False # type: ignore (test internal check) - example._enforce_dtype = False # type: ignore (test internal check) + example._enforce_shape = False + example._enforce_dtype = False with caplog.at_level(logging.DEBUG): - example._verify_ndarray(np.array("asdf")) + example.validate_array(np.array("asdf")) assert "Failed to reshape" in caplog.text