diff --git a/src/bentoml/_internal/bento/build_dev_bentoml_whl.py b/src/bentoml/_internal/bento/build_dev_bentoml_whl.py index f3a9e1647b1..0476dd081bf 100644 --- a/src/bentoml/_internal/bento/build_dev_bentoml_whl.py +++ b/src/bentoml/_internal/bento/build_dev_bentoml_whl.py @@ -7,6 +7,7 @@ from ..utils.pkg import source_locations from ...exceptions import BentoMLException from ...exceptions import MissingDependencyException +from ...grpc.utils import LATEST_PROTOCOL_VERSION from ..configuration import is_pypi_installed_bentoml logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ def build_bentoml_editable_wheel( - target_path: str, *, _internal_stubs_version: str = "v1" + target_path: str, *, _internal_protocol_version: str = LATEST_PROTOCOL_VERSION ) -> None: """ This is for BentoML developers to create Bentos that contains the local bentoml @@ -52,10 +53,10 @@ def build_bentoml_editable_wheel( bentoml_path = Path(module_location) if not Path( - module_location, "grpc", _internal_stubs_version, "service_pb2.py" + module_location, "grpc", _internal_protocol_version, "service_pb2.py" ).exists(): raise ModuleNotFoundError( - f"Generated stubs for version {_internal_stubs_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_stubs_version}' beforehand to generate gRPC stubs." + f"Generated stubs for version {_internal_protocol_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_protocol_version}' beforehand to generate gRPC stubs." ) from None # location to pyproject.toml diff --git a/src/bentoml/_internal/io_descriptors/base.py b/src/bentoml/_internal/io_descriptors/base.py index 870144c865a..f6c03e0f1e1 100644 --- a/src/bentoml/_internal/io_descriptors/base.py +++ b/src/bentoml/_internal/io_descriptors/base.py @@ -36,7 +36,7 @@ IOType = t.TypeVar("IOType") -def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]: +def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]: if "id" not in spec: raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.") return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec) diff --git a/src/bentoml/_internal/io_descriptors/json.py b/src/bentoml/_internal/io_descriptors/json.py index da6ed9bc502..2adbea6f631 100644 --- a/src/bentoml/_internal/io_descriptors/json.py +++ b/src/bentoml/_internal/io_descriptors/json.py @@ -30,6 +30,7 @@ import pydantic import pydantic.schema as schema + from google.protobuf import message as _message from google.protobuf import struct_pb2 from typing_extensions import Self @@ -392,19 +393,29 @@ async def to_proto(self, obj: JSONType) -> struct_pb2.Value: if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj): obj = obj.dict() msg = struct_pb2.Value() - # To handle None cases. - if obj is not None: - from google.protobuf.json_format import ParseDict - - if isinstance(obj, (dict, str, list, float, int, bool)): - # ParseDict handles google.protobuf.Struct type - # directly if given object has a supported type - ParseDict(obj, msg) - else: - # If given object doesn't have a supported type, we will - # use given JSON encoder to convert it to dictionary - # and then parse it to google.protobuf.Struct. - # Note that if a custom JSON encoder is used, it mustn't - # take any arguments. - ParseDict(self._json_encoder().default(obj), msg) + return parse_dict_to_proto(obj, msg, json_encoder=self._json_encoder) + + +def parse_dict_to_proto( + obj: JSONType, + msg: _message.Message, + json_encoder: type[json.JSONEncoder] = DefaultJsonEncoder, +) -> t.Any: + if obj is None: + # this function is an identity op for the msg if obj is None. return msg + + from google.protobuf.json_format import ParseDict + + if isinstance(obj, (dict, str, list, float, int, bool)): + # ParseDict handles google.protobuf.Struct type + # directly if given object has a supported type + ParseDict(obj, msg) + else: + # If given object doesn't have a supported type, we will + # use given JSON encoder to convert it to dictionary + # and then parse it to google.protobuf.Struct. + # Note that if a custom JSON encoder is used, it mustn't + # take any arguments. + ParseDict(json_encoder().default(obj), msg) + return msg diff --git a/src/bentoml/_internal/server/grpc/__init__.py b/src/bentoml/_internal/server/grpc/__init__.py index df277aa1d9a..09ed39ca17f 100644 --- a/src/bentoml/_internal/server/grpc/__init__.py +++ b/src/bentoml/_internal/server/grpc/__init__.py @@ -1,4 +1,3 @@ from .server import Server -from .servicer import Servicer -__all__ = ["Server", "Servicer"] +__all__ = ["Server"] diff --git a/src/bentoml/_internal/server/grpc/server.py b/src/bentoml/_internal/server/grpc/server.py index b96d87506ef..f2ca52c8596 100644 --- a/src/bentoml/_internal/server/grpc/server.py +++ b/src/bentoml/_internal/server/grpc/server.py @@ -4,6 +4,7 @@ import sys import typing as t import asyncio +import inspect import logging from typing import TYPE_CHECKING from concurrent.futures import ThreadPoolExecutor @@ -29,7 +30,8 @@ from bentoml.grpc.v1 import service_pb2_grpc as services - from .servicer import Servicer + from ..grpc_app import GrpcServicerFactory + else: grpc, aio = import_grpc() _, services = import_generated_stubs() @@ -61,7 +63,7 @@ class Server(aio._server.Server): @inject def __init__( self, - servicer: Servicer, + servicer: GrpcServicerFactory, bind_address: str, max_message_length: int | None = Provide[BentoMLContainer.grpc.max_message_length], @@ -88,10 +90,6 @@ def __init__( self.ssl_keyfile = ssl_keyfile self.ssl_ca_certs = ssl_ca_certs - if not bool(self.servicer): - self.servicer.load() - assert self.servicer.loaded - super().__init__( # Note that the max_workers are used inside ThreadPoolExecutor. # This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers. @@ -189,7 +187,11 @@ async def startup(self) -> None: from bentoml.exceptions import MissingDependencyException # Running on_startup callback. - await self.servicer.startup() + for handler in self.servicer.on_startup: + out = handler() + if inspect.isawaitable(out): + await out + # register bento servicer services.add_BentoServiceServicer_to_server(self.servicer.bento_servicer, self) services_health.add_HealthServicer_to_server( @@ -236,7 +238,11 @@ async def startup(self) -> None: async def shutdown(self): # Running on_startup callback. - await self.servicer.shutdown() + for handler in self.servicer.on_shutdown: + out = handler() + if inspect.isawaitable(out): + await out + await self.stop(grace=self.graceful_shutdown_timeout) await self.servicer.health_servicer.enter_graceful_shutdown() self.loop.stop() diff --git a/src/bentoml/_internal/server/grpc/servicer/__init__.py b/src/bentoml/_internal/server/grpc/servicer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py new file mode 100644 index 00000000000..d88e7613e7d --- /dev/null +++ b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import sys +import asyncio +import logging +from typing import TYPE_CHECKING + +import anyio + +from ......exceptions import InvalidArgument +from ......exceptions import BentoMLException +from ......grpc.utils import import_grpc +from ......grpc.utils import grpc_status_code +from ......grpc.utils import validate_proto_fields +from ......grpc.utils import import_generated_stubs + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) + + import grpc + + from bentoml.grpc.types import BentoServicerContext + + from ......grpc.v1 import service_pb2 as pb + from ......grpc.v1 import service_pb2_grpc as services + from .....service.service import Service +else: + grpc, _ = import_grpc() + pb, services = import_generated_stubs(version="v1") + + +def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: + # gRPC will always send a POST request. + logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) + + +def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + """ + This is the actual implementation of BentoServicer. + Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call + """ + + class BentoServiceImpl(services.BentoServiceServicer): + """An asyncio implementation of BentoService servicer.""" + + async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self: services.BentoServiceServicer, + request: pb.Request, + context: BentoServicerContext, + ) -> pb.Response | None: + if request.api_name not in service.apis: + raise InvalidArgument( + f"given 'api_name' is not defined in {service.name}", + ) from None + + api = service.apis[request.api_name] + response = pb.Response() + + # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved. + # This is important so that we know the order of fields to process. + # We will use fields descriptor to determine how to process that request. + try: + # we will check if the given fields list contains a pb.Multipart. + input_proto = getattr( + request, + validate_proto_fields(request.WhichOneof("content"), api.input), + ) + input_data = await api.input.from_proto(input_proto) + if asyncio.iscoroutinefunction(api.func): + if api.multi_input: + output = await api.func(**input_data) + else: + output = await api.func(input_data) + else: + if api.multi_input: + output = await anyio.to_thread.run_sync(api.func, **input_data) + else: + output = await anyio.to_thread.run_sync(api.func, input_data) + res = await api.output.to_proto(output) + # TODO(aarnphm): support multiple proto fields + response = pb.Response(**{api.output._proto_fields[0]: res}) + except BentoMLException as e: + log_exception(request, sys.exc_info()) + await context.abort(code=grpc_status_code(e), details=e.message) + except (RuntimeError, TypeError, NotImplementedError): + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="A runtime error has occurred, see stacktrace from logs.", + ) + except Exception: # pylint: disable=broad-except + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", + ) + return response + + return BentoServiceImpl() diff --git a/src/bentoml/_internal/server/grpc/servicer.py b/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py similarity index 54% rename from src/bentoml/_internal/server/grpc/servicer.py rename to src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py index b96ca7b8ea4..33688f14c93 100644 --- a/src/bentoml/_internal/server/grpc/servicer.py +++ b/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py @@ -1,22 +1,18 @@ from __future__ import annotations import sys -import typing as t import asyncio import logging from typing import TYPE_CHECKING -from inspect import isawaitable import anyio -from bentoml.grpc.utils import import_grpc -from bentoml.grpc.utils import grpc_status_code -from bentoml.grpc.utils import validate_proto_fields -from bentoml.grpc.utils import import_generated_stubs - -from ...utils import LazyLoader -from ....exceptions import InvalidArgument -from ....exceptions import BentoMLException +from ......exceptions import InvalidArgument +from ......exceptions import BentoMLException +from ......grpc.utils import import_grpc +from ......grpc.utils import grpc_status_code +from ......grpc.utils import validate_proto_fields +from ......grpc.utils import import_generated_stubs logger = logging.getLogger(__name__) @@ -24,31 +20,15 @@ from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) import grpc - from grpc import aio - from grpc_health.v1 import health - from typing_extensions import Self - - from bentoml.grpc.v1 import service_pb2 as pb - from bentoml.grpc.v1 import service_pb2_grpc as services - from bentoml.grpc.types import Interceptors - from bentoml.grpc.types import AddServicerFn - from bentoml.grpc.types import ServicerClass - from bentoml.grpc.types import BentoServicerContext - from ...service.service import Service + from bentoml.grpc.types import BentoServicerContext + from ......grpc.v1alpha1 import service_pb2 as pb + from ......grpc.v1alpha1 import service_pb2_grpc as services + from .....service.service import Service else: - pb, services = import_generated_stubs() - grpc, aio = import_grpc() - health = LazyLoader( - "health", - globals(), - "grpc_health.v1.health", - exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", - ) - containers = LazyLoader( - "containers", globals(), "google.protobuf.internal.containers" - ) + grpc, _ = import_grpc() + pb, services = import_generated_stubs(version="v1alpha1") def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: @@ -56,61 +36,6 @@ def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) -class Servicer: - """Create an instance of gRPC Servicer.""" - - def __init__( - self: Self, - service: Service, - on_startup: t.Sequence[t.Callable[[], t.Any]] | None = None, - on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None, - mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]] - | None = None, - interceptors: Interceptors | None = None, - ) -> None: - self.bento_service = service - - self.on_startup = [] if not on_startup else list(on_startup) - self.on_shutdown = [] if not on_shutdown else list(on_shutdown) - self.mount_servicers = [] if not mount_servicers else list(mount_servicers) - self.interceptors = [] if not interceptors else list(interceptors) - self.loaded = False - - def load(self): - assert not self.loaded - - self.interceptors_stack = self.build_interceptors_stack() - - self.bento_servicer = create_bento_servicer(self.bento_service) - - # Create a health check servicer. We use the non-blocking implementation - # to avoid thread starvation. - self.health_servicer = health.aio.HealthServicer() - - self.service_names = tuple( - service.full_name for service in pb.DESCRIPTOR.services_by_name.values() - ) + (health.SERVICE_NAME,) - self.loaded = True - - def build_interceptors_stack(self) -> list[aio.ServerInterceptor]: - return list(map(lambda x: x(), self.interceptors)) - - async def startup(self): - for handler in self.on_startup: - out = handler() - if isawaitable(out): - await out - - async def shutdown(self): - for handler in self.on_shutdown: - out = handler() - if isawaitable(out): - await out - - def __bool__(self): - return self.loaded - - def create_bento_servicer(service: Service) -> services.BentoServiceServicer: """ This is the actual implementation of BentoServicer. @@ -121,7 +46,7 @@ class BentoServiceImpl(services.BentoServiceServicer): """An asyncio implementation of BentoService servicer.""" async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method - self, + self: services.BentoServiceServicer, request: pb.Request, context: BentoServicerContext, ) -> pb.Response | None: diff --git a/src/bentoml/_internal/server/grpc_app.py b/src/bentoml/_internal/server/grpc_app.py index cc6df9e624e..71c537c93e5 100644 --- a/src/bentoml/_internal/server/grpc_app.py +++ b/src/bentoml/_internal/server/grpc_app.py @@ -3,34 +3,55 @@ import typing as t import asyncio import logging +import importlib from typing import TYPE_CHECKING from functools import partial from simple_di import inject from simple_di import Provide +from ..utils import LazyLoader +from ...grpc.utils import import_generated_stubs +from ...grpc.utils import LATEST_PROTOCOL_VERSION from ..configuration.containers import BentoMLContainer logger = logging.getLogger(__name__) if TYPE_CHECKING: + from types import ModuleType + + from grpc_health.v1 import health from bentoml.grpc.types import Interceptors from ..service import Service - from .grpc.servicer import Servicer + from ...grpc.v1 import service_pb2_grpc as services + + class ServicerModule(ModuleType): + @staticmethod + def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + ... OnStartup = list[t.Callable[[], t.Union[None, t.Coroutine[t.Any, t.Any, None]]]] +else: + health = LazyLoader( + "health", + globals(), + "grpc_health.v1.health", + exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", + ) -class GRPCAppFactory: +class GrpcServicerFactory: """ - GRPCApp creates an async gRPC API server based on APIs defined with a BentoService via BentoService#apis. - This is a light wrapper around GRPCServer with addition to `on_startup` and `on_shutdown` hooks. + GrpcServicerFactory creates an async gRPC API server based on APIs defined with a BentoService via BentoService.apis. + This is a light wrapper around GrpcServer with addition to `on_startup` and `on_shutdown` hooks. Note that even though the code are similar with BaseAppFactory, gRPC protocol is different from ASGI. """ + _cached_module = None + @inject def __init__( self, @@ -39,9 +60,27 @@ def __init__( enable_metrics: bool = Provide[ BentoMLContainer.api_server_config.metrics.enabled ], + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> None: + pb, _ = import_generated_stubs(protocol_version) + self.bento_service = bento_service self.enable_metrics = enable_metrics + self.protocol_version = protocol_version + self.interceptors_stack = list(map(lambda x: x(), self.interceptors)) + + self.bento_servicer = self._servicer_module.create_bento_servicer( + self.bento_service + ) + self.mount_servicers = self.bento_service.mount_servicers + + # Create a health check servicer. We use the non-blocking implementation + # to avoid thread starvation. + self.health_servicer = health.aio.HealthServicer() + + self.service_names = tuple( + service.full_name for service in pb.DESCRIPTOR.services_by_name.values() + ) + (health.SERVICE_NAME,) @inject async def wait_for_runner_ready( @@ -93,34 +132,37 @@ def on_shutdown(self) -> list[t.Callable[[], None]]: return on_shutdown - def __call__(self) -> Servicer: - from .grpc import Servicer - - return Servicer( - self.bento_service, - on_startup=self.on_startup, - on_shutdown=self.on_shutdown, - mount_servicers=self.bento_service.mount_servicers, - interceptors=self.interceptors, - ) + @property + def _servicer_module(self) -> ServicerModule: + if self._cached_module is None: + object.__setattr__( + self, + "_cached_module", + importlib.import_module( + f".grpc.servicer.{self.protocol_version}", + package="bentoml._internal.server", + ), + ) + assert self._cached_module is not None + return self._cached_module @property def interceptors(self) -> Interceptors: # Note that order of interceptors is important here. - from bentoml.grpc.interceptors.opentelemetry import ( + from ...grpc.interceptors.opentelemetry import ( AsyncOpenTelemetryServerInterceptor, ) interceptors: Interceptors = [AsyncOpenTelemetryServerInterceptor] if self.enable_metrics: - from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor + from ...grpc.interceptors.prometheus import PrometheusServerInterceptor interceptors.append(PrometheusServerInterceptor) if BentoMLContainer.api_server_config.logging.access.enabled.get(): - from bentoml.grpc.interceptors.access import AccessLogServerInterceptor + from ...grpc.interceptors.access import AccessLogServerInterceptor access_logger = logging.getLogger("bentoml.access") if access_logger.getEffectiveLevel() <= logging.INFO: diff --git a/src/bentoml/_internal/service/service.py b/src/bentoml/_internal/service/service.py index c905f818011..5b820c3c859 100644 --- a/src/bentoml/_internal/service/service.py +++ b/src/bentoml/_internal/service/service.py @@ -25,7 +25,7 @@ from .. import external_typing as ext from ..bento import Bento - from ..server.grpc.servicer import Servicer + from ..server.grpc_app import GrpcServicerFactory from .openapi.specification import OpenAPISpecification else: grpc, _ = import_grpc() @@ -221,10 +221,10 @@ def on_grpc_server_shutdown(self) -> None: pass @property - def grpc_servicer(self) -> Servicer: - from ..server.grpc_app import GRPCAppFactory + def grpc_servicer(self) -> GrpcServicerFactory: + from ..server.grpc_app import GrpcServicerFactory - return GRPCAppFactory(self)() + return GrpcServicerFactory(self) @property def asgi_app(self) -> "ext.ASGIApp": diff --git a/src/bentoml/grpc/utils/__init__.py b/src/bentoml/grpc/utils/__init__.py index a0a998334f6..f5ba32f5284 100644 --- a/src/bentoml/grpc/utils/__init__.py +++ b/src/bentoml/grpc/utils/__init__.py @@ -10,6 +10,7 @@ from bentoml.exceptions import InvalidArgument from bentoml.grpc.utils._import_hook import import_grpc from bentoml.grpc.utils._import_hook import import_generated_stubs +from bentoml.grpc.utils._import_hook import LATEST_PROTOCOL_VERSION if TYPE_CHECKING: from enum import Enum @@ -36,6 +37,7 @@ "import_generated_stubs", "import_grpc", "validate_proto_fields", + "LATEST_PROTOCOL_VERSION", ] logger = logging.getLogger(__name__) diff --git a/src/bentoml/grpc/utils/_import_hook.py b/src/bentoml/grpc/utils/_import_hook.py index 29b33eac705..147f0c2921d 100644 --- a/src/bentoml/grpc/utils/_import_hook.py +++ b/src/bentoml/grpc/utils/_import_hook.py @@ -5,9 +5,11 @@ if TYPE_CHECKING: import types +LATEST_PROTOCOL_VERSION = "v1" + def import_generated_stubs( - version: str = "v1", + version: str = LATEST_PROTOCOL_VERSION, file: str = "service.proto", ) -> tuple[types.ModuleType, types.ModuleType]: """ diff --git a/src/bentoml/testing/grpc/__init__.py b/src/bentoml/testing/grpc/__init__.py index ad1f68d5d08..7367460f12d 100644 --- a/src/bentoml/testing/grpc/__init__.py +++ b/src/bentoml/testing/grpc/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +import importlib import traceback from typing import TYPE_CHECKING from contextlib import ExitStack @@ -9,11 +10,11 @@ from bentoml.exceptions import BentoMLException from bentoml.grpc.utils import import_grpc from bentoml.grpc.utils import import_generated_stubs +from bentoml.grpc.utils import LATEST_PROTOCOL_VERSION from bentoml._internal.utils import LazyLoader from bentoml._internal.utils import reserve_free_port from bentoml._internal.utils import cached_contextmanager from bentoml._internal.utils import add_experimental_docstring -from bentoml._internal.server.grpc.servicer import create_bento_servicer if TYPE_CHECKING: import grpc @@ -23,9 +24,9 @@ from grpc.aio._channel import Channel from google.protobuf.message import Message + from bentoml import Service from bentoml.grpc.v1 import service_pb2 as pb else: - pb, _ = import_generated_stubs() grpc, aio = import_grpc() # pylint: disable=E1111 np = LazyLoader("np", globals(), "numpy") @@ -39,17 +40,38 @@ ] -def randomize_pb_ndarray(shape: tuple[int, ...]) -> pb.NDArray: +def create_bento_servicer( + protocol_version: str = LATEST_PROTOCOL_VERSION, +) -> t.Callable[[Service], t.Any]: + try: + module = importlib.import_module( + f".{protocol_version}", package="bentoml._internal.server.grpc.servicer" + ) + return getattr(module, "create_bento_servicer") + except (ImportError, ModuleNotFoundError): + raise BentoMLException( + f"Failed to load servicer implementation for version {protocol_version}" + ) from None + + +def randomize_pb_ndarray( + shape: tuple[int, ...], protocol_version: str = LATEST_PROTOCOL_VERSION +) -> pb.NDArray: + pb, _ = import_generated_stubs(protocol_version) arr: NDArray[np.float32] = t.cast("NDArray[np.float32]", np.random.rand(*shape)) return pb.NDArray( shape=list(shape), dtype=pb.NDArray.DTYPE_FLOAT, float_values=arr.ravel() ) -def make_pb_ndarray(arr: NDArray[t.Any]) -> pb.NDArray: +def make_pb_ndarray( + arr: NDArray[t.Any], protocol_version: str = LATEST_PROTOCOL_VERSION +) -> pb.NDArray: from bentoml._internal.io_descriptors.numpy import npdtype_to_dtypepb_map from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map + pb, _ = import_generated_stubs(protocol_version) + try: fieldpb = npdtype_to_fieldpb_map()[arr.dtype] dtypepb = npdtype_to_dtypepb_map()[arr.dtype] @@ -76,7 +98,7 @@ async def async_client_call( assert_code: grpc.StatusCode | None = None, assert_details: str | None = None, assert_trailing_metadata: aio.Metadata | None = None, - _internal_stubs_version: str = "v1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> pb.Response | None: """ Invoke a given API method via a client. @@ -95,11 +117,12 @@ async def async_client_call( Returns: The response from the server. """ + pb, _ = import_generated_stubs(protocol_version) res: pb.Response | None = None try: Call = channel.unary_unary( - f"/bentoml.grpc.{_internal_stubs_version}.BentoService/Call", + f"/bentoml.grpc.{protocol_version}.BentoService/Call", request_serializer=pb.Request.SerializeToString, response_deserializer=pb.Response.FromString, ) diff --git a/src/bentoml/testing/server.py b/src/bentoml/testing/server.py index b7dd72ca9a5..6d2b1d03321 100644 --- a/src/bentoml/testing/server.py +++ b/src/bentoml/testing/server.py @@ -25,6 +25,8 @@ from bentoml._internal.utils import reserve_free_port from bentoml._internal.utils import cached_contextmanager +from ..grpc.utils import LATEST_PROTOCOL_VERSION + if TYPE_CHECKING: from grpc import aio from grpc_health.v1 import health_pb2 as pb_health @@ -75,7 +77,7 @@ async def server_warmup( check_interval: float = 1, popen: subprocess.Popen[t.Any] | None = None, service_name: str | None = None, - _internal_stubs_version: str = "v1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> bool: start_time = time.time() proxy_handler = urllib.request.ProxyHandler({}) @@ -87,9 +89,7 @@ async def server_warmup( try: if service_name is None: - service_name = ( - f"bentoml.grpc.{_internal_stubs_version}.BentoService" - ) + service_name = f"bentoml.grpc.{protocol_version}.BentoService" async with create_channel(host_url) as channel: Check = channel.unary_unary( "/grpc.health.v1.Health/Check", @@ -177,7 +177,7 @@ def containerize( subprocess.call([backend, "rmi", image_tag]) -@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}") +@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}, {protocol_version}") def run_bento_server_container( image_tag: str, config_file: str | None = None, @@ -185,6 +185,7 @@ def run_bento_server_container( timeout: float = 90, host: str = "127.0.0.1", backend: str = "docker", + protocol_version: str = LATEST_PROTOCOL_VERSION, ): """ Launch a bentoml service container from a container, yield the host URL @@ -227,7 +228,13 @@ def run_bento_server_container( try: host_url = f"{host}:{port}" if asyncio.run( - server_warmup(host_url, timeout=timeout, popen=proc, grpc=use_grpc) + server_warmup( + host_url, + timeout=timeout, + popen=proc, + grpc=use_grpc, + protocol_version=protocol_version, + ) ): yield host_url else: @@ -247,6 +254,7 @@ def run_bento_server_standalone( config_file: str | None = None, timeout: float = 90, host: str = "127.0.0.1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ): """ Launch a bentoml service directly by the bentoml CLI, yields the host URL. @@ -277,7 +285,13 @@ def run_bento_server_standalone( try: host_url = f"{host}:{server_port}" assert asyncio.run( - server_warmup(host_url, timeout=timeout, popen=p, grpc=use_grpc) + server_warmup( + host_url, + timeout=timeout, + popen=p, + grpc=use_grpc, + protocol_version=protocol_version, + ) ) yield host_url finally: @@ -302,6 +316,7 @@ def run_bento_server_distributed( use_grpc: bool = False, timeout: float = 90, host: str = "127.0.0.1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ): """ Launch a bentoml service as a simulated distributed environment(Yatai), yields the host URL. @@ -391,7 +406,14 @@ def run_bento_server_distributed( ) try: host_url = f"{host}:{server_port}" - asyncio.run(server_warmup(host_url, timeout=timeout, grpc=use_grpc)) + asyncio.run( + server_warmup( + host_url, + timeout=timeout, + grpc=use_grpc, + protocol_version=protocol_version, + ) + ) yield host_url finally: for p in processes: @@ -404,7 +426,7 @@ def run_bento_server_distributed( @cached_contextmanager( - "{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}" + "{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}, {protocol_version}" ) def host_bento( bento_name: str | Tag | None = None, @@ -417,6 +439,7 @@ def host_bento( host: str = "127.0.0.1", timeout: float = 120, backend: str = "docker", + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> t.Generator[str, None, None]: """ Host a bentoml service, yields the host URL. @@ -473,6 +496,7 @@ def host_bento( use_grpc=use_grpc, host=host, timeout=timeout, + protocol_version=protocol_version, ) as host_url: yield host_url elif deployment_mode == "container": @@ -492,6 +516,7 @@ def host_bento( host=host, timeout=timeout, backend=backend, + protocol_version=protocol_version, ) as host_url: yield host_url elif deployment_mode == "distributed": @@ -501,6 +526,7 @@ def host_bento( use_grpc=use_grpc, host=host, timeout=timeout, + protocol_version=protocol_version, ) as host_url: yield host_url else: diff --git a/tests/unit/grpc/interceptors/test_access.py b/tests/unit/grpc/interceptors/test_access.py index d07f0607d2a..57bb99177c2 100644 --- a/tests/unit/grpc/interceptors/test_access.py +++ b/tests/unit/grpc/interceptors/test_access.py @@ -29,7 +29,6 @@ from google.protobuf import wrappers_pb2 from bentoml import Service - from bentoml.grpc.v1 import service_pb2_grpc as services from bentoml.grpc.types import Request from bentoml.grpc.types import Response from bentoml.grpc.types import RpcMethodHandler @@ -37,7 +36,6 @@ from bentoml.grpc.types import HandlerCallDetails from bentoml.grpc.types import BentoServicerContext else: - _, services = import_generated_stubs() grpc, aio = import_grpc() wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") @@ -125,7 +123,12 @@ async def test_trailing_metadata(caplog: LogCaptureFixture): @pytest.mark.asyncio @pytest.mark.usefixtures("propagate_logs") -async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: Service): +@pytest.mark.parametrize("protocol_version", ["v1", "v1alpha1"]) +async def test_access_log_exception( + caplog: LogCaptureFixture, simple_service: Service, protocol_version: str +): + _, services = import_generated_stubs(protocol_version) + with make_standalone_server( # we need to also setup opentelemetry interceptor # to make sure the access log is correctly setup. @@ -135,7 +138,7 @@ async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: S ] ) as (server, host_url): services.add_BentoServiceServicer_to_server( - create_bento_servicer(simple_service), server + create_bento_servicer(protocol_version)(simple_service), server ) try: await server.start() @@ -146,9 +149,10 @@ async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: S channel=channel, data={"text": wrappers_pb2.StringValue(value="asdf")}, assert_code=grpc.StatusCode.INTERNAL, + protocol_version=protocol_version, ) assert ( - "(scheme=http,path=/bentoml.grpc.v1.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)" + f"(scheme=http,path=/bentoml.grpc.{protocol_version}.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)" in caplog.text ) finally: diff --git a/tests/unit/grpc/interceptors/test_prometheus.py b/tests/unit/grpc/interceptors/test_prometheus.py index 237cc0c47c7..4c5fad79ef0 100644 --- a/tests/unit/grpc/interceptors/test_prometheus.py +++ b/tests/unit/grpc/interceptors/test_prometheus.py @@ -27,10 +27,7 @@ from google.protobuf import wrappers_pb2 from bentoml import Service - from bentoml.grpc.v1 import service_pb2_grpc as services else: - - _, services = import_generated_stubs() wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") grpc, aio = import_grpc() @@ -106,19 +103,23 @@ async def test_empty_metrics(): ("gauge", ["api_name", "service_version", "service_name"]), ], ) +@pytest.mark.parametrize("protocol_version", ["v1", "v1alpha1"]) async def test_metrics_interceptors( simple_service: Service, metric_type: str, parent_set: list[str], + protocol_version: str, ): metrics_client = BentoMLContainer.metrics_client.get() + _, services = import_generated_stubs(protocol_version) + with make_standalone_server(interceptors=[interceptor]) as ( server, host_url, ): services.add_BentoServiceServicer_to_server( - create_bento_servicer(simple_service), server + create_bento_servicer(protocol_version)(simple_service), server ) try: await server.start() @@ -127,6 +128,7 @@ async def test_metrics_interceptors( "noop_sync", channel=channel, data={"text": wrappers_pb2.StringValue(value="BentoML")}, + protocol_version=protocol_version, ) for m in metrics_client.text_string_to_metric_families(): for sample in m.samples: