diff --git a/bentoml/_internal/server/grpc/interceptors/__init__.py b/bentoml/_internal/server/grpc/interceptors/__init__.py deleted file mode 100644 index 8902f563f74..00000000000 --- a/bentoml/_internal/server/grpc/interceptors/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -import typing as t -import functools -from typing import TYPE_CHECKING - -from grpc import aio - -from bentoml.grpc.utils import wrap_rpc_handler -from bentoml.grpc.utils import GRPC_CONTENT_TYPE - -if TYPE_CHECKING: - - from bentoml.grpc.types import Request - from bentoml.grpc.types import Response - from bentoml.grpc.types import RpcMethodHandler - from bentoml.grpc.types import AsyncHandlerMethod - from bentoml.grpc.types import HandlerCallDetails - from bentoml.grpc.types import BentoServicerContext - - -class GenericHeadersServerInterceptor(aio.ServerInterceptor): - """ - A light header interceptor that provides some initial metadata to the client. - Refers to https://chromium.googlesource.com/external/github.com/grpc/grpc/+/HEAD/doc/PROTOCOL-HTTP2.md - """ - - def __init__(self, *, message_format: str | None = None): - if not message_format: - # By default, we are sending proto message. - message_format = "proto" - self._content_type = f"{GRPC_CONTENT_TYPE}+{message_format}" - - def set_trailing_metadata(self, context: BentoServicerContext): - # We want to send some initial metadata to the client. - # gRPC doesn't use `:status` pseudo header to indicate success or failure - # of the current request. gRPC instead uses trailers for this purpose, and - # trailers are sent during `send_trailing_metadata` call - # For now we are sending over the content-type header. - context.set_trailing_metadata((("content-type", self._content_type),)) - - async def intercept_service( - self, - continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], - handler_call_details: HandlerCallDetails, - ) -> RpcMethodHandler: - handler = await continuation(handler_call_details) - - if handler and (handler.response_streaming or handler.request_streaming): - return handler - - def wrapper(behaviour: AsyncHandlerMethod[Response]): - @functools.wraps(behaviour) - async def new_behaviour( - request: Request, context: BentoServicerContext - ) -> Response | t.Awaitable[Response]: - # setup metadata - self.set_trailing_metadata(context) - - # for the rpc itself. - return await behaviour(request, context) - - return new_behaviour - - return t.cast("RpcMethodHandler", wrap_rpc_handler(wrapper, handler)) diff --git a/bentoml/grpc/buf.yaml b/bentoml/grpc/buf.yaml index 1c14efa6950..4e26d9abb86 100644 --- a/bentoml/grpc/buf.yaml +++ b/bentoml/grpc/buf.yaml @@ -12,9 +12,9 @@ lint: - RPC_RESPONSE_STANDARD_NAME ignore_only: DEFAULT: - - bentoml/grpc/v1/service_test.proto + - bentoml/grpc/v1alpha1/service_test.proto ENUM_VALUE_PREFIX: - - bentoml/grpc/v1/service.proto + - bentoml/grpc/v1alpha1/service.proto enum_zero_value_suffix: _UNSPECIFIED rpc_allow_same_request_response: true rpc_allow_google_protobuf_empty_requests: true diff --git a/bentoml/grpc/interceptors/__init__.py b/bentoml/grpc/interceptors/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/bentoml/_internal/server/grpc/interceptors/access.py b/bentoml/grpc/interceptors/access.py similarity index 87% rename from bentoml/_internal/server/grpc/interceptors/access.py rename to bentoml/grpc/interceptors/access.py index d70e493ddb8..a89cc0f0a15 100644 --- a/bentoml/_internal/server/grpc/interceptors/access.py +++ b/bentoml/grpc/interceptors/access.py @@ -6,15 +6,14 @@ from timeit import default_timer from typing import TYPE_CHECKING -import grpc -from grpc import aio - from bentoml.grpc.utils import to_http_status from bentoml.grpc.utils import wrap_rpc_handler from bentoml.grpc.utils import GRPC_CONTENT_TYPE if TYPE_CHECKING: - from grpc.aio._typing import MetadataType + import grpc + from grpc import aio + from grpc.aio._typing import MetadataType # pylint: disable=unused-import from bentoml.grpc.types import Request from bentoml.grpc.types import Response @@ -24,14 +23,16 @@ from bentoml.grpc.types import BentoServicerContext from bentoml.grpc.v1alpha1 import service_pb2 as pb else: + from bentoml.grpc.utils import import_grpc from bentoml.grpc.utils import import_generated_stubs pb, _ = import_generated_stubs() + grpc, aio = import_grpc() class AccessLogServerInterceptor(aio.ServerInterceptor): """ - An asyncio interceptor for access log. + An asyncio interceptor for access logging. """ async def intercept_service( @@ -51,16 +52,13 @@ def wrapper(behaviour: AsyncHandlerMethod[Response]): async def new_behaviour( request: Request, context: BentoServicerContext ) -> Response | t.Awaitable[Response]: - content_type = GRPC_CONTENT_TYPE - trailing_metadata: MetadataType | None = context.trailing_metadata() if trailing_metadata: trailing = dict(trailing_metadata) content_type = trailing.get("content-type", GRPC_CONTENT_TYPE) response = pb.Response() - start = default_timer() try: response = await behaviour(request, context) @@ -68,8 +66,6 @@ async def new_behaviour( context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(e)) finally: - if TYPE_CHECKING: - assert response latency = max(default_timer() - start, 0) * 1000 req = [ @@ -79,6 +75,9 @@ async def new_behaviour( f"size={request.ByteSize()}", ] + # Note that in order AccessLogServerInterceptor to work, the + # interceptor must be added to the server after AsyncOpenTeleServerInterceptor + # and PrometheusServerInterceptor. typed_context_code = t.cast(grpc.StatusCode, context.code()) resp = [ f"http_status={to_http_status(typed_context_code)}", diff --git a/bentoml/_internal/server/grpc/interceptors/opentelemetry.py b/bentoml/grpc/interceptors/opentelemetry.py similarity index 93% rename from bentoml/_internal/server/grpc/interceptors/opentelemetry.py rename to bentoml/grpc/interceptors/opentelemetry.py index 4c8fc00ce92..b36b19a12c4 100644 --- a/bentoml/_internal/server/grpc/interceptors/opentelemetry.py +++ b/bentoml/grpc/interceptors/opentelemetry.py @@ -6,8 +6,6 @@ from typing import TYPE_CHECKING from contextlib import asynccontextmanager -import grpc -from grpc import aio from simple_di import inject from simple_di import Provide from opentelemetry import trace @@ -21,11 +19,12 @@ from bentoml.grpc.utils import wrap_rpc_handler from bentoml.grpc.utils import GRPC_CONTENT_TYPE from bentoml.grpc.utils import parse_method_name - -from ....utils.pkg import get_pkg_version -from ....configuration.containers import BentoMLContainer +from bentoml._internal.utils.pkg import get_pkg_version +from bentoml._internal.configuration.containers import BentoMLContainer if TYPE_CHECKING: + import grpc + from grpc import aio from grpc.aio._typing import MetadataKey from grpc.aio._typing import MetadataType from grpc.aio._typing import MetadataValue @@ -38,6 +37,10 @@ from bentoml.grpc.types import AsyncHandlerMethod from bentoml.grpc.types import HandlerCallDetails from bentoml.grpc.types import BentoServicerContext +else: + from bentoml.grpc.utils import import_grpc + + grpc, aio = import_grpc() logger = logging.getLogger(__name__) @@ -48,6 +51,10 @@ def __init__(self, servicer_context: BentoServicerContext, active_span: Span): self._active_span = active_span self._code = grpc.StatusCode.OK self._details = "" + super().__init__() + + def __getattr__(self, attr: str) -> t.Any: + return getattr(self._servicer_context, attr) async def read(self) -> Request: return await self._servicer_context.read() @@ -156,15 +163,16 @@ async def set_remote_context( self, servicer_context: BentoServicerContext ) -> t.AsyncGenerator[None, None]: metadata = servicer_context.invocation_metadata() - if not metadata: - yield - md: dict[MetadataKey, MetadataValue] = {m.key: m.value for m in metadata} - ctx = extract(md) - token = attach(ctx) - try: + if metadata: + md: dict[MetadataKey, MetadataValue] = {m.key: m.value for m in metadata} + ctx = extract(md) + token = attach(ctx) + try: + yield + finally: + detach(token) + else: yield - finally: - detach(token) def start_span( self, diff --git a/bentoml/_internal/server/grpc/interceptors/prometheus.py b/bentoml/grpc/interceptors/prometheus.py similarity index 84% rename from bentoml/_internal/server/grpc/interceptors/prometheus.py rename to bentoml/grpc/interceptors/prometheus.py index 0f2303628b0..576a61d7e6e 100644 --- a/bentoml/_internal/server/grpc/interceptors/prometheus.py +++ b/bentoml/grpc/interceptors/prometheus.py @@ -7,19 +7,19 @@ from timeit import default_timer from typing import TYPE_CHECKING -import grpc -from grpc import aio from simple_di import inject from simple_di import Provide from bentoml.grpc.utils import to_http_status from bentoml.grpc.utils import wrap_rpc_handler - -from ....configuration.containers import BentoMLContainer +from bentoml._internal.configuration.containers import BentoMLContainer START_TIME_VAR: contextvars.ContextVar[float] = contextvars.ContextVar("START_TIME_VAR") if TYPE_CHECKING: + import grpc + from grpc import aio + from bentoml.grpc.types import Request from bentoml.grpc.types import Response from bentoml.grpc.types import RpcMethodHandler @@ -27,13 +27,14 @@ from bentoml.grpc.types import HandlerCallDetails from bentoml.grpc.types import BentoServicerContext from bentoml.grpc.v1alpha1 import service_pb2 as pb - - from ....service import Service - from ...metrics.prometheus import PrometheusClient + from bentoml._internal.service import Service + from bentoml._internal.server.metrics.prometheus import PrometheusClient else: + from bentoml.grpc.utils import import_grpc from bentoml.grpc.utils import import_generated_stubs pb, _ = import_generated_stubs() + grpc, aio = import_grpc() logger = logging.getLogger(__name__) @@ -53,25 +54,20 @@ def _setup( self, metrics_client: PrometheusClient = Provide[BentoMLContainer.metrics_client], ): # pylint: disable=attribute-defined-outside-init - - # a valid tag name may includes invalid characters, so we need to escape them - # ref: https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels - service_name = self.bento_service.name.replace("-", ":").replace(".", "::") - self.metrics_request_duration = metrics_client.Histogram( - name=f"{service_name}_request_duration_seconds", - documentation=f"{service_name} API GRPC request duration in seconds", - labelnames=["api_name", "service_version", "http_response_code"], + name="request_duration_seconds", + documentation="API GRPC request duration in seconds", + labelnames=["api_name", "service_version", "http_response_code", "service"], ) self.metrics_request_total = metrics_client.Counter( - name=f"{service_name}_request_total", + name="request_total", documentation="Total number of GRPC requests", - labelnames=["api_name", "service_version", "http_response_code"], + labelnames=["api_name", "service_version", "http_response_code", "service"], ) self.metrics_request_in_progress = metrics_client.Gauge( - name=f"{service_name}_request_in_progress", + name="request_in_progress", documentation="Total number of GRPC requests in progress now", - labelnames=["api_name", "service_version"], + labelnames=["api_name", "service_version", "service"], multiprocess_mode="livesum", ) self._is_setup = True @@ -92,6 +88,9 @@ async def intercept_service( service_version = ( self.bento_service.tag.version if self.bento_service.tag else "" ) + # a valid tag name may includes invalid characters, so we need to escape them + # ref: https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels + service_name = self.bento_service.name.replace("-", ":").replace(".", "::") START_TIME_VAR.set(default_timer()) @@ -112,6 +111,7 @@ async def new_behaviour( http_response_code=to_http_status( t.cast(grpc.StatusCode, context.code()) ), + service=service_name, ).inc() # instrument request duration @@ -123,6 +123,7 @@ async def new_behaviour( http_response_code=to_http_status( t.cast(grpc.StatusCode, context.code()) ), + service=service_name, ).observe( total_time ) @@ -131,7 +132,9 @@ async def new_behaviour( # instrument request in progress with self.metrics_request_in_progress.labels( - api_name=api_name, service_version=service_version + api_name=api_name, + service_version=service_version, + service=service_name, ).track_inprogress(): response = await behaviour(request, context) return response diff --git a/bentoml/grpc/types.py b/bentoml/grpc/types.py new file mode 100644 index 00000000000..9fa1dceee39 --- /dev/null +++ b/bentoml/grpc/types.py @@ -0,0 +1,108 @@ +# pragma: no cover +""" +Specific types for BentoService gRPC server. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import typing as t + from functools import partial + + import grpc + from grpc import aio + + from bentoml.grpc.v1alpha1.service_pb2 import Request + from bentoml.grpc.v1alpha1.service_pb2 import Response + from bentoml.grpc.v1alpha1.service_pb2_grpc import BentoServiceServicer + + P = t.TypeVar("P") + + BentoServicerContext = aio.ServicerContext[Request, Response] + + RequestDeserializerFn = t.Callable[[Request | None], object] | None + ResponseSerializerFn = t.Callable[[bytes], Response | None] | None + + HandlerMethod = t.Callable[[Request, BentoServicerContext], P] + AsyncHandlerMethod = t.Callable[[Request, BentoServicerContext], t.Awaitable[P]] + + class RpcMethodHandler( + t.NamedTuple( + "RpcMethodHandler", + request_streaming=bool, + response_streaming=bool, + request_deserializer=RequestDeserializerFn, + response_serializer=ResponseSerializerFn, + unary_unary=t.Optional[HandlerMethod[Response]], + unary_stream=t.Optional[HandlerMethod[Response]], + stream_unary=t.Optional[HandlerMethod[Response]], + stream_stream=t.Optional[HandlerMethod[Response]], + ), + grpc.RpcMethodHandler, + ): + """An implementation of a single RPC method.""" + + request_streaming: bool + response_streaming: bool + request_deserializer: RequestDeserializerFn + response_serializer: ResponseSerializerFn + unary_unary: t.Optional[HandlerMethod[Response]] + unary_stream: t.Optional[HandlerMethod[Response]] + stream_unary: t.Optional[HandlerMethod[Response]] + stream_stream: t.Optional[HandlerMethod[Response]] + + class HandlerCallDetails( + t.NamedTuple( + "HandlerCallDetails", method=str, invocation_metadata=aio.Metadata + ), + grpc.HandlerCallDetails, + ): + """Describes an RPC that has just arrived for service. + + Attributes: + method: The method name of the RPC. + invocation_metadata: A sequence of metadatum, a key-value pair included in the HTTP header. + An example is: ``('binary-metadata-bin', b'\\x00\\xFF')`` + """ + + method: str + invocation_metadata: aio.Metadata + + # Servicer types + ServicerImpl = t.TypeVar("ServicerImpl") + Servicer = t.Annotated[ServicerImpl, object] + ServicerClass = t.Type[Servicer[t.Any]] + AddServicerFn = t.Callable[[Servicer[t.Any], aio.Server | grpc.Server], None] + + # accepted proto fields + ProtoField = t.Annotated[ + str, + t.Literal[ + "dataframe", + "file", + "json", + "ndarray", + "series", + "text", + "multipart", + "serialized_bytes", + ], + ] + + Interceptors = list[ + t.Callable[[], aio.ServerInterceptor] | partial[aio.ServerInterceptor] + ] + + # types defined for client interceptors + BentoUnaryUnaryCall = aio.UnaryUnaryCall[Request, Response] + + __all__ = [ + "Request", + "Response", + "BentoServicerContext", + "BentoServiceServicer", + "HandlerCallDetails", + "RpcMethodHandler", + "BentoUnaryUnaryCall", + ] diff --git a/bentoml/grpc/utils/__init__.py b/bentoml/grpc/utils/__init__.py new file mode 100644 index 00000000000..710d71ad35b --- /dev/null +++ b/bentoml/grpc/utils/__init__.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import typing as t +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING +from functools import lru_cache +from dataclasses import dataclass + +from bentoml._internal.utils.lazy_loader import LazyLoader + +if TYPE_CHECKING: + import types + from enum import Enum + + import grpc + from google.protobuf import descriptor as descriptor_mod + + from bentoml.exceptions import BentoMLException + from bentoml.grpc.types import Request + from bentoml.grpc.types import Response + from bentoml.grpc.types import HandlerMethod + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import BentoServicerContext + from bentoml.grpc.v1alpha1 import service_pb2 as pb + + # We need this here so that __all__ is detected due to lazy import + def import_generated_stubs( + version: str = "v1alpha1", + ) -> tuple[types.ModuleType, types.ModuleType]: + ... + + def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: + ... + +else: + from bentoml.grpc.utils._import_hook import import_grpc + from bentoml.grpc.utils._import_hook import import_generated_stubs + + pb, _ = import_generated_stubs() + grpc, _ = import_grpc() + descriptor_mod = LazyLoader( + "descriptor_mod", globals(), "google.protobuf.descriptor" + ) + +__all__ = [ + "grpc_status_code", + "parse_method_name", + "to_http_status", + "GRPC_CONTENT_TYPE", + "import_generated_stubs", + "import_grpc", +] + +logger = logging.getLogger(__name__) + +# content-type is always application/grpc +GRPC_CONTENT_TYPE = "application/grpc" + + +def get_field_by_name( + descriptor: descriptor_mod.FieldDescriptor | descriptor_mod.Descriptor, + field: str, +) -> descriptor_mod.FieldDescriptor: + if isinstance(descriptor, descriptor_mod.FieldDescriptor): + # descriptor is a FieldDescriptor + return descriptor.message_type.fields_by_name[field] + elif isinstance(descriptor, descriptor_mod.Descriptor): + # descriptor is a Descriptor + return descriptor.fields_by_name[field] + else: + raise NotImplementedError(f"Type {type(descriptor)} is not yet supported.") + + +def is_map_field(field: descriptor_mod.FieldDescriptor) -> bool: + return ( + field.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE + and field.message_type.has_options + and field.message_type.GetOptions().map_entry + ) + + +@lru_cache(maxsize=1) +def http_status_to_grpc_status_map() -> dict[Enum, grpc.StatusCode]: + # Maps HTTP status code to grpc.StatusCode + from http import HTTPStatus + + return { + HTTPStatus.OK: grpc.StatusCode.OK, + HTTPStatus.UNAUTHORIZED: grpc.StatusCode.UNAUTHENTICATED, + HTTPStatus.FORBIDDEN: grpc.StatusCode.PERMISSION_DENIED, + HTTPStatus.NOT_FOUND: grpc.StatusCode.UNIMPLEMENTED, + HTTPStatus.TOO_MANY_REQUESTS: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.BAD_GATEWAY: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT: grpc.StatusCode.DEADLINE_EXCEEDED, + HTTPStatus.BAD_REQUEST: grpc.StatusCode.INVALID_ARGUMENT, + HTTPStatus.INTERNAL_SERVER_ERROR: grpc.StatusCode.INTERNAL, + HTTPStatus.UNPROCESSABLE_ENTITY: grpc.StatusCode.FAILED_PRECONDITION, + } + + +@lru_cache(maxsize=1) +def grpc_status_to_http_status_map() -> dict[grpc.StatusCode, Enum]: + return {v: k for k, v in http_status_to_grpc_status_map().items()} + + +@lru_cache(maxsize=1) +def filetype_pb_to_mimetype_map() -> dict[pb.File.FileType.ValueType, str]: + return { + pb.File.FILE_TYPE_CSV: "text/csv", + pb.File.FILE_TYPE_PLAINTEXT: "text/plain", + pb.File.FILE_TYPE_JSON: "application/json", + pb.File.FILE_TYPE_BYTES: "application/octet-stream", + pb.File.FILE_TYPE_PDF: "application/pdf", + pb.File.FILE_TYPE_PNG: "image/png", + pb.File.FILE_TYPE_JPEG: "image/jpeg", + pb.File.FILE_TYPE_GIF: "image/gif", + pb.File.FILE_TYPE_TIFF: "image/tiff", + pb.File.FILE_TYPE_BMP: "image/bmp", + pb.File.FILE_TYPE_WEBP: "image/webp", + pb.File.FILE_TYPE_SVG: "image/svg+xml", + } + + +@lru_cache(maxsize=1) +def mimetype_to_filetype_pb_map() -> dict[str, pb.File.FileType.ValueType]: + return {v: k for k, v in filetype_pb_to_mimetype_map().items()} + + +def grpc_status_code(err: BentoMLException) -> grpc.StatusCode: + """ + Convert BentoMLException.error_code to grpc.StatusCode. + """ + return http_status_to_grpc_status_map().get(err.error_code, grpc.StatusCode.UNKNOWN) + + +def to_http_status(status_code: grpc.StatusCode) -> int: + """ + Convert grpc.StatusCode to HTTPStatus. + """ + status = grpc_status_to_http_status_map().get( + status_code, HTTPStatus.INTERNAL_SERVER_ERROR + ) + + return status.value + + +@dataclass +class MethodName: + """ + Represents a gRPC method name. + + Attributes: + package: This is defined by `package foo.bar`, designation in the protocol buffer definition + service: service name in protocol buffer definition (eg: service SearchService { ... }) + method: method name + """ + + package: str = "" + service: str = "" + method: str = "" + + @property + def fully_qualified_service(self): + """return the service name prefixed with package""" + return f"{self.package}.{self.service}" if self.package else self.service + + +def parse_method_name(method_name: str) -> tuple[MethodName, bool]: + """ + Infers the grpc service and method name from the handler_call_details. + e.g. /package.ServiceName/MethodName + """ + method = method_name.split("/", maxsplit=2) + # sanity check for method. + if len(method) != 3: + return MethodName(), False + _, package_service, method = method + *packages, service = package_service.rsplit(".", maxsplit=1) + package = packages[0] if packages else "" + return MethodName(package, service, method), True + + +def wrap_rpc_handler( + wrapper: t.Callable[ + [HandlerMethod[Response]], t.Callable[[Request, BentoServicerContext], Response] + ], + handler: RpcMethodHandler | None, +) -> RpcMethodHandler | None: + if not handler: + return None + if not handler.request_streaming and not handler.response_streaming: + assert handler.unary_unary + return handler._replace(unary_unary=wrapper(handler.unary_unary)) + elif not handler.request_streaming and handler.response_streaming: + assert handler.unary_stream + return handler._replace(unary_stream=wrapper(handler.unary_stream)) + elif handler.request_streaming and not handler.response_streaming: + assert handler.stream_unary + return handler._replace(stream_unary=wrapper(handler.stream_unary)) + elif handler.request_streaming and handler.response_streaming: + assert handler.stream_stream + return handler._replace(stream_stream=wrapper(handler.stream_stream)) + else: + raise RuntimeError(f"RPC method handler {handler} does not exist.") from None diff --git a/bentoml/grpc/utils/_import_hook.py b/bentoml/grpc/utils/_import_hook.py new file mode 100644 index 00000000000..285e0abec06 --- /dev/null +++ b/bentoml/grpc/utils/_import_hook.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from pathlib import Path + +if TYPE_CHECKING: + import types + + +def import_generated_stubs( + version: str = "v1alpha1", + file: str = "service.proto", +) -> tuple[types.ModuleType, types.ModuleType]: + """ + Import generated stubs. + + Args: + version: The version of the proto file to import. + file: The name of the proto file to import. + + Returns: + A tuple of the generated stubs for the proto file. + + Examples: + + .. code-block:: python + + from bentoml.grpc.utils import import_generated_stubs + + # given proto file bentoml/grpc/v1alpha2/service.proto exists + pb, services = import_generated_stubs(version="v1alpha2", file="service.proto") + """ + # generate git root from this file's path + from bentoml._internal.utils import LazyLoader + + GIT_ROOT = Path(__file__).parent.parent.parent.parent + + exception_message = f"Generated stubs for '{version}/{file}' are missing. To generate stubs, run '{GIT_ROOT}/scripts/generate_grpc_stubs.sh'" + file = file.split(".")[0] + + service_pb2 = LazyLoader( + f"{file}_pb2", + globals(), + f"bentoml.grpc.{version}.{file}_pb2", + exc_msg=exception_message, + ) + service_pb2_grpc = LazyLoader( + f"{file}_pb2_grpc", + globals(), + f"bentoml.grpc.{version}.{file}_pb2_grpc", + exc_msg=exception_message, + ) + return service_pb2, service_pb2_grpc + + +def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: + from bentoml._internal.utils import LazyLoader + + exception_message = "'grpcio' is required for gRPC support. Install with 'pip install bentoml[grpc]'." + grpc = LazyLoader( + "grpc", + globals(), + "grpc", + exc_msg=exception_message, + ) + aio = LazyLoader("aio", globals(), "grpc.aio", exc_msg=exception_message) + return grpc, aio diff --git a/bentoml/testing/server.py b/bentoml/testing/server.py index d520e0a88d2..8553f0b69fa 100644 --- a/bentoml/testing/server.py +++ b/bentoml/testing/server.py @@ -18,10 +18,11 @@ from typing import TYPE_CHECKING from contextlib import contextmanager +import psutil + from .._internal.tag import Tag from .._internal.utils import reserve_free_port from .._internal.utils import cached_contextmanager -from .._internal.utils.platform import kill_subprocess_tree logger = logging.getLogger("bentoml") @@ -75,6 +76,19 @@ async def async_request( return r.status, Headers(headers), r_body +def kill_subprocess_tree(p: subprocess.Popen[t.Any]) -> None: + """ + Tell the process to terminate and kill all of its children. Availabe both on Windows and Linux. + Note: It will return immediately rather than wait for the process to terminate. + Args: + p: subprocess.Popen object + """ + if psutil.WINDOWS: + subprocess.call(["taskkill", "/F", "/T", "/PID", str(p.pid)]) + else: + p.terminate() + + def _wait_until_api_server_ready( host_url: str, timeout: float,