From c0b500988c3575b87fc1e65d5b4d9b3a7c535c17 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sun, 31 Jul 2022 12:58:17 -0700 Subject: [PATCH] feat: adding OpenTelemetry interceptor (#2844) --- .../_internal/bento/build_dev_bentoml_whl.py | 16 +- bentoml/_internal/io_descriptors/numpy.py | 44 +-- bentoml/_internal/server/grpc/__init__.py | 5 +- .../server/grpc/interceptors/__init__.py | 130 ++++++--- .../server/grpc/interceptors/opentelemetry.py | 271 ++++++++++++++++++ .../server/grpc/interceptors/trace.py | 0 bentoml/_internal/server/grpc/server.py | 58 ++++ bentoml/_internal/server/grpc/servicer.py | 45 +-- bentoml/_internal/server/grpc/types.py | 4 +- bentoml/_internal/server/grpc_app.py | 73 ++--- bentoml/_internal/utils/grpc/__init__.py | 56 ++-- bentoml/_internal/utils/grpc/codec.py | 55 ++++ bentoml/grpc/v1/service.proto | 1 - 13 files changed, 592 insertions(+), 166 deletions(-) create mode 100644 bentoml/_internal/server/grpc/interceptors/opentelemetry.py delete mode 100644 bentoml/_internal/server/grpc/interceptors/trace.py create mode 100644 bentoml/_internal/utils/grpc/codec.py diff --git a/bentoml/_internal/bento/build_dev_bentoml_whl.py b/bentoml/_internal/bento/build_dev_bentoml_whl.py index 323f81dd21..d65b12e767 100644 --- a/bentoml/_internal/bento/build_dev_bentoml_whl.py +++ b/bentoml/_internal/bento/build_dev_bentoml_whl.py @@ -26,6 +26,14 @@ def build_bentoml_editable_wheel(target_path: str) -> None: # skip this entirely if BentoML is installed from PyPI return + try: + from build import ProjectBuilder + from build.env import IsolatedEnvBuilder + except ModuleNotFoundError: + raise BentoMLException( + f"`{BENTOML_DEV_BUILD}=True`, which requires the `pypa/build` package. Install development dependencies with `pip install -r requirements/dev-requirements.txt` and try again." + ) + # Find bentoml module path module_location = source_locations("bentoml") if not module_location: @@ -40,14 +48,6 @@ def build_bentoml_editable_wheel(target_path: str) -> None: logger.info( "BentoML is installed in `editable` mode; building BentoML distribution with the local BentoML code base. The built wheel file will be included in the target bento." ) - try: - from build import ProjectBuilder - from build.env import IsolatedEnvBuilder - except ModuleNotFoundError: - raise BentoMLException( - f"Environment variable {BENTOML_DEV_BUILD}=True detected, which requires the `pypa/build` package. Make sure to install all dev dependencies via `pip install -r requirements/dev-requirements.txt` and try again." - ) - with IsolatedEnvBuilder() as env: builder = ProjectBuilder(os.path.dirname(pyproject)) builder.python_executable = env.executable diff --git a/bentoml/_internal/io_descriptors/numpy.py b/bentoml/_internal/io_descriptors/numpy.py index 30c3463567..3138025586 100644 --- a/bentoml/_internal/io_descriptors/numpy.py +++ b/bentoml/_internal/io_descriptors/numpy.py @@ -218,26 +218,28 @@ def openapi_responses_schema(self) -> t.Dict[str, t.Any]: def _verify_ndarray( self, obj: "ext.NpNDArray", + dtype: np.dtype[t.Any] | None, + shape: tuple[int, ...] | None, exception_cls: t.Type[Exception] = BadInput, ) -> "ext.NpNDArray": - if self._dtype is not None and self._dtype != obj.dtype: + if dtype is not None and dtype != obj.dtype: # ‘same_kind’ means only safe casts or casts within a kind, like float64 # to float32, are allowed. - if np.can_cast(obj.dtype, self._dtype, casting="same_kind"): - obj = obj.astype(self._dtype, casting="same_kind") # type: ignore + if np.can_cast(obj.dtype, dtype, casting="same_kind"): + obj = obj.astype(dtype, casting="same_kind") # type: ignore else: - msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{obj.dtype}" was received.' + msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{dtype}", but "{obj.dtype}" was received.' if self._enforce_dtype: raise exception_cls(msg) else: logger.debug(msg) - if self._shape is not None and not _is_matched_shape(self._shape, obj.shape): - msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{obj.shape}" was received.' + if shape is not None and not _is_matched_shape(shape, obj.shape): + msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{shape}", but "{obj.shape}" was received.' if self._enforce_shape: raise exception_cls(msg) try: - obj = obj.reshape(self._shape) + obj = obj.reshape(shape) except ValueError as e: logger.debug(f"{msg} Failed to reshape: {e}.") @@ -260,7 +262,7 @@ async def from_http_request(self, request: Request) -> ext.NpNDArray: res = np.array(obj, dtype=self._dtype) except ValueError: res = np.array(obj) - return self._verify_ndarray(res, BadInput) + return self._verify_ndarray(res, dtype=self._dtype, shape=self._shape) async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None): """ @@ -273,7 +275,9 @@ async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None) HTTP Response of type `starlette.responses.Response`. This can be accessed via cURL or any external web traffic. """ - obj = self._verify_ndarray(obj, InternalServerError) + obj = self._verify_ndarray( + obj, dtype=self._dtype, shape=self._shape, exception_cls=InternalServerError + ) if ctx is not None: res = Response( json.dumps(obj.tolist()), @@ -313,6 +317,7 @@ async def from_grpc_request( ) if not self._shape: raise UnprocessableEntity("'shape' is required when 'packed' is set.") + metadata = serialized["metadata"] if not self._dtype: if "dtype" not in metadata: @@ -340,13 +345,11 @@ async def from_grpc_request( logger.warning( f"'shape={self._shape},enforce_shape={self._enforce_shape}' is set with {self.__class__.__name__}, while 'shape' field is present in request message. To avoid this warning, set 'enforce_shape=True'. Using 'shape={shape}' from request message." ) - self._shape = shape else: logger.debug( f"'enforce_shape={self._enforce_shape}', ignoring 'shape' field in request message." ) - else: - self._shape = shape + shape = self._shape array = serialized["array"] else: @@ -360,20 +363,18 @@ async def from_grpc_request( logger.warning( f"'dtype={self._dtype},enforce_dtype={self._enforce_dtype}' is set with {self.__class__.__name__}, while 'dtype' field is present in request message. To avoid this warning, set 'enforce_dtype=True'. Using 'dtype={dtype}' from request message." ) - self._dtype = dtype else: logger.debug( f"'enforce_dtype={self._enforce_dtype}', ignoring 'dtype' field in request message." ) - else: - self._dtype = dtype + dtype = self._dtype try: - res = np.array(content, dtype=self._dtype) + res = np.array(content, dtype=dtype) except ValueError: res = np.array(content) - return self._verify_ndarray(res, BadInput) + return self._verify_ndarray(res, dtype=dtype, shape=shape) async def to_grpc_response( self, obj: ext.NpNDArray, context: BentoServicerContext @@ -395,7 +396,12 @@ async def to_grpc_response( value_key = _NP_TO_VALUE_MAP[obj.dtype] try: - obj = self._verify_ndarray(obj, InternalServerError) + obj = self._verify_ndarray( + obj, + dtype=self._dtype, + shape=self._shape, + exception_cls=InternalServerError, + ) except InternalServerError as e: context.set_code(grpc_status_code(e)) context.set_details(e.message) @@ -411,7 +417,7 @@ async def to_grpc_response( ) value.raw_value.CopyFrom(raw) else: - if self._bytesorder: + if self._bytesorder and self._bytesorder != "C": logger.warning( f"'bytesorder={self._bytesorder}' is ignored when 'packed={self._packed}'." ) diff --git a/bentoml/_internal/server/grpc/__init__.py b/bentoml/_internal/server/grpc/__init__.py index ed466ca664..3ffb8f96b0 100644 --- a/bentoml/_internal/server/grpc/__init__.py +++ b/bentoml/_internal/server/grpc/__init__.py @@ -1,5 +1,4 @@ from .server import GRPCServer -from .servicer import register_bento_servicer -from .servicer import register_health_servicer +from .servicer import create_bento_servicer -__all__ = ["GRPCServer", "register_health_servicer", "register_bento_servicer"] +__all__ = ["GRPCServer", "create_bento_servicer"] diff --git a/bentoml/_internal/server/grpc/interceptors/__init__.py b/bentoml/_internal/server/grpc/interceptors/__init__.py index 379a71c0e4..1e94c19f36 100644 --- a/bentoml/_internal/server/grpc/interceptors/__init__.py +++ b/bentoml/_internal/server/grpc/interceptors/__init__.py @@ -6,36 +6,54 @@ from timeit import default_timer from typing import TYPE_CHECKING +import grpc from grpc import aio -from opentelemetry import trace +from ....utils import LazyLoader +from ....utils.grpc import ProtoCodec +from ....utils.grpc import to_http_status from ....utils.grpc import wrap_rpc_handler +from ....utils.grpc import get_grpc_content_type +from ....utils.grpc.codec import GRPC_CONTENT_TYPE if TYPE_CHECKING: - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.trace.span import Span + from grpc.aio._typing import MetadataType + + from bentoml.grpc.v1 import service_pb2 from ..types import Request from ..types import Response - from ..types import HandlerMethod from ..types import RpcMethodHandler + from ..types import AsyncHandlerMethod from ..types import HandlerCallDetails from ..types import BentoServicerContext + from ....utils.grpc.codec import Codec +else: + service_pb2 = LazyLoader("service_pb2", globals(), "bentoml.grpc.v1.service_pb2") logger = logging.getLogger(__name__) -class AccessLogInterceptor(aio.ServerInterceptor): +class GenericHeadersServerInterceptor(aio.ServerInterceptor): """ - An asyncio interceptors for access log. - - .. TODO: - - Add support for streaming RPCs. + A light header interceptor that provides some initial metadata to the client. + TODO: https://chromium.googlesource.com/external/github.com/grpc/grpc/+/HEAD/doc/PROTOCOL-HTTP2.md """ - def __init__(self, tracer_provider: TracerProvider) -> None: - self.logger = logging.getLogger("bentoml.access") - self.tracer_provider = tracer_provider + def __init__(self, *, codec: Codec | None = None): + if not codec: + # By default, we use ProtoCodec. + codec = ProtoCodec() + self._codec = codec + + def set_trailing_metadata(self, context: BentoServicerContext): + # We want to send some initial metadata to the client. + # gRPC doesn't use `:status` pseudo header to indicate success or failure + # of the current request. gRPC instead uses trailers for this purpose, and + # trailers are sent during `send_trailing_metadata` call + # For now we are sending over the content-type header. + headers = [("content-type", get_grpc_content_type(codec=self._codec))] + context.set_trailing_metadata(headers) async def intercept_service( self, @@ -43,41 +61,87 @@ async def intercept_service( handler_call_details: HandlerCallDetails, ) -> RpcMethodHandler: handler = await continuation(handler_call_details) + + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + # setup metadata + self.set_trailing_metadata(context) + + # for the rpc itself. + resp = behaviour(request, context) + if not hasattr(resp, "__aiter__"): + resp = await resp + return resp + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) + + +class AccessLogServerInterceptor(aio.ServerInterceptor): + """ + An asyncio interceptors for access log. + """ + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + logger = logging.getLogger("bentoml.access") + handler = await continuation(handler_call_details) method_name = handler_call_details.method if handler and (handler.response_streaming or handler.request_streaming): return handler - def wrapper( - behaviour: HandlerMethod[Response | t.AsyncGenerator[Response, None]] - ) -> t.Callable[..., t.Any]: + def wrapper(behaviour: AsyncHandlerMethod[Response]): @functools.wraps(behaviour) async def new_behaviour( request: Request, context: BentoServicerContext - ) -> Response: + ) -> Response | t.Awaitable[Response]: + + content_type = GRPC_CONTENT_TYPE - tracer = self.tracer_provider.get_tracer( - "opentelemetry.instrumentation.grpc" - ) - span: Span = tracer.start_span("grpc") - span_context = span.get_span_context() - kind = str(request.input.WhichOneof("kind")) + trailing_metadata: MetadataType | None = context.trailing_metadata() + if trailing_metadata: + trailing = dict(trailing_metadata) + content_type = trailing.get("content-type", GRPC_CONTENT_TYPE) start = default_timer() - with trace.use_span(span, end_on_exit=True): + try: response = behaviour(request, context) if not hasattr(response, "__aiter__"): response = await response - latency = max(default_timer() - start, 0) - - req_info = f"api_name={request.api_name},type={kind},size={request.input.ByteSize()}" - resp_info = f"status={context.code()},type={kind},size={response.output.ByteSize()}" - trace_and_span = f"trace={span_context.trace_id},span={span_context.span_id},sampled={1 if span_context.trace_flags.sampled else 0}" - - self.logger.info( - f"{context.peer()} ({req_info}) ({resp_info}) {latency:.3f}ms ({trace_and_span})" - ) - + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + response = service_pb2.Response() + finally: + latency = max(default_timer() - start, 0) + req = [ + "scheme=http", # TODO: support https when ssl is added + f"path={method_name}", + f"type={content_type}", + f"size={request.ByteSize()}", + ] + resp = [ + f"http_status={to_http_status(context.code())}", + f"grpc_status={context.code().value[0]}", + f"type={content_type}", + f"size={response.ByteSize()}", + ] + + # TODO: fix ports + logger.info( + f"{context.peer()} ({','.join(req)}) ({','.join(resp)}) {latency:.3f}ms" + ) return response return new_behaviour diff --git a/bentoml/_internal/server/grpc/interceptors/opentelemetry.py b/bentoml/_internal/server/grpc/interceptors/opentelemetry.py new file mode 100644 index 0000000000..c46ce96a83 --- /dev/null +++ b/bentoml/_internal/server/grpc/interceptors/opentelemetry.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import typing as t +import logging +import functools +from typing import TYPE_CHECKING +from contextlib import asynccontextmanager + +import grpc +from grpc import aio +from simple_di import inject +from simple_di import Provide +from opentelemetry import trace +from opentelemetry.context import attach +from opentelemetry.context import detach +from opentelemetry.propagate import extract +from opentelemetry.trace.status import Status +from opentelemetry.trace.status import StatusCode +from opentelemetry.semconv.trace import SpanAttributes + +from ....utils.pkg import get_pkg_version +from ....utils.grpc import wrap_rpc_handler +from ....utils.grpc import parse_method_name +from ....utils.grpc.codec import GRPC_CONTENT_TYPE +from ....configuration.containers import BentoMLContainer + +if TYPE_CHECKING: + from grpc.aio._typing import MetadataKey + from grpc.aio._typing import MetadataType + from grpc.aio._typing import MetadataValue + from opentelemetry.trace import Span + from opentelemetry.sdk.trace import TracerProvider + + from ..types import Request + from ..types import Response + from ..types import RpcMethodHandler + from ..types import AsyncHandlerMethod + from ..types import HandlerCallDetails + from ..types import BentoServicerContext + +logger = logging.getLogger(__name__) + + +class _OpenTelemetryServicerContext(aio.ServicerContext["Request", "Response"]): + def __init__(self, servicer_context: BentoServicerContext, active_span: Span): + self._servicer_context = servicer_context + self._active_span = active_span + self._code = grpc.StatusCode.OK + self._details = "" + + async def read(self) -> Request: + return await self._servicer_context.read() + + async def write(self, message: Response) -> None: + return await self._servicer_context.write(message) + + def trailing_metadata(self) -> aio.Metadata: + return self._servicer_context.trailing_metadata() # type: ignore (unfinished type) + + def auth_context(self) -> t.Mapping[str, t.Iterable[bytes]]: + return self._servicer_context.auth_context() + + def peer_identity_key(self) -> str | None: + return self._servicer_context.peer_identity_key() + + def peer_identities(self) -> t.Iterable[bytes] | None: + return self._servicer_context.peer_identities() + + def peer(self) -> str: + return self._servicer_context.peer() + + def disable_next_message_compression(self) -> None: + self._servicer_context.disable_next_message_compression() + + def set_compression(self, compression: grpc.Compression) -> None: + return self._servicer_context.set_compression(compression) + + def invocation_metadata(self) -> aio.Metadata | None: + return self._servicer_context.invocation_metadata() + + def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: + self._servicer_context.set_trailing_metadata(trailing_metadata) + + async def send_initial_metadata(self, initial_metadata: MetadataType) -> None: + return await self._servicer_context.send_initial_metadata(initial_metadata) + + async def abort( + self, + code: grpc.StatusCode, + details: str = "", + trailing_metadata: MetadataType = tuple(), + ) -> None: + self._code = code + self._details = details + self._active_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0] + ) + self._active_span.set_status( + Status(status_code=StatusCode.ERROR, description=f"{code}:{details}") + ) + return await self._servicer_context.abort( + code, details=details, trailing_metadata=trailing_metadata + ) + + def set_code(self, code: grpc.StatusCode) -> None: + self._code = code + details = self._details or code.value[1] + self._active_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0] + ) + if code != grpc.StatusCode.OK: + self._active_span.set_status( + Status(status_code=StatusCode.ERROR, description=f"{code}:{details}") + ) + return self._servicer_context.set_code(code) + + def code(self) -> grpc.StatusCode: + return self._code + + def set_details(self, details: str) -> None: + self._details = details + if self._code != grpc.StatusCode.OK: + self._active_span.set_status( + Status( + status_code=StatusCode.ERROR, description=f"{self._code}:{details}" + ) + ) + return self._servicer_context.set_details(details) + + def details(self) -> str: + return self._details + + +# Since opentelemetry doesn't provide an async implementation for the server interceptor, +# we will need to create an async implementation ourselves. +# By doing this we will have more control over how to handle span and context propagation. +# +# Until there is a solution upstream, this implementation is sufficient for our needs. +class AsyncOpenTelemetryServerInterceptor(aio.ServerInterceptor): + @inject + def __init__( + self, + *, + tracer_provider: TracerProvider = Provide[BentoMLContainer.tracer_provider], + schema_url: str | None = None, + ): + self._tracer = tracer_provider.get_tracer( + "opentelemetry.instrumentation.grpc", + get_pkg_version("opentelemetry-instrumentation-grpc"), + schema_url=schema_url, + ) + + @asynccontextmanager + async def set_remote_context( + self, servicer_context: BentoServicerContext + ) -> t.AsyncGenerator[None, None]: + metadata = servicer_context.invocation_metadata() + if not metadata: + yield + md: dict[MetadataKey, MetadataValue] = {m.key: m.value for m in metadata} + ctx = extract(md) + token = attach(ctx) + try: + yield + finally: + detach(token) + + def start_span( + self, + method_name: str, + context: BentoServicerContext, + set_status_on_exception: bool = False, + ) -> t.ContextManager[Span]: + attributes: dict[str, str | bytes] = { + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0], + } + + # method_name shouldn't be none, otherwise + # it will never reach this point. + method_rpc, _ = parse_method_name(method_name) + attributes.update( + { + SpanAttributes.RPC_METHOD: method_rpc.method, + SpanAttributes.RPC_SERVICE: method_rpc.fully_qualified_service, + } + ) + + # add some attributes from the metadata + metadata = context.invocation_metadata() + if metadata: + dct: dict[str, str | bytes] = dict(metadata) + if "user-agent" in dct: + attributes["rpc.user_agent"] = dct["user-agent"] + + # get trailing metadata + trailing_metadata: MetadataType | None = context.trailing_metadata() + if trailing_metadata: + trailing = dict(trailing_metadata) + attributes["rpc.content_type"] = trailing.get( + "content-type", GRPC_CONTENT_TYPE + ) + + # Split up the peer to keep with how other telemetry sources + # do it. This looks like: + # * ipv6:[::1]:57284 + # * ipv4:127.0.0.1:57284 + # * ipv4:10.2.1.1:57284,127.0.0.1:57284 + # + # the process ip and port would be [::1] 57284 + try: + ipv4_addr = context.peer().split(",")[0] + ip, port = ipv4_addr.split(":", 1)[1].rsplit(":", 1) + attributes.update( + { + SpanAttributes.NET_PEER_IP: ip, + SpanAttributes.NET_PEER_PORT: port, + } + ) + # other telemetry sources add this, so we will too + if ip in ("[::1]", "127.0.0.1"): + attributes[SpanAttributes.NET_PEER_NAME] = "localhost" + except IndexError: + logger.warning("Failed to parse peer address '%s'", context.peer()) + + return self._tracer.start_as_current_span( + name=method_name, + kind=trace.SpanKind.SERVER, + attributes=attributes, + set_status_on_exception=set_status_on_exception, + ) + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + handler = await continuation(handler_call_details) + method_name = handler_call_details.method + + # Currently not support streaming RPCs. + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + + async with self.set_remote_context(context): + with self.start_span(method_name, context) as span: + # wrap context + wrapped_context = _OpenTelemetryServicerContext(context, span) + + # And now we run the actual RPC. + try: + response = behaviour(request, wrapped_context) + if not hasattr(response, "__aiter__"): + response = await response + return response + except Exception as e: + # We are interested in uncaught exception, otherwise + # it will be handled by gRPC. + if type(e) != Exception: + span.record_exception(e) + raise e + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) diff --git a/bentoml/_internal/server/grpc/interceptors/trace.py b/bentoml/_internal/server/grpc/interceptors/trace.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/bentoml/_internal/server/grpc/server.py b/bentoml/_internal/server/grpc/server.py index 1a4be0ed10..5be3955bfa 100644 --- a/bentoml/_internal/server/grpc/server.py +++ b/bentoml/_internal/server/grpc/server.py @@ -3,14 +3,36 @@ import typing as t import asyncio import logging +from typing import TYPE_CHECKING import grpc from grpc import aio +from bentoml.exceptions import MissingDependencyException + +from ...utils import LazyLoader from ...utils import cached_property logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from grpc_health.v1 import health + from grpc_health.v1 import health_pb2 + from grpc_health.v1 import health_pb2_grpc + + from bentoml.grpc.v1 import service_pb2 + from bentoml.grpc.v1 import service_pb2_grpc +else: + service_pb2 = LazyLoader("service_pb2", globals(), "bentoml.grpc.v1.service_pb2") + service_pb2_grpc = LazyLoader( + "service_pb2_grpc", globals(), "bentoml.grpc.v1.service_pb2_grpc" + ) + health = LazyLoader("health", globals(), "grpc_health.v1.health") + health_pb2 = LazyLoader("health_pb2", globals(), "grpc_health.v1.health_pb2") + health_pb2_grpc = LazyLoader( + "health_pb2_grpc", globals(), "grpc_health.v1.health_pb2_grpc" + ) + class GRPCServer: """An ASGI-like implementation for async gRPC server.""" @@ -22,7 +44,11 @@ def __init__( on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None, *, _grace_period: int = 5, + _bento_servicer: service_pb2_grpc.BentoServiceServicer, + _health_servicer: health.aio.HealthServicer, ): + self._bento_servicer = _bento_servicer + self._health_servicer = _health_servicer self._grace_period = _grace_period self.server = server @@ -54,6 +80,14 @@ async def serve(self, bind_addr: str) -> None: await self.wait_for_termination() async def startup(self) -> None: + try: + # reflection is required for health checking to work. + from grpc_reflection.v1alpha import reflection + except ImportError: + raise MissingDependencyException( + "reflection is enabled, which requires 'grpcio-reflection' to be installed. Install with `pip install 'grpcio-relfection'.`" + ) + # Running on_startup callback. for handler in self.on_startup: if asyncio.iscoroutinefunction(handler): @@ -61,6 +95,24 @@ async def startup(self) -> None: else: handler() + # register bento servicer + service_pb2_grpc.add_BentoServiceServicer_to_server( + self._bento_servicer, self.server # type: ignore (unfinished async types) + ) + health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server) + + services = tuple( + service.full_name + for service in service_pb2.DESCRIPTOR.services_by_name.values() + ) + (health.SERVICE_NAME, reflection.SERVICE_NAME) + reflection.enable_server_reflection(services, self.server) + + # mark all services as healthy + for service in services: + await self._health_servicer.set( + service, health_pb2.HealthCheckResponse.SERVING # type: ignore (no types available) + ) + await self.server.start() async def shutdown(self): @@ -72,6 +124,7 @@ async def shutdown(self): handler() await self.server.stop(grace=self._grace_period) + await self._health_servicer.enter_graceful_shutdown() async def wait_for_termination(self, timeout: int | None = None) -> bool: return await self.server.wait_for_termination(timeout=timeout) @@ -81,3 +134,8 @@ def add_insecure_port(self, address: str) -> int: def add_secure_port(self, address: str, credentials: grpc.ServerCredentials) -> int: return self.server.add_secure_port(address, credentials) + + def add_generic_rpc_handlers( + self, generic_rpc_handlers: t.Sequence[grpc.GenericRpcHandler] + ) -> None: + self.server.add_generic_rpc_handlers(generic_rpc_handlers) diff --git a/bentoml/_internal/server/grpc/servicer.py b/bentoml/_internal/server/grpc/servicer.py index c8e638099f..533171ac42 100644 --- a/bentoml/_internal/server/grpc/servicer.py +++ b/bentoml/_internal/server/grpc/servicer.py @@ -7,11 +7,9 @@ import grpc import anyio -from grpc import aio from bentoml.exceptions import BentoMLException from bentoml.exceptions import UnprocessableEntity -from bentoml.exceptions import MissingDependencyException from bentoml._internal.service.service import Service from ...utils import LazyLoader @@ -34,10 +32,11 @@ def log_exception(request: _service_pb2.Request, exc_info: ExcInfoType) -> None: - logger.error(f"Exception on /{request.api_name}", exc_info=exc_info) + # gRPC will always send a POST request. + logger.error(f"Exception on /{request.api_name} [POST]", exc_info=exc_info) -def register_bento_servicer(service: Service, server: aio.Server) -> None: +def create_bento_servicer(service: Service) -> _service_pb2_grpc.BentoServiceServicer: """ This is the actual implementation of BentoServicer. Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call @@ -85,40 +84,4 @@ async def Call( # type: ignore (no async types) ) return response - _service_pb2_grpc.add_BentoServiceServicer_to_server(BentoServiceServicer(), server) # type: ignore (lack of asyncio types) - - -async def register_health_servicer(server: aio.Server) -> None: - from bentoml.grpc.v1 import service_pb2 - - try: - from grpc_health.v1 import health - from grpc_health.v1 import health_pb2 - from grpc_health.v1 import health_pb2_grpc - except ImportError: - raise MissingDependencyException( - "'grpcio-health-checking' is required for using health checking endpoints. Install with `pip install grpcio-health-checking`." - ) - try: - # reflection is required for health checking to work. - from grpc_reflection.v1alpha import reflection - except ImportError: - raise MissingDependencyException( - "reflection is enabled, which requires 'grpcio-reflection' to be installed. Install with `pip install 'grpcio-relfection'.`" - ) - - # Create a health check servicer. We use the non-blocking implementation - # to avoid thread starvation. - health_servicer = health.aio.HealthServicer() - health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) - - # create a list of service we want to export for health checking. - services = tuple( - service.full_name - for service in service_pb2.DESCRIPTOR.services_by_name.values() - ) + (health.SERVICE_NAME, reflection.SERVICE_NAME) - reflection.enable_server_reflection(services, server) - - # mark all services as healthy - for service in services: - await health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) # type: ignore (unfinished grpcio-health-checking type) + return BentoServiceServicer() diff --git a/bentoml/_internal/server/grpc/types.py b/bentoml/_internal/server/grpc/types.py index d891dc2f4b..e73fa62960 100644 --- a/bentoml/_internal/server/grpc/types.py +++ b/bentoml/_internal/server/grpc/types.py @@ -6,6 +6,7 @@ from typing import TypeVar from typing import Callable from typing import Optional +from typing import Awaitable from typing import NamedTuple from typing import TYPE_CHECKING @@ -19,12 +20,13 @@ P = TypeVar("P") - BentoServicerContext = aio.ServicerContext[Response, Request] + BentoServicerContext = aio.ServicerContext[Request, Response] RequestDeserializerFn = Callable[[Request | None], object] | None ResponseSerializerFn = Callable[[bytes], Response | None] | None HandlerMethod = Callable[[Request, BentoServicerContext], P] + AsyncHandlerMethod = Callable[[Request, BentoServicerContext], Awaitable[P]] class RpcMethodHandler( NamedTuple( diff --git a/bentoml/_internal/server/grpc_app.py b/bentoml/_internal/server/grpc_app.py index e7477b5073..1db9796f5e 100644 --- a/bentoml/_internal/server/grpc_app.py +++ b/bentoml/_internal/server/grpc_app.py @@ -10,14 +10,14 @@ from simple_di import inject from simple_di import Provide +from bentoml.exceptions import MissingDependencyException + from .grpc.server import GRPCServer from ..configuration.containers import BentoMLContainer logger = logging.getLogger(__name__) if TYPE_CHECKING: - from opentelemetry.trace import TracerProvider - from ..service import Service OnStartup = list[t.Callable[[], None | t.Coroutine[t.Any, t.Any, None]]] @@ -59,9 +59,6 @@ def mark_as_ready(self) -> None: @property def on_startup(self) -> OnStartup: - from .grpc import register_bento_servicer - from .grpc import register_health_servicer - on_startup: OnStartup = [ self.mark_as_ready, self.bento_service.on_grpc_server_startup, @@ -73,17 +70,6 @@ def on_startup(self) -> OnStartup: for runner in self.bento_service.runners: on_startup.append(runner.init_client) - on_startup.extend( - [ - functools.partial( - register_bento_servicer, - service=self.bento_service, - server=self.server, - ), - functools.partial(register_health_servicer, server=self.server), - ] - ) - return on_startup @property @@ -95,10 +81,25 @@ def on_shutdown(self) -> list[t.Callable[[], None]]: return on_shutdown def __call__(self) -> GRPCServer: + try: + from grpc_health.v1 import health + except ImportError: + raise MissingDependencyException( + "'grpcio-health-checking' is required for using health checking endpoints. Install with `pip install grpcio-health-checking`." + ) + from .grpc.servicer import create_bento_servicer + + # Create a health check servicer. We use the non-blocking implementation + # to avoid thread starvation. + health_servicer = health.aio.HealthServicer() + bento_servicer = create_bento_servicer(self.bento_service) + return GRPCServer( server=self.server, on_startup=self.on_startup, on_shutdown=self.on_shutdown, + _health_servicer=health_servicer, + _bento_servicer=bento_servicer, ) @property @@ -120,40 +121,28 @@ def options( return options @property - @inject - def interceptors( - self, - *, - tracer_provider: TracerProvider = Provide[BentoMLContainer.tracer_provider], - ) -> list[aio.ServerInterceptor]: - from opentelemetry import trace - from opentelemetry.sdk.trace.export import ConsoleSpanExporter - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - - # from .grpc.interceptors import AccessLogInterceptor - - trace.set_tracer_provider(tracer_provider) - trace.get_tracer_provider().add_span_processor( - SimpleSpanProcessor(ConsoleSpanExporter()) + def interceptors(self) -> list[aio.ServerInterceptor]: + # Note that order of interceptors is important here. + from .grpc.interceptors import GenericHeadersServerInterceptor + from .grpc.interceptors.opentelemetry import ( + AsyncOpenTelemetryServerInterceptor as AsyncOtelInterceptor, ) - # from .grpc.interceptors.trace import ( - # AsyncOpenTelemetryServerInterceptor as OtelInterceptor, - # ) # TODO: prometheus interceptors. - # interceptors: list[aio.ServerInterceptor] = [OtelInterceptor()] - interceptors: list[aio.ServerInterceptor] = [] + interceptors: list[t.Type[aio.ServerInterceptor]] = [ + GenericHeadersServerInterceptor, + AsyncOtelInterceptor, + ] access_log_config = BentoMLContainer.api_server_config.logging.access if access_log_config.enabled.get(): - from .grpc.interceptors import AccessLogInterceptor + from .grpc.interceptors import AccessLogServerInterceptor access_logger = logging.getLogger("bentoml.access") if access_logger.getEffectiveLevel() <= logging.INFO: - interceptors.append( - AccessLogInterceptor(tracer_provider=tracer_provider) - ) + interceptors.append(AccessLogServerInterceptor) # add users-defined interceptors. - interceptors.extend(map(lambda x: x(), self.bento_service.interceptors)) - return interceptors + interceptors.extend(self.bento_service.interceptors) + + return list(map(lambda x: x(), interceptors)) diff --git a/bentoml/_internal/utils/grpc/__init__.py b/bentoml/_internal/utils/grpc/__init__.py index a36952da8e..a355a485e7 100644 --- a/bentoml/_internal/utils/grpc/__init__.py +++ b/bentoml/_internal/utils/grpc/__init__.py @@ -12,6 +12,8 @@ from bentoml.exceptions import BentoMLException from bentoml.exceptions import UnprocessableEntity +from .codec import ProtoCodec +from .codec import get_grpc_content_type from ..lazy_loader import LazyLoader if TYPE_CHECKING: @@ -19,8 +21,6 @@ from bentoml.io import IODescriptor from bentoml.grpc.v1 import service_pb2 - from ...server.grpc.types import Response - from ...server.grpc.types import HandlerMethod from ...server.grpc.types import RpcMethodHandler else: service_pb2 = LazyLoader("service_pb2", globals(), "bentoml.grpc.v1.service_pb2") @@ -28,8 +28,10 @@ __all__ = [ "grpc_status_code", "parse_method_name", - "get_method_type", "deserialize_proto", + "to_http_status", + "get_grpc_content_type", + "ProtoCodec", ] logger = logging.getLogger(__name__) @@ -55,10 +57,18 @@ def deserialize_proto( return kind, MessageToDict(getattr(req.input, kind), **kwargs) +# Maps HTTP status code to grpc.StatusCode _STATUS_CODE_MAPPING = { + HTTPStatus.OK: grpc.StatusCode.OK, + HTTPStatus.UNAUTHORIZED: grpc.StatusCode.UNAUTHENTICATED, + HTTPStatus.FORBIDDEN: grpc.StatusCode.PERMISSION_DENIED, + HTTPStatus.NOT_FOUND: grpc.StatusCode.UNIMPLEMENTED, + HTTPStatus.TOO_MANY_REQUESTS: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.BAD_GATEWAY: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT: grpc.StatusCode.DEADLINE_EXCEEDED, HTTPStatus.BAD_REQUEST: grpc.StatusCode.INVALID_ARGUMENT, HTTPStatus.INTERNAL_SERVER_ERROR: grpc.StatusCode.INTERNAL, - HTTPStatus.NOT_FOUND: grpc.StatusCode.NOT_FOUND, HTTPStatus.UNPROCESSABLE_ENTITY: grpc.StatusCode.FAILED_PRECONDITION, } @@ -70,6 +80,18 @@ def grpc_status_code(err: BentoMLException) -> grpc.StatusCode: return _STATUS_CODE_MAPPING.get(err.error_code, grpc.StatusCode.UNKNOWN) +def to_http_status(status_code: grpc.StatusCode) -> int: + """ + Convert grpc.StatusCode to HTTPStatus. + """ + try: + status = {v: k for k, v in _STATUS_CODE_MAPPING.items()}[status_code] + except KeyError: + status = HTTPStatus.INTERNAL_SERVER_ERROR + + return status.value + + class RpcMethodType(str, enum.Enum): UNARY = "UNARY" CLIENT_STREAMING = "CLIENT_STREAMING" @@ -112,33 +134,31 @@ def parse_method_name(method_name: str) -> tuple[MethodName, bool]: return MethodName(package, service, method), True -def get_method_type(request_streaming: bool, response_streaming: bool) -> str: - if not request_streaming and not response_streaming: - return RpcMethodType.UNARY - elif not request_streaming and response_streaming: - return RpcMethodType.SERVER_STREAMING - elif request_streaming and not response_streaming: - return RpcMethodType.CLIENT_STREAMING - elif request_streaming and response_streaming: - return RpcMethodType.BIDI_STREAMING - else: - return RpcMethodType.UNKNOWN - - def wrap_rpc_handler( - wrapper: t.Callable[[HandlerMethod[Response] | None], HandlerMethod[Response]], + wrapper: t.Callable[..., t.Any], handler: RpcMethodHandler | None, ) -> RpcMethodHandler | None: if not handler: return None + # The reason we are using TYPE_CHECKING for assert here + # is that if the following bool request_streaming and response_streaming + # are set, then it is guaranteed that RpcMethodHandler are not None. if not handler.request_streaming and not handler.response_streaming: + if TYPE_CHECKING: + assert handler.unary_unary return handler._replace(unary_unary=wrapper(handler.unary_unary)) elif not handler.request_streaming and handler.response_streaming: + if TYPE_CHECKING: + assert handler.unary_stream return handler._replace(unary_stream=wrapper(handler.unary_stream)) elif handler.request_streaming and not handler.response_streaming: + if TYPE_CHECKING: + assert handler.stream_unary return handler._replace(stream_unary=wrapper(handler.stream_unary)) elif handler.request_streaming and handler.response_streaming: + if TYPE_CHECKING: + assert handler.stream_stream return handler._replace(stream_stream=wrapper(handler.stream_stream)) else: raise BentoMLException(f"RPC method handler {handler} does not exist.") diff --git a/bentoml/_internal/utils/grpc/codec.py b/bentoml/_internal/utils/grpc/codec.py new file mode 100644 index 0000000000..cef98769b1 --- /dev/null +++ b/bentoml/_internal/utils/grpc/codec.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import typing as t +from abc import ABC +from abc import abstractmethod +from typing import TYPE_CHECKING + +from typing_extensions import Self + +if TYPE_CHECKING: + from google.protobuf.message import Message + +# content-type is always application/grpc +GRPC_CONTENT_TYPE = "application/grpc" + + +class Codec(ABC): + _content_subtype: str + + def __new__(cls: type[Self]) -> Self: + obj = object.__new__(cls) + if not cls._content_subtype: + raise TypeError(f"{cls} should have a '_content_subtype' attribute") + obj.__setattr__("_content_subtype", cls._content_subtype) + return obj + + @property + def content_type(self) -> str: + return self._content_subtype + + @abstractmethod + def encode(self, message: t.Any, message_type: t.Type[Message]) -> bytes: + # TODO: We will want to use this to encode headers message. + pass + + @abstractmethod + def decode(self, data: bytes, message_type: t.Type[Message]) -> t.Any: + # TODO: We will want to use this to decode headers message. + pass + + +class ProtoCodec(Codec): + _content_subtype: str = "proto" + + def encode(self, message: t.Any, message_type: t.Type[Message]) -> bytes: + if not isinstance(message, message_type): + raise TypeError(f"message should be a {message_type}, got {type(message)}.") + return message.SerializeToString() + + def decode(self, data: bytes, message_type: t.Type[Message]) -> t.Any: + return message_type.FromString(data) + + +def get_grpc_content_type(codec: Codec | None = None) -> str: + return f"{GRPC_CONTENT_TYPE}" + f"+{codec.content_type}" if codec else "" diff --git a/bentoml/grpc/v1/service.proto b/bentoml/grpc/v1/service.proto index 6613e3d73a..bb145a321a 100644 --- a/bentoml/grpc/v1/service.proto +++ b/bentoml/grpc/v1/service.proto @@ -4,7 +4,6 @@ package bentoml.grpc.v1; // cc_enable_arenas pre-allocate memory for given message to improve speed. (C++ only) option cc_enable_arenas = true; -option cc_generic_services = false; option go_package = "github.com/bentoml/grpc/v1"; option java_multiple_files = true; option java_outer_classname = "ServiceProto";