From d9d2fac882466f2f2fa5c32e836f599d4c59255a Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 28 Jul 2022 17:02:52 -0700 Subject: [PATCH] fix(numpy): handling numpy object and rename to generic types Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- bentoml/_internal/configuration/containers.py | 1 + .../configuration/default_configuration.yaml | 1 + bentoml/_internal/io_descriptors/base.py | 30 ++- bentoml/_internal/io_descriptors/file.py | 2 +- bentoml/_internal/io_descriptors/image.py | 2 +- bentoml/_internal/io_descriptors/json.py | 2 +- bentoml/_internal/io_descriptors/numpy.py | 237 +++++++----------- bentoml/_internal/io_descriptors/pandas.py | 8 +- bentoml/_internal/io_descriptors/text.py | 2 +- .../server/grpc/interceptors/__init__.py | 160 ++++-------- .../server/grpc/interceptors/access.py | 60 ----- .../server/grpc/interceptors/trace.py | 0 bentoml/_internal/server/grpc/servicer.py | 66 ++++- bentoml/_internal/server/grpc/types.py | 34 +-- bentoml/_internal/server/grpc_app.py | 37 ++- bentoml/_internal/utils/grpc/__init__.py | 55 ++-- bentoml/grpc/v1/service.proto | 171 +++---------- 17 files changed, 322 insertions(+), 546 deletions(-) delete mode 100644 bentoml/_internal/server/grpc/interceptors/access.py create mode 100644 bentoml/_internal/server/grpc/interceptors/trace.py diff --git a/bentoml/_internal/configuration/containers.py b/bentoml/_internal/configuration/containers.py index b6dc47793db..623209dbbe7 100644 --- a/bentoml/_internal/configuration/containers.py +++ b/bentoml/_internal/configuration/containers.py @@ -111,6 +111,7 @@ def _is_ip_address(addr: str) -> bool: }, "grpc": { "max_message_length": Or(int, None), + "maximum_concurrent_rpcs": Or(int, None), }, }, "runners": { diff --git a/bentoml/_internal/configuration/default_configuration.yaml b/bentoml/_internal/configuration/default_configuration.yaml index 30b50283557..ea49bbb5763 100644 --- a/bentoml/_internal/configuration/default_configuration.yaml +++ b/bentoml/_internal/configuration/default_configuration.yaml @@ -25,6 +25,7 @@ api_server: access_control_expose_headers: Null grpc: max_message_length: Null + maximum_concurrent_rpcs: Null runners: batching: diff --git a/bentoml/_internal/io_descriptors/base.py b/bentoml/_internal/io_descriptors/base.py index f7f059ba09e..9ed84e95a72 100644 --- a/bentoml/_internal/io_descriptors/base.py +++ b/bentoml/_internal/io_descriptors/base.py @@ -1,13 +1,14 @@ from __future__ import annotations import typing as t -from abc import ABC +from abc import ABCMeta from abc import abstractmethod from typing import TYPE_CHECKING if TYPE_CHECKING: from types import UnionType + from typing_extensions import Self from starlette.requests import Request from starlette.responses import Response @@ -32,7 +33,22 @@ _T = t.TypeVar("_T") -class IODescriptor(ABC, t.Generic[IOPyObj]): +class DescriptorMeta(ABCMeta): + def __new__( + cls: type[Self], + name: str, + bases: tuple[type, ...], + namespace: dict[str, t.Any], + *, + proto_fields: list[str] | None = None, + ) -> Self: + if not proto_fields: + proto_fields = [] + namespace["_proto_fields"] = proto_fields + return super().__new__(cls, name, bases, namespace) + + +class IODescriptor(t.Generic[IOPyObj], metaclass=DescriptorMeta, proto_fields=None): """ IODescriptor describes the input/output data format of an InferenceAPI defined in a :code:`bentoml.Service`. This is an abstract base class for extending new HTTP @@ -41,7 +57,7 @@ class IODescriptor(ABC, t.Generic[IOPyObj]): HTTP_METHODS = ["POST"] _init_str: str = "" - _proto_kind: list[str] | None = None + _proto_fields: list[str] def __new__(cls: t.Type[_T], *args: t.Any, **kwargs: t.Any) -> _T: self = super().__new__(cls) @@ -56,17 +72,13 @@ def __repr__(self) -> str: return self._init_str @property - def accepted_proto_kind(self) -> list[str]: + def accepted_proto_fields(self) -> list[str]: """ Returns a list of kinds fields that the IODescriptor can accept. Make sure to keep in sync with bentoml.grpc.v1.Value message. """ - return self._proto_kind or [] - - @accepted_proto_kind.setter - def accepted_proto_kind(self, value: list[str]): - self._proto_kind = value + return self._proto_fields @abstractmethod def input_type(self) -> InputType: diff --git a/bentoml/_internal/io_descriptors/file.py b/bentoml/_internal/io_descriptors/file.py index 663a8b6f8b5..4fd7ccc9bb8 100644 --- a/bentoml/_internal/io_descriptors/file.py +++ b/bentoml/_internal/io_descriptors/file.py @@ -24,7 +24,7 @@ FileType: t.TypeAlias = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]] -class File(IODescriptor[FileType]): +class File(IODescriptor[FileType], proto_fields=["raw_value"]): """ :code:`File` defines API specification for the inputs/outputs of a Service, where either inputs will be converted to or outputs will be converted from file-like objects as diff --git a/bentoml/_internal/io_descriptors/image.py b/bentoml/_internal/io_descriptors/image.py index 5b1ae9446ff..654d68d3b5d 100644 --- a/bentoml/_internal/io_descriptors/image.py +++ b/bentoml/_internal/io_descriptors/image.py @@ -47,7 +47,7 @@ DEFAULT_PIL_MODE = "RGB" -class Image(IODescriptor[ImageType]): +class Image(IODescriptor[ImageType], proto_fields=["raw_value"]): """ :code:`Image` defines API specification for the inputs/outputs of a Service, where either inputs will be converted to or outputs will be converted from images as specified diff --git a/bentoml/_internal/io_descriptors/json.py b/bentoml/_internal/io_descriptors/json.py index a6c917b468e..7c76736a850 100644 --- a/bentoml/_internal/io_descriptors/json.py +++ b/bentoml/_internal/io_descriptors/json.py @@ -67,7 +67,7 @@ def default(self, o: _SerializableObj) -> t.Any: return super().default(o) -class JSON(IODescriptor[JSONType]): +class JSON(IODescriptor[JSONType], proto_fields=["map_value", "raw_value"]): """ :code:`JSON` defines API specification for the inputs/outputs of a Service, where either inputs will be converted to or outputs will be converted from a JSON representation diff --git a/bentoml/_internal/io_descriptors/numpy.py b/bentoml/_internal/io_descriptors/numpy.py index d860a6da376..30c34635670 100644 --- a/bentoml/_internal/io_descriptors/numpy.py +++ b/bentoml/_internal/io_descriptors/numpy.py @@ -3,7 +3,6 @@ import json import typing as t import logging -from typing import overload from typing import TYPE_CHECKING from starlette.requests import Request @@ -33,109 +32,38 @@ logger = logging.getLogger(__name__) - -_DTYPE_TO_FIELD_MAP = { - "DT_BOOL": "bool_contents", - "DT_FLOAT": "float_contents", - "DT_COMPLEX64": "float_contents", - "DT_STRING": "string_contents", - "DT_DOUBLE": "double_contents", - "DT_COMPLEX128": "double_contents", - "DT_INT32": "int_contents", - "DT_IN16": "int_contents", - "DT_UINT16": "int_contents", - "DT_INT8": "int_contents", - "DT_UINT8": "int_contents", - "DT_HALF": "int_contents", - "DT_INT64": "long_contents", - "DT_STRUCT": "struct_contents", - "DT_UINT32": "uint32_contents", - "DT_UINT64": "uint64_contents", - # "DT_QINT32": "bytes_contents", - # "DT_QINT16": "bytes_contents", - # "DT_QUINT16": "bytes_contents", - # "DT_QINT8": "bytes_contents", - # "DT_QUINT8": "bytes_contents", - # "DT_BFLOAT16": "int_contents", -} - -_DTYPE_TO_STRING_MAP = { - "DT_BOOL": "bool", - "DT_FLOAT": "float32", - "DT_COMPLEX64": "complex64", - "DT_STRING": " specify types in NumpyNdarray + using int_values. +# +# For bfloat16, half (float16) -> specify types in NumpyNdarray + using float_values. +# +# for string_values, use np.dtype[t.Any] | None: - if datatype_string == "DT_UNSPECIFIED": - return - elif datatype_string in _NOT_SUPPORTED_DTYPE: - raise UnprocessableEntity(f"{datatype_string} is not yet supported.") - elif datatype_string == "DT_STRUCT": - assert ( - struct_npdtype - ), "'dtype' is required in NumpyNdarray to use in conjunction with DT_STRUCT." - return struct_npdtype - else: - return np.dtype(_DTYPE_TO_STRING_MAP[datatype_string]) - - -@overload -def get_array_value(array: dict[str, str | bytes]) -> tuple[str, bytes, bool]: - ... - - -@overload -def get_array_value( - array: dict[str, str | list[t.Any]] -) -> tuple[str, list[t.Any], bool]: - ... - - -# array_descriptor -> {"dtype": "DT_FLOAT", "float_contents": [1, 2, 3]} -def get_array_value(array: dict[str, t.Any]) -> tuple[str, list[t.Any] | bytes, bool]: +# array_descriptor -> {"float_contents": [1, 2, 3]} +def get_array_proto(array: dict[str, t.Any]) -> tuple[str, list[t.Any]]: # returns the array contents with whether the result is using bytes. - dtype = t.cast(str, array.pop("dtype")) - if _DTYPE_TO_FIELD_MAP[dtype] not in array: - if "bytes_contents" not in array: - raise BadInput( - f"{dtype} requires specifying either '{_DTYPE_TO_FIELD_MAP[dtype]}' or 'bytes_contents' in the protobuf message." - ) - content = array.pop("bytes_contents") - assert isinstance(content, bytes) - return dtype, content, True - else: - # all of the repeated fields can be represented as list. - content = t.cast(t.List[t.Any], array.pop(_DTYPE_TO_FIELD_MAP[dtype])) - return dtype, content, False + accepted_fields = list(service_pb2.Array.DESCRIPTOR.fields_by_name) + if len(set(array) - set(accepted_fields)) > 0: + raise UnprocessableEntity("Given array has unsupported fields.") + if len(array) != 1: + raise BadInput( + f"Array contents can only be one of {accepted_fields} as key. Use one of {list(array)} only." + ) + return tuple(array.items())[0] def _is_matched_shape(left: tuple[int, ...], right: tuple[int, ...]) -> bool: @@ -155,7 +83,10 @@ def _is_matched_shape(left: tuple[int, ...], right: tuple[int, ...]) -> bool: # TODO: when updating docs, add examples with gRPCurl -class NumpyNdarray(IODescriptor["ext.NpNDArray"]): +class NumpyNdarray( + IODescriptor["ext.NpNDArray"], + proto_fields=["multi_dimensional_array_value", "array_value", "raw_value"], +): """ :code:`NumpyNdarray` defines API specification for the inputs/outputs of a Service, where either inputs will be converted to or outputs will be converted from type @@ -227,29 +158,28 @@ def __init__( bytesorder: t.Literal["C", "F", "A", None] = None, ): if dtype is not None and not isinstance(dtype, np.dtype): - # Convert from primitive type or type string, e.g.: - # np.dtype(float) - # np.dtype("float64") + # Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64") try: dtype = np.dtype(dtype) except TypeError as e: - raise BentoMLException(f'NumpyNdarray: Invalid dtype "{dtype}": {e}') + raise UnprocessableEntity(f'NumpyNdarray: Invalid dtype "{dtype}": {e}') self._dtype: np.dtype[t.Any] | None = dtype self._shape = shape self._enforce_dtype = enforce_dtype self._enforce_shape = enforce_shape + + # whether to use packed representation of numpy while sending protobuf + # this means users should be using raw_value instead of array_value or multi_dimensional_array_value self._packed = packed - if bytesorder not in ["C", "F", "A", None]: + if bytesorder and bytesorder not in ["C", "F", "A"]: raise BadInput( - f"'bytesorder' must be one of ['C', 'F', 'A', 'None'], got {bytesorder} instead." + f"'bytesorder' must be one of ['C', 'F', 'A'], got {bytesorder} instead." ) if not bytesorder: bytesorder = "C" # default from numpy (C-order) # https://numpy.org/doc/stable/user/basics.byteswapping.html#introduction-to-byte-ordering-and-ndarrays - self._bytesorder: t.Literal["C", "F", "A", None] = bytesorder - - self.accepted_proto_kind = ["array_value", "ndarray_value"] + self._bytesorder: t.Literal["C", "F", "A"] = bytesorder def _infer_openapi_types(self) -> str: # pragma: no cover if self._dtype is not None: @@ -326,7 +256,6 @@ async def from_http_request(self, request: Request) -> ext.NpNDArray: inside users defined logics. """ obj = await request.json() - res: "ext.NpNDArray" try: res = np.array(obj, dtype=self._dtype) except ValueError: @@ -374,8 +303,30 @@ async def from_grpc_request( from ..utils.grpc import deserialize_proto + # TODO: deserialize is pretty inefficient, but ok for first pass. field, serialized = deserialize_proto(self, request) - if field == "ndarray_value": + + if self._packed: + if field != "raw_value": + raise BentoMLException( + f"'packed={self._packed}' requires to use 'raw_value' instead of {field}." + ) + if not self._shape: + raise UnprocessableEntity("'shape' is required when 'packed' is set.") + metadata = serialized["metadata"] + if not self._dtype: + if "dtype" not in metadata: + raise BentoMLException( + f"'dtype' is not found in both {repr(self)} and {metadata}. Set either 'dtype' in {self.__class__.__name__} or add 'dtype' to metadata for 'raw_value' message." + ) + dtype = metadata["dtype"] + else: + dtype = self._dtype + obj = np.frombuffer(serialized["content"], dtype=dtype) + + return np.reshape(obj, self._shape) + + if field == "multi_dimensional_array_value": # {'shape': [2, 3], 'array': {'dtype': 'DT_FLOAT', ...}} if "array" not in serialized: msg = "'array' cannot be None." @@ -399,11 +350,11 @@ async def from_grpc_request( array = serialized["array"] else: - # {'dtype': 'DT_FLOAT', 'float_contents': [1.0, 2.0, 3.0]} + # {'float_contents': [1.0, 2.0, 3.0]} array = serialized - dtype_string, content, use_bytes = get_array_value(array) - dtype = get_dtype(dtype_string, struct_npdtype=self._dtype) + dtype_string, content = get_array_proto(array) + dtype = np.dtype(_VALUES_TO_NP_DTYPE_MAP[dtype_string]) if self._dtype: if not self._enforce_dtype: logger.warning( @@ -417,13 +368,10 @@ async def from_grpc_request( else: self._dtype = dtype - if use_bytes: - res = np.frombuffer(content, dtype=self._dtype) - else: - try: - res = np.array(content, dtype=self._dtype) - except ValueError: - res = np.array(content) + try: + res = np.array(content, dtype=self._dtype) + except ValueError: + res = np.array(content) return self._verify_ndarray(res, BadInput) @@ -443,10 +391,8 @@ async def to_grpc_response( from ..utils.grpc import grpc_status_code from ..configuration import get_debug_mode - _NPTYPE_TO_DTYPE_STRING_MAP = { - np.dtype(v): k for k, v in _DTYPE_TO_STRING_MAP.items() - } - dtype_string = _NPTYPE_TO_DTYPE_STRING_MAP[obj.dtype] + _NP_TO_VALUE_MAP = {np.dtype(v): k for k, v in _VALUES_TO_NP_DTYPE_MAP.items()} + value_key = _NP_TO_VALUE_MAP[obj.dtype] try: obj = self._verify_ndarray(obj, InternalServerError) @@ -455,31 +401,38 @@ async def to_grpc_response( context.set_details(e.message) raise - cnt: dict[str, t.Any] = {"dtype": dtype_string} + response = service_pb2.Response() + value = service_pb2.Value() - resp = service_pb2.Response() if self._packed: - cnt.update({"bytes_contents": obj.tobytes(order=self._bytesorder)}) + raw = service_pb2.Raw( + metadata={"dtype": str(obj.dtype)}, + content=obj.tobytes(order=self._bytesorder), + ) + value.raw_value.CopyFrom(raw) else: if self._bytesorder: logger.warning( f"'bytesorder={self._bytesorder}' is ignored when 'packed={self._packed}'." ) - cnt.update({_DTYPE_TO_FIELD_MAP[dtype_string]: obj.tolist()}) - - if obj.ndim == 1: - message = service_pb2.Array(**cnt) - resp.contents.array_value.CopyFrom(message) - else: - cnt["shape"] = tuple(obj.shape) - resp.contents.ndarray_value.CopyFrom( - service_pb2.NDArray( - shape=tuple(obj.shape), array=service_pb2.Array(**cnt) + # we just need a view of the array, instead of copy it to contiguous memory. + array = service_pb2.Array(**{value_key: obj.ravel().tolist()}) + if obj.ndim != 1: + ndarray = service_pb2.MultiDimensionalArray( + shape=tuple(obj.shape), array=array ) - ) + value.multi_dimensional_array_value.CopyFrom(ndarray) + else: + value.array_value.CopyFrom(array) + + response.output.CopyFrom(value) + if get_debug_mode(): - logger.debug(f"Response proto: \n{resp}") - return resp + logger.debug( + f"Response proto: {response.SerializeToString(deterministic=True)}" + ) + + return response def generate_protobuf(self): pass diff --git a/bentoml/_internal/io_descriptors/pandas.py b/bentoml/_internal/io_descriptors/pandas.py index f55abd2822a..66a345e38c8 100644 --- a/bentoml/_internal/io_descriptors/pandas.py +++ b/bentoml/_internal/io_descriptors/pandas.py @@ -128,7 +128,9 @@ def _validate_serialization_format(serialization_format: SerializationFormat): ) -class PandasDataFrame(IODescriptor["ext.PdDataFrame"]): +class PandasDataFrame( + IODescriptor["ext.PdDataFrame"], proto_fields=["map_value", "raw_value"] +): """ :code:`PandasDataFrame` defines API specification for the inputs/outputs of a Service, where either inputs will be converted to or outputs will be converted from type @@ -480,7 +482,9 @@ def predict(inputs: pd.DataFrame) -> pd.DataFrame:... ) -class PandasSeries(IODescriptor["ext.PdSeries"]): +class PandasSeries( + IODescriptor["ext.PdSeries"], proto_fields=["map_value", "raw_value"] +): """ :code:`PandasSeries` defines API specification for the inputs/outputs of a Service, where either inputs will be converted to or outputs will be converted from type diff --git a/bentoml/_internal/io_descriptors/text.py b/bentoml/_internal/io_descriptors/text.py index c6975fedf00..fbda5d110e1 100644 --- a/bentoml/_internal/io_descriptors/text.py +++ b/bentoml/_internal/io_descriptors/text.py @@ -18,7 +18,7 @@ MIME_TYPE = "text/plain" -class Text(IODescriptor[str]): +class Text(IODescriptor[str], proto_fields=["string_value", "raw_value"]): """ :code:`Text` defines API specification for the inputs/outputs of a Service. :code:`Text` represents strings for all incoming requests/outcoming responses as specified in diff --git a/bentoml/_internal/server/grpc/interceptors/__init__.py b/bentoml/_internal/server/grpc/interceptors/__init__.py index 97fdee7bf78..379a71c0e47 100644 --- a/bentoml/_internal/server/grpc/interceptors/__init__.py +++ b/bentoml/_internal/server/grpc/interceptors/__init__.py @@ -1,22 +1,20 @@ from __future__ import annotations -import sys import typing as t import logging -from abc import ABCMeta -from abc import abstractmethod +import functools +from timeit import default_timer from typing import TYPE_CHECKING -import grpc from grpc import aio +from opentelemetry import trace -from bentoml.exceptions import BentoMLException - -from ....utils.grpc import get_rpc_handler -from ....utils.grpc import grpc_status_code -from ....utils.grpc import invoke_handler_factory +from ....utils.grpc import wrap_rpc_handler if TYPE_CHECKING: + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.trace.span import Span + from ..types import Request from ..types import Response from ..types import HandlerMethod @@ -24,122 +22,64 @@ from ..types import HandlerCallDetails from ..types import BentoServicerContext -AsyncClientInterceptorReturn = type( - "AsyncClientInterceptorReturn", (aio.Call, grpc.Future), {} -) - logger = logging.getLogger(__name__) -class AsyncServerInterceptor(aio.ServerInterceptor, metaclass=ABCMeta): +class AccessLogInterceptor(aio.ServerInterceptor): """ - Base class for BentoService server-side interceptors. + An asyncio interceptors for access log. - To implement, subclass this class and override ``intercept`` method. - - Currently, only unary RPCs are supported. + .. TODO: + - Add support for streaming RPCs. """ - @abstractmethod - async def intercept( - self, - method: HandlerMethod[t.Any], - request: Request, - context: BentoServicerContext, - method_name: str, - ) -> t.Any: - response_or_iterator = method(request, context) - if hasattr(response_or_iterator, "__aiter__"): - return response_or_iterator - else: - return await response_or_iterator + def __init__(self, tracer_provider: TracerProvider) -> None: + self.logger = logging.getLogger("bentoml.access") + self.tracer_provider = tracer_provider async def intercept_service( self, continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], handler_call_details: HandlerCallDetails, ) -> RpcMethodHandler: - """ - Implementation of grpc.aio.ServerInterceptor. - Don't override unless you know what you are doing. - """ handler = await continuation(handler_call_details) - handler_factory, next_handler = get_rpc_handler(handler) method_name = handler_call_details.method - # if handler is a streaming RPC, the return handler - # would not be the Request message. - # Right now we will just pass the handler directly to - # the interceptor. - # TODO: support streaming RPCs. - if handler and (handler.request_streaming or handler.response_streaming): + if handler and (handler.response_streaming or handler.request_streaming): return handler - async def invoke_intercept_unary( - request: Request, context: BentoServicerContext - ) -> t.Awaitable[Response]: - return await self.intercept(next_handler, request, context, method_name) - - return invoke_handler_factory(invoke_intercept_unary, handler_factory, handler) - - -class ExceptionHandlerInterceptor(AsyncServerInterceptor): - """An async interceptor that handles exceptions raised via BentoService.""" - - async def handle_exception( - self, - ex: Exception, - context: BentoServicerContext, - method_name: str, - ) -> None: - """Handle an exception raised by a method. - - Args: - ex: The exception raised by the method. - context: The context of the RPC. - method_name: The name of the method. - """ - logger.error( - f"Error while invoking {method_name}: {ex}", exc_info=sys.exc_info() - ) - details = f"{ex.__class__.__name__}<{str(ex)}>" - if isinstance(ex, BentoMLException): - status_code = grpc_status_code(ex) - details = ex.message - elif any(isinstance(ex, cls) for cls in (RuntimeError, TypeError)): - status_code = grpc.StatusCode.INTERNAL - details = "An error has occurred in BentoML user code when handling this request, find the error details in server logs." - else: - status_code = grpc.StatusCode.UNKNOWN - - await context.abort(code=status_code, details=details) - raise ex - - async def generate_responses( - self, - context: BentoServicerContext, - method_name: str, - response_iterator: t.AsyncIterable[Response], - ) -> t.AsyncGenerator[t.Any, None]: - """Yield all the responses, but check for errors along the way.""" - try: - async for r in response_iterator: - yield r - except Exception as ex: - await self.handle_exception(ex, context, method_name) - - async def intercept( - self, - method: HandlerMethod[t.Any], - request: Request, - context: BentoServicerContext, - method_name: str, - ) -> t.AsyncGenerator[Response, t.Any]: - try: - response_or_iterator = method(request, context) - if not hasattr(response_or_iterator, "__aiter__"): - return await response_or_iterator - except Exception as ex: - await self.handle_exception(ex, context, method_name) - - return self.generate_responses(context, method_name, response_or_iterator) # type: ignore (unknown variable warning) + def wrapper( + behaviour: HandlerMethod[Response | t.AsyncGenerator[Response, None]] + ) -> t.Callable[..., t.Any]: + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response: + + tracer = self.tracer_provider.get_tracer( + "opentelemetry.instrumentation.grpc" + ) + span: Span = tracer.start_span("grpc") + span_context = span.get_span_context() + kind = str(request.input.WhichOneof("kind")) + + start = default_timer() + with trace.use_span(span, end_on_exit=True): + response = behaviour(request, context) + if not hasattr(response, "__aiter__"): + response = await response + latency = max(default_timer() - start, 0) + + req_info = f"api_name={request.api_name},type={kind},size={request.input.ByteSize()}" + resp_info = f"status={context.code()},type={kind},size={response.output.ByteSize()}" + trace_and_span = f"trace={span_context.trace_id},span={span_context.span_id},sampled={1 if span_context.trace_flags.sampled else 0}" + + self.logger.info( + f"{context.peer()} ({req_info}) ({resp_info}) {latency:.3f}ms ({trace_and_span})" + ) + + return response + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) diff --git a/bentoml/_internal/server/grpc/interceptors/access.py b/bentoml/_internal/server/grpc/interceptors/access.py deleted file mode 100644 index b7fb9b4bff0..00000000000 --- a/bentoml/_internal/server/grpc/interceptors/access.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import typing as t -import logging -from timeit import default_timer -from typing import TYPE_CHECKING - -from simple_di import inject -from simple_di import Provide -from opentelemetry import trace - -from . import AsyncServerInterceptor -from ....configuration.containers import BentoMLContainer - -if TYPE_CHECKING: - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.trace.span import Span - - from bentoml.grpc.v1.service_pb2 import Request - from bentoml.grpc.v1.service_pb2 import Response - - from ..types import HandlerMethod - from ..types import BentoServicerContext - - -class AccessLogInterceptor(AsyncServerInterceptor): - def __init__(self) -> None: - self.logger = logging.getLogger("bentoml.access") - - @inject - async def intercept( - self, - method: HandlerMethod[t.Any], - request: Request, - context: BentoServicerContext, - method_name: str, - *, - tracer_provider: TracerProvider = Provide[BentoMLContainer.tracer_provider], - ) -> t.AsyncGenerator[Response, None]: - tracer = tracer_provider.get_tracer("opentelemetry.instrumentation.grpc") - span: Span = tracer.start_span("grpc") - span_context = span.get_span_context() - kind = str(request.contents.WhichOneof("kind")) - - start = default_timer() - with trace.use_span(span, end_on_exit=True): - response_or_iterator = method(request, context) - if not hasattr(response_or_iterator, "__aiter__"): - response_or_iterator = await response_or_iterator - latency = max(default_timer() - start, 0) - - req_info = f"api_name={request.api_name},type={kind},size={request.contents.ByteSize()}" - resp_info = f"status={context.code()},type={kind},size={response_or_iterator.contents.ByteSize()}" - trace_and_span = f"trace={span_context.trace_id},span={span_context.span_id},sampled={1 if span_context.trace_flags.sampled else 0}" - - self.logger.info( - f"{context.peer()} ({req_info}) ({resp_info}) {latency:.3f}ms ({trace_and_span})" - ) - - return response_or_iterator diff --git a/bentoml/_internal/server/grpc/interceptors/trace.py b/bentoml/_internal/server/grpc/interceptors/trace.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/bentoml/_internal/server/grpc/servicer.py b/bentoml/_internal/server/grpc/servicer.py index 5ff5977d760..c8e638099f1 100644 --- a/bentoml/_internal/server/grpc/servicer.py +++ b/bentoml/_internal/server/grpc/servicer.py @@ -1,31 +1,52 @@ from __future__ import annotations +import sys import asyncio +import logging from typing import TYPE_CHECKING +import grpc import anyio from grpc import aio +from bentoml.exceptions import BentoMLException from bentoml.exceptions import UnprocessableEntity from bentoml.exceptions import MissingDependencyException from bentoml._internal.service.service import Service +from ...utils import LazyLoader +from ...utils.grpc import grpc_status_code + +logger = logging.getLogger(__name__) + if TYPE_CHECKING: + from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) + + from bentoml.grpc.v1 import service_pb2 as _service_pb2 + from bentoml.grpc.v1 import service_pb2_grpc as _service_pb2_grpc + from .types import BentoServicerContext +else: + _service_pb2 = LazyLoader("_service_pb2", globals(), "bentoml.grpc.v1.service_pb2") + _service_pb2_grpc = LazyLoader( + "_service_pb2_grpc", globals(), "bentoml.grpc.v1.service_pb2_grpc" + ) + + +def log_exception(request: _service_pb2.Request, exc_info: ExcInfoType) -> None: + logger.error(f"Exception on /{request.api_name}", exc_info=exc_info) def register_bento_servicer(service: Service, server: aio.Server) -> None: """ This is the actual implementation of BentoServicer. - Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Inference + Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call """ - from bentoml.grpc.v1 import service_pb2 as _service_pb2 - from bentoml.grpc.v1 import service_pb2_grpc as _service_pb2_grpc class BentoServiceServicer(_service_pb2_grpc.BentoServiceServicer): """An asyncio implementation of BentoService servicer.""" - async def Infer( # type: ignore (no async types) + async def Call( # type: ignore (no async types) self, request: _service_pb2.Request, context: BentoServicerContext, @@ -36,15 +57,33 @@ async def Infer( # type: ignore (no async types) ) api = service.apis[request.api_name] - - input = await api.input.from_grpc_request(request, context) - - if asyncio.iscoroutinefunction(api.func): - output = await api.func(input) - else: - output = await anyio.to_thread.run_sync(api.func, input) - - return await api.output.to_grpc_response(output, context) + response = _service_pb2.Response() + + try: + input = await api.input.from_grpc_request(request, context) + + if asyncio.iscoroutinefunction(api.func): + output = await api.func(input) + else: + output = await anyio.to_thread.run_sync(api.func, input) + + response = await api.output.to_grpc_response(output, context) + except BentoMLException as e: + log_exception(request, sys.exc_info()) + await context.abort(code=grpc_status_code(e), details=e.message) + except (RuntimeError, TypeError, NotImplementedError): + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="An internal runtime error has occurred, check out error details in server logs.", + ) + except Exception: # type: ignore (generic exception) + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.UNKNOWN, + details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", + ) + return response _service_pb2_grpc.add_BentoServiceServicer_to_server(BentoServiceServicer(), server) # type: ignore (lack of asyncio types) @@ -61,6 +100,7 @@ async def register_health_servicer(server: aio.Server) -> None: "'grpcio-health-checking' is required for using health checking endpoints. Install with `pip install grpcio-health-checking`." ) try: + # reflection is required for health checking to work. from grpc_reflection.v1alpha import reflection except ImportError: raise MissingDependencyException( diff --git a/bentoml/_internal/server/grpc/types.py b/bentoml/_internal/server/grpc/types.py index 497322e44a3..d891dc2f4b3 100644 --- a/bentoml/_internal/server/grpc/types.py +++ b/bentoml/_internal/server/grpc/types.py @@ -3,7 +3,6 @@ """ from __future__ import annotations -from typing import Any from typing import TypeVar from typing import Callable from typing import Optional @@ -11,8 +10,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Protocol - import grpc from grpc import aio @@ -20,25 +17,14 @@ from bentoml.grpc.v1.service_pb2 import Response from bentoml.grpc.v1.service_pb2_grpc import BentoServiceServicer - P_con = TypeVar("P_con", contravariant=True) + P = TypeVar("P") BentoServicerContext = aio.ServicerContext[Response, Request] RequestDeserializerFn = Callable[[Request | None], object] | None ResponseSerializerFn = Callable[[bytes], Response | None] | None - HandlerMethod = Callable[[Request, BentoServicerContext], P_con] - - class HandlerFactoryProtocol(Protocol[P_con]): - def __call__( - self, - behaviour: HandlerMethod[P_con], - request_deserializer: RequestDeserializerFn = None, - response_serializer: ResponseSerializerFn = None, - ) -> grpc.RpcMethodHandler: - ... - - HandlerFactoryFn = HandlerFactoryProtocol[Any] + HandlerMethod = Callable[[Request, BentoServicerContext], P] class RpcMethodHandler( NamedTuple( @@ -47,10 +33,10 @@ class RpcMethodHandler( response_streaming=bool, request_deserializer=RequestDeserializerFn, response_serializer=ResponseSerializerFn, - unary_unary=Optional[aio.UnaryUnaryMultiCallable], - unary_stream=Optional[aio.UnaryStreamMultiCallable], - stream_unary=Optional[aio.StreamUnaryMultiCallable], - stream_stream=Optional[aio.StreamStreamMultiCallable], + unary_unary=Optional[HandlerMethod[Response]], + unary_stream=Optional[HandlerMethod[Response]], + stream_unary=Optional[HandlerMethod[Response]], + stream_stream=Optional[HandlerMethod[Response]], ), grpc.RpcMethodHandler, ): @@ -60,10 +46,10 @@ class RpcMethodHandler( response_streaming: bool request_deserializer: RequestDeserializerFn response_serializer: ResponseSerializerFn - unary_unary: Optional[aio.UnaryUnaryMultiCallable] - unary_stream: Optional[aio.UnaryStreamMultiCallable] - stream_unary: Optional[aio.StreamUnaryMultiCallable] - stream_stream: Optional[aio.StreamStreamMultiCallable] + unary_unary: Optional[HandlerMethod[Response]] + unary_stream: Optional[HandlerMethod[Response]] + stream_unary: Optional[HandlerMethod[Response]] + stream_stream: Optional[HandlerMethod[Response]] class HandlerCallDetails( NamedTuple("HandlerCallDetails", method=str, invocation_metadata=aio.Metadata), diff --git a/bentoml/_internal/server/grpc_app.py b/bentoml/_internal/server/grpc_app.py index 3036f9dea36..e7477b50730 100644 --- a/bentoml/_internal/server/grpc_app.py +++ b/bentoml/_internal/server/grpc_app.py @@ -33,12 +33,21 @@ class GRPCAppFactory: _is_ready: bool = False - def __init__(self, bento_service: Service, *, _thread_pool_size: int = 10) -> None: + @inject + def __init__( + self, + bento_service: Service, + *, + _thread_pool_size: int = 10, + maximum_concurrent_rpcs: int + | None = Provide[BentoMLContainer.grpc.maximum_concurrent_rpcs], + ) -> None: self.bento_service = bento_service self.server = aio.server( ThreadPoolExecutor(_thread_pool_size), interceptors=self.interceptors, options=self.options, + maximum_concurrent_rpcs=maximum_concurrent_rpcs, ) @property @@ -121,22 +130,30 @@ def interceptors( from opentelemetry.sdk.trace.export import ConsoleSpanExporter from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from .grpc.interceptors import ExceptionHandlerInterceptor - from .grpc.interceptors.access import AccessLogInterceptor + # from .grpc.interceptors import AccessLogInterceptor trace.set_tracer_provider(tracer_provider) trace.get_tracer_provider().add_span_processor( SimpleSpanProcessor(ConsoleSpanExporter()) ) + # from .grpc.interceptors.trace import ( + # AsyncOpenTelemetryServerInterceptor as OtelInterceptor, + # ) # TODO: prometheus interceptors. - interceptors: list[aio.ServerInterceptor] = [ - ExceptionHandlerInterceptor(), - AccessLogInterceptor(), - ] + # interceptors: list[aio.ServerInterceptor] = [OtelInterceptor()] + interceptors: list[aio.ServerInterceptor] = [] + + access_log_config = BentoMLContainer.api_server_config.logging.access + if access_log_config.enabled.get(): + from .grpc.interceptors import AccessLogInterceptor + + access_logger = logging.getLogger("bentoml.access") + if access_logger.getEffectiveLevel() <= logging.INFO: + interceptors.append( + AccessLogInterceptor(tracer_provider=tracer_provider) + ) # add users-defined interceptors. - interceptors.extend( - [interceptor() for interceptor in self.bento_service.interceptors] - ) + interceptors.extend(map(lambda x: x(), self.bento_service.interceptors)) return interceptors diff --git a/bentoml/_internal/utils/grpc/__init__.py b/bentoml/_internal/utils/grpc/__init__.py index b73d0f74940..a36952da8e4 100644 --- a/bentoml/_internal/utils/grpc/__init__.py +++ b/bentoml/_internal/utils/grpc/__init__.py @@ -19,12 +19,9 @@ from bentoml.io import IODescriptor from bentoml.grpc.v1 import service_pb2 + from ...server.grpc.types import Response from ...server.grpc.types import HandlerMethod - from ...server.grpc.types import HandlerFactoryFn from ...server.grpc.types import RpcMethodHandler - - # keep sync with bentoml.grpc.v1.service.Response - ContentsDict = dict[str, dict[str, t.Any]] else: service_pb2 = LazyLoader("service_pb2", globals(), "bentoml.grpc.v1.service_pb2") @@ -32,9 +29,7 @@ "grpc_status_code", "parse_method_name", "get_method_type", - "get_rpc_handler", "deserialize_proto", - "serialize_proto", ] logger = logging.getLogger(__name__) @@ -51,19 +46,13 @@ def deserialize_proto( if "preserving_proto_field_name" not in kwargs: kwargs.setdefault("preserving_proto_field_name", True) - kind = req.contents.WhichOneof("kind") - if kind not in io_descriptor.accepted_proto_kind: + kind = req.input.WhichOneof("kind") + if kind not in io_descriptor.accepted_proto_fields: raise UnprocessableEntity( - f"{kind} is not supported for {io_descriptor.__class__.__name__}. Supported message fields are: {io_descriptor.accepted_proto_kind}" + f"{kind} is not supported for {io_descriptor.__class__.__name__}. Supported protobuf message fields are: {io_descriptor.accepted_proto_fields}" ) - return kind, MessageToDict(getattr(req.contents, kind), **kwargs) - - -def serialize_proto(fields: str, contents_dict: ContentsDict) -> service_pb2.Response: - from google.protobuf.json_format import ParseDict - - return ParseDict({"contents": {fields: contents_dict}}, service_pb2.Response()) + return kind, MessageToDict(getattr(req.input, kind), **kwargs) _STATUS_CODE_MAPPING = { @@ -95,10 +84,8 @@ class MethodName: Represents a gRPC method name. Attributes: - package: This is defined by `package foo.bar`, - designation in the protocol buffer definition - service: service name in protocol buffer - definition (eg: service SearchService { ... }) + package: This is defined by `package foo.bar`, designation in the protocol buffer definition + service: service name in protocol buffer definition (eg: service SearchService { ... }) method: method name """ @@ -138,26 +125,20 @@ def get_method_type(request_streaming: bool, response_streaming: bool) -> str: return RpcMethodType.UNKNOWN -def get_rpc_handler( - handler: RpcMethodHandler, -) -> tuple[HandlerFactoryFn, HandlerMethod[t.Any]]: +def wrap_rpc_handler( + wrapper: t.Callable[[HandlerMethod[Response] | None], HandlerMethod[Response]], + handler: RpcMethodHandler | None, +) -> RpcMethodHandler | None: + if not handler: + return None + if not handler.request_streaming and not handler.response_streaming: - return grpc.unary_unary_rpc_method_handler, handler.unary_unary + return handler._replace(unary_unary=wrapper(handler.unary_unary)) elif not handler.request_streaming and handler.response_streaming: - return grpc.unary_stream_rpc_method_handler, handler.unary_stream + return handler._replace(unary_stream=wrapper(handler.unary_stream)) elif handler.request_streaming and not handler.response_streaming: - return grpc.stream_unary_rpc_method_handler, handler.stream_unary + return handler._replace(stream_unary=wrapper(handler.stream_unary)) elif handler.request_streaming and handler.response_streaming: - return grpc.stream_stream_rpc_method_handler, handler.stream_stream + return handler._replace(stream_stream=wrapper(handler.stream_stream)) else: raise BentoMLException(f"RPC method handler {handler} does not exist.") - - -def invoke_handler_factory( - fn: HandlerMethod[t.Any], factory: HandlerFactoryFn, handler: RpcMethodHandler -) -> t.Any: - return factory( - fn, - request_deserializer=handler.request_deserializer, - response_serializer=handler.response_serializer, - ) diff --git a/bentoml/grpc/v1/service.proto b/bentoml/grpc/v1/service.proto index f1ff06df797..6613e3d73ab 100644 --- a/bentoml/grpc/v1/service.proto +++ b/bentoml/grpc/v1/service.proto @@ -2,11 +2,9 @@ syntax = "proto3"; package bentoml.grpc.v1; -import "google/protobuf/struct.proto"; - // cc_enable_arenas pre-allocate memory for given message to improve speed. (C++ only) option cc_enable_arenas = true; -option cc_generic_services = true; +option cc_generic_services = false; option go_package = "github.com/bentoml/grpc/v1"; option java_multiple_files = true; option java_outer_classname = "ServiceProto"; @@ -17,103 +15,59 @@ option py_generic_services = true; // a gRPC BentoServer. service BentoService { // Infer handles unary API. - rpc Infer(Request) returns (Response) {} + rpc Call(Request) returns (Response) {} } -// Request for Infer. +// Request for Call. message Request { // a given API route the rpc request is sent to. string api_name = 1; - Value contents = 2; - -// TODO: -// The data contained in an input can be represented in -// "raw" bytes form or in the repeated type that matches the data type. -// Using the "raw" bytes form will typically allow higher performance due to the way protobuf -// allocation and reuse interacts with GRPC. -// For example, see https://github.com/grpc/grpc/issues/23231. -// bytes raw_bytes_contents = 3; + Value input = 2; } -// Response from Infer. +// Response from Call. message Response { // representation of the output value. - Value contents = 1; - -// TODO: -// The data contained in an input can be represented in -// "raw" bytes form or in the repeated type that matches the data type. -// Using the "raw" bytes form will typically allow higher performance due to the way protobuf -// allocation and reuse interacts with GRPC. -// For example, see https://github.com/grpc/grpc/issues/23231. -// bytes raw_bytes_contents = 2; + Value output = 1; } -// Represents a n-dimensional array. -// This is synonymous to NumpyNdarray IO Descriptor. -message NDArray { +// Represents an n-dimensional array. +message MultiDimensionalArray { // The shape of the array. repeated int32 shape = 1; - // The contents of the array as flattened. + // The flattened contents of the nd array. optional Array array = 2; } // This represents a 1-d array. message Array { - optional DataType dtype = 1; - - // Type specific representations that make it easy to create protos in - // all languages. The values hold the flattened representation of the inputs in row major order. - - // Serialized bytes contents. - // The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - // for quantized types, use bytes_contents - bytes bytes_contents = 2; - - // DT_BOOL - repeated bool bool_contents = 7 [packed = true]; - - // DT_FLOAT, DT_COMPLEX64 - // Note that since protobuf has no int16 type, we'll have some pointless zero padding for each value here. - repeated float float_contents = 3 [packed = true]; - - // DT_STRING - repeated string string_contents = 6; - - // DT_DOUBLE, DT_COMPLEX128 - repeated double double_contents = 4 [packed = true]; - - // DT_INT32, DT_INT16, DT_UINT16, DT_INT8, DT_UINT8, DT_HALF, DT_BFLOAT16 - repeated int32 int_contents = 5 [packed = true]; - - // DT_INT64 - repeated int64 long_contents = 8 [packed = true]; - - // DT_UINT32 - repeated uint32 uint32_val = 11 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 12 [packed = true]; - - // TODO: support single/double precision complex value type. - - // User can specify arbitrary struct that then can be parsed to numpy. - // DT_STRUCT - repeated google.protobuf.Struct struct_contents = 10; - - reserved 13, 14, 15; + repeated bool bool_values = 5 [packed = true]; + repeated float float_values = 1 [packed = true]; + repeated string string_values = 4; + repeated double double_values = 2 [packed = true]; + repeated int32 int_values = 3 [packed = true]; + repeated int64 long_values = 6 [packed = true]; + repeated uint32 uint32_values = 7 [packed = true]; + repeated uint64 uint64_values = 8 [packed = true]; + + // TODO: supports the following: + // - arbitrary structs + // - single/double precision complex value type. + // - quantized value type + // + // repeated google.protobuf.Struct struct_contents = 10; + + reserved 10 to 15; } -// Represents file types. -// This is synonymous to File IO Descriptor. -message File { - // type of file, let it be csv, text ,parquet, etc. +// Represents raw bytes types. +message Raw { + // type of file, let it be csv, text ,parquet, tensor type, etc. optional string kind = 1; + map metadata = 2; // contents of file as bytes. - bytes content = 2; + bytes content = 3; } // Represents a map value. @@ -127,69 +81,16 @@ message Value { oneof kind { // Text() string string_value = 1; - // File(), Image() - File file_value = 2; + // File(), Image(), raw byte forms of ndarray, dataframe + Raw raw_value = 2; // NDArray(), etc. Array array_value = 3; - // NDArray(), etc. - NDArray ndarray_value = 4; + // NDArray(), DataFrame(), etc. + MultiDimensionalArray multi_dimensional_array_value = 4; // DataFrame() MapValue map_value = 5; } + // We want to reserve these for future uses. reserved 56 to 100; } - -// Represents data type that can be passed to numpy. -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_UNSPECIFIED = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - - // int type - DT_INT64 = 9; - DT_INT32 = 3; - DT_INT16 = 5; - DT_INT8 = 6; - - // string type - DT_STRING = 7; - - // bool type - DT_BOOL = 10; - - // Represents half data type (32 -> 16 bits). - DT_HALF = 19; - // Float32 truncated to 16 bits. Only for cast ops. - DT_BFLOAT16 = 14; - - // Quantized int8 - DT_QINT8 = 11; - // Quantized uint8 - DT_QUINT8 = 12; - // Quantized int32 - DT_QINT32 = 13; - // Quantized int16 - DT_QINT16 = 15; - // Quantized uint16 - DT_QUINT16 = 16; - - // Double-precision complex - DT_COMPLEX128 = 18; - // Single-precision complex - DT_COMPLEX64 = 8; - - // unsigned int type - DT_UINT64 = 21; - DT_UINT32 = 20; - DT_UINT16 = 17; - DT_UINT8 = 4; - - // struct dtype - DT_STRUCT = 22; -}