diff --git a/src/bentoml/_internal/bento/build_dev_bentoml_whl.py b/src/bentoml/_internal/bento/build_dev_bentoml_whl.py index f3a9e1647b..0476dd081b 100644 --- a/src/bentoml/_internal/bento/build_dev_bentoml_whl.py +++ b/src/bentoml/_internal/bento/build_dev_bentoml_whl.py @@ -7,6 +7,7 @@ from ..utils.pkg import source_locations from ...exceptions import BentoMLException from ...exceptions import MissingDependencyException +from ...grpc.utils import LATEST_PROTOCOL_VERSION from ..configuration import is_pypi_installed_bentoml logger = logging.getLogger(__name__) @@ -15,7 +16,7 @@ def build_bentoml_editable_wheel( - target_path: str, *, _internal_stubs_version: str = "v1" + target_path: str, *, _internal_protocol_version: str = LATEST_PROTOCOL_VERSION ) -> None: """ This is for BentoML developers to create Bentos that contains the local bentoml @@ -52,10 +53,10 @@ def build_bentoml_editable_wheel( bentoml_path = Path(module_location) if not Path( - module_location, "grpc", _internal_stubs_version, "service_pb2.py" + module_location, "grpc", _internal_protocol_version, "service_pb2.py" ).exists(): raise ModuleNotFoundError( - f"Generated stubs for version {_internal_stubs_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_stubs_version}' beforehand to generate gRPC stubs." + f"Generated stubs for version {_internal_protocol_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_protocol_version}' beforehand to generate gRPC stubs." ) from None # location to pyproject.toml diff --git a/src/bentoml/_internal/io_descriptors/base.py b/src/bentoml/_internal/io_descriptors/base.py index 870144c865..f6c03e0f1e 100644 --- a/src/bentoml/_internal/io_descriptors/base.py +++ b/src/bentoml/_internal/io_descriptors/base.py @@ -36,7 +36,7 @@ IOType = t.TypeVar("IOType") -def from_spec(spec: dict[str, str]) -> IODescriptor[t.Any]: +def from_spec(spec: dict[str, t.Any]) -> IODescriptor[t.Any]: if "id" not in spec: raise InvalidArgument(f"IO descriptor spec ({spec}) missing ID.") return IO_DESCRIPTOR_REGISTRY[spec["id"]].from_spec(spec) diff --git a/src/bentoml/_internal/io_descriptors/json.py b/src/bentoml/_internal/io_descriptors/json.py index da6ed9bc50..2adbea6f63 100644 --- a/src/bentoml/_internal/io_descriptors/json.py +++ b/src/bentoml/_internal/io_descriptors/json.py @@ -30,6 +30,7 @@ import pydantic import pydantic.schema as schema + from google.protobuf import message as _message from google.protobuf import struct_pb2 from typing_extensions import Self @@ -392,19 +393,29 @@ async def to_proto(self, obj: JSONType) -> struct_pb2.Value: if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj): obj = obj.dict() msg = struct_pb2.Value() - # To handle None cases. - if obj is not None: - from google.protobuf.json_format import ParseDict - - if isinstance(obj, (dict, str, list, float, int, bool)): - # ParseDict handles google.protobuf.Struct type - # directly if given object has a supported type - ParseDict(obj, msg) - else: - # If given object doesn't have a supported type, we will - # use given JSON encoder to convert it to dictionary - # and then parse it to google.protobuf.Struct. - # Note that if a custom JSON encoder is used, it mustn't - # take any arguments. - ParseDict(self._json_encoder().default(obj), msg) + return parse_dict_to_proto(obj, msg, json_encoder=self._json_encoder) + + +def parse_dict_to_proto( + obj: JSONType, + msg: _message.Message, + json_encoder: type[json.JSONEncoder] = DefaultJsonEncoder, +) -> t.Any: + if obj is None: + # this function is an identity op for the msg if obj is None. return msg + + from google.protobuf.json_format import ParseDict + + if isinstance(obj, (dict, str, list, float, int, bool)): + # ParseDict handles google.protobuf.Struct type + # directly if given object has a supported type + ParseDict(obj, msg) + else: + # If given object doesn't have a supported type, we will + # use given JSON encoder to convert it to dictionary + # and then parse it to google.protobuf.Struct. + # Note that if a custom JSON encoder is used, it mustn't + # take any arguments. + ParseDict(json_encoder().default(obj), msg) + return msg diff --git a/src/bentoml/_internal/server/grpc/__init__.py b/src/bentoml/_internal/server/grpc/__init__.py index df277aa1d9..e69de29bb2 100644 --- a/src/bentoml/_internal/server/grpc/__init__.py +++ b/src/bentoml/_internal/server/grpc/__init__.py @@ -1,4 +0,0 @@ -from .server import Server -from .servicer import Servicer - -__all__ = ["Server", "Servicer"] diff --git a/src/bentoml/_internal/server/grpc/server.py b/src/bentoml/_internal/server/grpc/server.py deleted file mode 100644 index b96d87506e..0000000000 --- a/src/bentoml/_internal/server/grpc/server.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -import os -import sys -import typing as t -import asyncio -import logging -from typing import TYPE_CHECKING -from concurrent.futures import ThreadPoolExecutor - -from simple_di import inject -from simple_di import Provide - -from bentoml.grpc.utils import import_grpc -from bentoml.grpc.utils import import_generated_stubs - -from ...utils import LazyLoader -from ...utils import cached_property -from ...utils import resolve_user_filepath -from ...configuration.containers import BentoMLContainer - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - import grpc - from grpc import aio - from grpc_health.v1 import health_pb2 as pb_health - from grpc_health.v1 import health_pb2_grpc as services_health - - from bentoml.grpc.v1 import service_pb2_grpc as services - - from .servicer import Servicer -else: - grpc, aio = import_grpc() - _, services = import_generated_stubs() - health_exception_msg = "'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'." - pb_health = LazyLoader( - "pb_health", - globals(), - "grpc_health.v1.health_pb2", - exc_msg=health_exception_msg, - ) - services_health = LazyLoader( - "services_health", - globals(), - "grpc_health.v1.health_pb2_grpc", - exc_msg=health_exception_msg, - ) - - -def _load_from_file(p: str) -> bytes: - rp = resolve_user_filepath(p, ctx=None) - with open(rp, "rb") as f: - return f.read() - - -# NOTE: we are using the internal aio._server.Server (which is initialized with aio.server) -class Server(aio._server.Server): - """An async implementation of a gRPC server.""" - - @inject - def __init__( - self, - servicer: Servicer, - bind_address: str, - max_message_length: int - | None = Provide[BentoMLContainer.grpc.max_message_length], - maximum_concurrent_rpcs: int - | None = Provide[BentoMLContainer.grpc.maximum_concurrent_rpcs], - enable_reflection: bool = False, - enable_channelz: bool = False, - max_concurrent_streams: int | None = None, - migration_thread_pool_workers: int = 1, - ssl_certfile: str | None = None, - ssl_keyfile: str | None = None, - ssl_ca_certs: str | None = None, - graceful_shutdown_timeout: float | None = None, - compression: grpc.Compression | None = None, - ): - self.servicer = servicer - self.max_message_length = max_message_length - self.max_concurrent_streams = max_concurrent_streams - self.bind_address = bind_address - self.enable_reflection = enable_reflection - self.enable_channelz = enable_channelz - self.graceful_shutdown_timeout = graceful_shutdown_timeout - self.ssl_certfile = ssl_certfile - self.ssl_keyfile = ssl_keyfile - self.ssl_ca_certs = ssl_ca_certs - - if not bool(self.servicer): - self.servicer.load() - assert self.servicer.loaded - - super().__init__( - # Note that the max_workers are used inside ThreadPoolExecutor. - # This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers. - # Setting it to 1 makes it thread-safe for sync APIs. - thread_pool=ThreadPoolExecutor(max_workers=migration_thread_pool_workers), - generic_handlers=() if self.handlers is None else self.handlers, - interceptors=self.servicer.interceptors_stack, - options=self.options, - # maximum_concurrent_rpcs defines the maximum number of concurrent RPCs this server - # will service before returning RESOURCE_EXHAUSTED status. - # Set to None will indicate no limit. - maximum_concurrent_rpcs=maximum_concurrent_rpcs, - compression=compression, - ) - - @property - def options(self) -> grpc.aio.ChannelArgumentType: - options: grpc.aio.ChannelArgumentType = [] - - if sys.platform != "win32": - # https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h#L294 - # Eventhough GRPC_ARG_ALLOW_REUSEPORT is set to 1 by default, we want still - # want to explicitly set it to 1 so that we can spawn multiple gRPC servers in - # production settings. - options.append(("grpc.so_reuseport", 1)) - if self.max_concurrent_streams: - options.append(("grpc.max_concurrent_streams", self.max_concurrent_streams)) - if self.enable_channelz: - options.append(("grpc.enable_channelz", 1)) - if self.max_message_length: - options.extend( - ( - # grpc.max_message_length this is a deprecated options, for backward compatibility - ("grpc.max_message_length", self.max_message_length), - ("grpc.max_receive_message_length", self.max_message_length), - ("grpc.max_send_message_length", self.max_message_length), - ) - ) - - return tuple(options) - - @property - def handlers(self) -> t.Sequence[grpc.GenericRpcHandler] | None: - # Note that currently BentoML doesn't provide any specific - # handlers for gRPC. If users have any specific handlers, - # BentoML will pass it through to grpc.aio.Server - return self.servicer.bento_service.grpc_handlers - - @cached_property - def loop(self) -> asyncio.AbstractEventLoop: - return asyncio.get_event_loop() - - def run(self) -> None: - try: - self.loop.run_until_complete(self.serve()) - finally: - try: - self.loop.call_soon_threadsafe( - lambda: asyncio.ensure_future(self.shutdown()) - ) - except Exception as e: # pylint: disable=broad-except - raise RuntimeError(f"Server failed unexpectedly: {e}") from None - - def configure_port(self, addr: str): - if self.ssl_certfile: - client_auth = False - ca_cert = None - assert ( - self.ssl_keyfile - ), "'ssl_keyfile' is required when 'ssl_certfile' is provided." - if self.ssl_ca_certs is not None: - client_auth = True - ca_cert = _load_from_file(self.ssl_ca_certs) - server_credentials = grpc.ssl_server_credentials( - ( - ( - _load_from_file(self.ssl_keyfile), - _load_from_file(self.ssl_certfile), - ), - ), - root_certificates=ca_cert, - require_client_auth=client_auth, - ) - - self.add_secure_port(addr, server_credentials) - else: - self.add_insecure_port(addr) - - async def serve(self) -> None: - self.configure_port(self.bind_address) - await self.startup() - await self.wait_for_termination() - - async def startup(self) -> None: - from bentoml.exceptions import MissingDependencyException - - # Running on_startup callback. - await self.servicer.startup() - # register bento servicer - services.add_BentoServiceServicer_to_server(self.servicer.bento_servicer, self) - services_health.add_HealthServicer_to_server( - self.servicer.health_servicer, self - ) - - service_names = self.servicer.service_names - # register custom servicer - for ( - user_servicer, - add_servicer_fn, - user_service_names, - ) in self.servicer.mount_servicers: - add_servicer_fn(user_servicer(), self) - service_names += tuple(user_service_names) - if self.enable_channelz: - try: - from grpc_channelz.v1 import channelz - except ImportError: - raise MissingDependencyException( - "'--debug' is passed, which requires 'grpcio-channelz' to be installed. Install with 'pip install bentoml[grpc-channelz]'." - ) from None - if "GRPC_TRACE" not in os.environ: - logger.debug( - "channelz is enabled, while GRPC_TRACE is not set. No channel tracing will be recorded." - ) - channelz.add_channelz_servicer(self) - if self.enable_reflection: - try: - # reflection is required for health checking to work. - from grpc_reflection.v1alpha import reflection - except ImportError: - raise MissingDependencyException( - "reflection is enabled, which requires 'grpcio-reflection' to be installed. Install with 'pip install bentoml[grpc-reflection]'." - ) from None - service_names += (reflection.SERVICE_NAME,) - reflection.enable_server_reflection(service_names, self) - # mark all services as healthy - for service in service_names: - await self.servicer.health_servicer.set( - service, pb_health.HealthCheckResponse.SERVING # type: ignore (no types available) - ) - await self.start() - - async def shutdown(self): - # Running on_startup callback. - await self.servicer.shutdown() - await self.stop(grace=self.graceful_shutdown_timeout) - await self.servicer.health_servicer.enter_graceful_shutdown() - self.loop.stop() diff --git a/src/bentoml/_internal/server/grpc/servicer/__init__.py b/src/bentoml/_internal/server/grpc/servicer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py new file mode 100644 index 0000000000..7886d828ee --- /dev/null +++ b/src/bentoml/_internal/server/grpc/servicer/v1/__init__.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import sys +import asyncio +import logging +from typing import TYPE_CHECKING + +import anyio + +from ......exceptions import InvalidArgument +from ......exceptions import BentoMLException +from ......grpc.utils import import_grpc +from ......grpc.utils import grpc_status_code +from ......grpc.utils import validate_proto_fields +from ......grpc.utils import import_generated_stubs + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) + + import grpc + + from ......grpc.v1 import service_pb2 as pb + from ......grpc.v1 import service_pb2_grpc as services + from ......grpc.types import BentoServicerContext + from .....service.service import Service +else: + grpc, _ = import_grpc() + pb, services = import_generated_stubs(version="v1") + + +def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: + # gRPC will always send a POST request. + logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) + + +def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + """ + This is the actual implementation of BentoServicer. + Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call + """ + + class BentoServiceImpl(services.BentoServiceServicer): + """An asyncio implementation of BentoService servicer.""" + + async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self: services.BentoServiceServicer, + request: pb.Request, + context: BentoServicerContext, + ) -> pb.Response | None: + if request.api_name not in service.apis: + raise InvalidArgument( + f"given 'api_name' is not defined in {service.name}", + ) from None + + api = service.apis[request.api_name] + response = pb.Response() + + # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved. + # This is important so that we know the order of fields to process. + # We will use fields descriptor to determine how to process that request. + try: + # we will check if the given fields list contains a pb.Multipart. + input_proto = getattr( + request, + validate_proto_fields(request.WhichOneof("content"), api.input), + ) + input_data = await api.input.from_proto(input_proto) + if asyncio.iscoroutinefunction(api.func): + if api.multi_input: + output = await api.func(**input_data) + else: + output = await api.func(input_data) + else: + if api.multi_input: + output = await anyio.to_thread.run_sync(api.func, **input_data) + else: + output = await anyio.to_thread.run_sync(api.func, input_data) + res = await api.output.to_proto(output) + # TODO(aarnphm): support multiple proto fields + response = pb.Response(**{api.output._proto_fields[0]: res}) + except BentoMLException as e: + log_exception(request, sys.exc_info()) + await context.abort(code=grpc_status_code(e), details=e.message) + except (RuntimeError, TypeError, NotImplementedError): + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="A runtime error has occurred, see stacktrace from logs.", + ) + except Exception: # pylint: disable=broad-except + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", + ) + return response + + return BentoServiceImpl() diff --git a/src/bentoml/_internal/server/grpc/servicer.py b/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py similarity index 53% rename from src/bentoml/_internal/server/grpc/servicer.py rename to src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py index b96ca7b8ea..4fd00d36da 100644 --- a/src/bentoml/_internal/server/grpc/servicer.py +++ b/src/bentoml/_internal/server/grpc/servicer/v1alpha1/__init__.py @@ -1,22 +1,18 @@ from __future__ import annotations import sys -import typing as t import asyncio import logging from typing import TYPE_CHECKING -from inspect import isawaitable import anyio -from bentoml.grpc.utils import import_grpc -from bentoml.grpc.utils import grpc_status_code -from bentoml.grpc.utils import validate_proto_fields -from bentoml.grpc.utils import import_generated_stubs - -from ...utils import LazyLoader -from ....exceptions import InvalidArgument -from ....exceptions import BentoMLException +from ......exceptions import InvalidArgument +from ......exceptions import BentoMLException +from ......grpc.utils import import_grpc +from ......grpc.utils import grpc_status_code +from ......grpc.utils import validate_proto_fields +from ......grpc.utils import import_generated_stubs logger = logging.getLogger(__name__) @@ -24,31 +20,14 @@ from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) import grpc - from grpc import aio - from grpc_health.v1 import health - from typing_extensions import Self - - from bentoml.grpc.v1 import service_pb2 as pb - from bentoml.grpc.v1 import service_pb2_grpc as services - from bentoml.grpc.types import Interceptors - from bentoml.grpc.types import AddServicerFn - from bentoml.grpc.types import ServicerClass - from bentoml.grpc.types import BentoServicerContext - - from ...service.service import Service + from ......grpc.types import BentoServicerContext + from ......grpc.v1alpha1 import service_pb2 as pb + from ......grpc.v1alpha1 import service_pb2_grpc as services + from .....service.service import Service else: - pb, services = import_generated_stubs() - grpc, aio = import_grpc() - health = LazyLoader( - "health", - globals(), - "grpc_health.v1.health", - exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", - ) - containers = LazyLoader( - "containers", globals(), "google.protobuf.internal.containers" - ) + grpc, _ = import_grpc() + pb, services = import_generated_stubs(version="v1alpha1") def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: @@ -56,61 +35,6 @@ def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) -class Servicer: - """Create an instance of gRPC Servicer.""" - - def __init__( - self: Self, - service: Service, - on_startup: t.Sequence[t.Callable[[], t.Any]] | None = None, - on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None, - mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]] - | None = None, - interceptors: Interceptors | None = None, - ) -> None: - self.bento_service = service - - self.on_startup = [] if not on_startup else list(on_startup) - self.on_shutdown = [] if not on_shutdown else list(on_shutdown) - self.mount_servicers = [] if not mount_servicers else list(mount_servicers) - self.interceptors = [] if not interceptors else list(interceptors) - self.loaded = False - - def load(self): - assert not self.loaded - - self.interceptors_stack = self.build_interceptors_stack() - - self.bento_servicer = create_bento_servicer(self.bento_service) - - # Create a health check servicer. We use the non-blocking implementation - # to avoid thread starvation. - self.health_servicer = health.aio.HealthServicer() - - self.service_names = tuple( - service.full_name for service in pb.DESCRIPTOR.services_by_name.values() - ) + (health.SERVICE_NAME,) - self.loaded = True - - def build_interceptors_stack(self) -> list[aio.ServerInterceptor]: - return list(map(lambda x: x(), self.interceptors)) - - async def startup(self): - for handler in self.on_startup: - out = handler() - if isawaitable(out): - await out - - async def shutdown(self): - for handler in self.on_shutdown: - out = handler() - if isawaitable(out): - await out - - def __bool__(self): - return self.loaded - - def create_bento_servicer(service: Service) -> services.BentoServiceServicer: """ This is the actual implementation of BentoServicer. @@ -121,7 +45,7 @@ class BentoServiceImpl(services.BentoServiceServicer): """An asyncio implementation of BentoService servicer.""" async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method - self, + self: services.BentoServiceServicer, request: pb.Request, context: BentoServicerContext, ) -> pb.Response | None: diff --git a/src/bentoml/_internal/server/grpc_app.py b/src/bentoml/_internal/server/grpc_app.py index cc6df9e624..3da154c5ee 100644 --- a/src/bentoml/_internal/server/grpc_app.py +++ b/src/bentoml/_internal/server/grpc_app.py @@ -1,47 +1,135 @@ from __future__ import annotations +import os +import sys import typing as t import asyncio +import inspect import logging from typing import TYPE_CHECKING from functools import partial +from concurrent.futures import ThreadPoolExecutor from simple_di import inject from simple_di import Provide +from bentoml.grpc.utils import import_grpc +from bentoml.grpc.utils import import_generated_stubs + +from ..utils import LazyLoader +from ..utils import cached_property +from ..utils import resolve_user_filepath +from ...grpc.utils import LATEST_PROTOCOL_VERSION from ..configuration.containers import BentoMLContainer logger = logging.getLogger(__name__) if TYPE_CHECKING: - - from bentoml.grpc.types import Interceptors + import grpc + from grpc import aio + from grpc_health.v1 import health + from grpc_health.v1 import health_pb2 as pb_health + from grpc_health.v1 import health_pb2_grpc as services_health from ..service import Service - from .grpc.servicer import Servicer + from ...grpc.types import Interceptors OnStartup = list[t.Callable[[], t.Union[None, t.Coroutine[t.Any, t.Any, None]]]] - -class GRPCAppFactory: - """ - GRPCApp creates an async gRPC API server based on APIs defined with a BentoService via BentoService#apis. - This is a light wrapper around GRPCServer with addition to `on_startup` and `on_shutdown` hooks. - - Note that even though the code are similar with BaseAppFactory, gRPC protocol is different from ASGI. - """ +else: + grpc, aio = import_grpc() + health_exception_msg = "'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'." + pb_health = LazyLoader( + "pb_health", + globals(), + "grpc_health.v1.health_pb2", + exc_msg=health_exception_msg, + ) + services_health = LazyLoader( + "services_health", + globals(), + "grpc_health.v1.health_pb2_grpc", + exc_msg=health_exception_msg, + ) + health = LazyLoader( + "health", + globals(), + "grpc_health.v1.health", + exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", + ) + + +def _load_from_file(p: str) -> bytes: + rp = resolve_user_filepath(p, ctx=None) + with open(rp, "rb") as f: + return f.read() + + +# NOTE: we are using the internal aio._server.Server (which is initialized with aio.server) +class Server(aio._server.Server): + """An async implementation of a gRPC server.""" @inject def __init__( self, bento_service: Service, - *, - enable_metrics: bool = Provide[ - BentoMLContainer.api_server_config.metrics.enabled - ], - ) -> None: + bind_address: str, + max_message_length: int + | None = Provide[BentoMLContainer.grpc.max_message_length], + maximum_concurrent_rpcs: int + | None = Provide[BentoMLContainer.grpc.maximum_concurrent_rpcs], + enable_reflection: bool = False, + enable_channelz: bool = False, + max_concurrent_streams: int | None = None, + migration_thread_pool_workers: int = 1, + ssl_certfile: str | None = None, + ssl_keyfile: str | None = None, + ssl_ca_certs: str | None = None, + graceful_shutdown_timeout: float | None = None, + compression: grpc.Compression | None = None, + protocol_version: str = LATEST_PROTOCOL_VERSION, + ): + pb, _ = import_generated_stubs(protocol_version) + self.bento_service = bento_service - self.enable_metrics = enable_metrics + self.servicer = bento_service.get_grpc_servicer(protocol_version) + + # options + self.max_message_length = max_message_length + self.max_concurrent_streams = max_concurrent_streams + self.bind_address = bind_address + self.enable_reflection = enable_reflection + self.enable_channelz = enable_channelz + self.graceful_shutdown_timeout = graceful_shutdown_timeout + self.ssl_certfile = ssl_certfile + self.ssl_keyfile = ssl_keyfile + self.ssl_ca_certs = ssl_ca_certs + self.protocol_version = protocol_version + + # Create a health check servicer. We use the non-blocking implementation + # to avoid thread starvation. + self.health_servicer = health.aio.HealthServicer() + + self.mount_servicers = self.bento_service.mount_servicers + + self.service_names = tuple( + service.full_name for service in pb.DESCRIPTOR.services_by_name.values() + ) + (health.SERVICE_NAME,) + + super().__init__( + # Note that the max_workers are used inside ThreadPoolExecutor. + # This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers. + # Setting it to 1 makes it thread-safe for sync APIs. + thread_pool=ThreadPoolExecutor(max_workers=migration_thread_pool_workers), + generic_handlers=() if self.handlers is None else self.handlers, + interceptors=list(map(lambda x: x(), self.interceptors)), + options=self.options, + # maximum_concurrent_rpcs defines the maximum number of concurrent RPCs this server + # will service before returning RESOURCE_EXHAUSTED status. + # Set to None will indicate no limit. + maximum_concurrent_rpcs=maximum_concurrent_rpcs, + compression=compression, + ) @inject async def wait_for_runner_ready( @@ -73,54 +161,48 @@ async def wait_for_runner_ready( logger.info("All runners ready.") @property - def on_startup(self) -> OnStartup: - on_startup: OnStartup = [self.bento_service.on_grpc_server_startup] - if BentoMLContainer.development_mode.get(): - for runner in self.bento_service.runners: - on_startup.append(partial(runner.init_local, quiet=True)) - else: - for runner in self.bento_service.runners: - on_startup.append(runner.init_client) - - on_startup.append(self.wait_for_runner_ready) - return on_startup - - @property - def on_shutdown(self) -> list[t.Callable[[], None]]: - on_shutdown = [self.bento_service.on_grpc_server_shutdown] - for runner in self.bento_service.runners: - on_shutdown.append(runner.destroy) - - return on_shutdown - - def __call__(self) -> Servicer: - from .grpc import Servicer - - return Servicer( - self.bento_service, - on_startup=self.on_startup, - on_shutdown=self.on_shutdown, - mount_servicers=self.bento_service.mount_servicers, - interceptors=self.interceptors, - ) + def options(self) -> grpc.aio.ChannelArgumentType: + options: grpc.aio.ChannelArgumentType = [] + + if sys.platform != "win32": + # https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h#L294 + # Eventhough GRPC_ARG_ALLOW_REUSEPORT is set to 1 by default, we want still + # want to explicitly set it to 1 so that we can spawn multiple gRPC servers in + # production settings. + options.append(("grpc.so_reuseport", 1)) + if self.max_concurrent_streams: + options.append(("grpc.max_concurrent_streams", self.max_concurrent_streams)) + if self.enable_channelz: + options.append(("grpc.enable_channelz", 1)) + if self.max_message_length: + options.extend( + ( + # grpc.max_message_length this is a deprecated options, for backward compatibility + ("grpc.max_message_length", self.max_message_length), + ("grpc.max_receive_message_length", self.max_message_length), + ("grpc.max_send_message_length", self.max_message_length), + ) + ) + + return tuple(options) @property def interceptors(self) -> Interceptors: # Note that order of interceptors is important here. - from bentoml.grpc.interceptors.opentelemetry import ( + from ...grpc.interceptors.opentelemetry import ( AsyncOpenTelemetryServerInterceptor, ) interceptors: Interceptors = [AsyncOpenTelemetryServerInterceptor] - if self.enable_metrics: - from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor + if BentoMLContainer.api_server_config.metrics.enabled.get(): + from ...grpc.interceptors.prometheus import PrometheusServerInterceptor interceptors.append(PrometheusServerInterceptor) if BentoMLContainer.api_server_config.logging.access.enabled.get(): - from bentoml.grpc.interceptors.access import AccessLogServerInterceptor + from ...grpc.interceptors.access import AccessLogServerInterceptor access_logger = logging.getLogger("bentoml.access") if access_logger.getEffectiveLevel() <= logging.INFO: @@ -130,3 +212,140 @@ def interceptors(self) -> Interceptors: interceptors.extend(self.bento_service.interceptors) return interceptors + + @property + def handlers(self) -> t.Sequence[grpc.GenericRpcHandler] | None: + # Note that currently BentoML doesn't provide any specific + # handlers for gRPC. If users have any specific handlers, + # BentoML will pass it through to grpc.aio.Server + return self.bento_service.grpc_handlers + + @cached_property + def loop(self) -> asyncio.AbstractEventLoop: + return asyncio.get_event_loop() + + def run(self) -> None: + try: + self.loop.run_until_complete(self.serve()) + finally: + try: + self.loop.call_soon_threadsafe( + lambda: asyncio.ensure_future(self.shutdown()) + ) + except Exception as e: # pylint: disable=broad-except + raise RuntimeError(f"Server failed unexpectedly: {e}") from None + + def configure_port(self, addr: str): + if self.ssl_certfile: + client_auth = False + ca_cert = None + assert ( + self.ssl_keyfile + ), "'ssl_keyfile' is required when 'ssl_certfile' is provided." + if self.ssl_ca_certs is not None: + client_auth = True + ca_cert = _load_from_file(self.ssl_ca_certs) + server_credentials = grpc.ssl_server_credentials( + ( + ( + _load_from_file(self.ssl_keyfile), + _load_from_file(self.ssl_certfile), + ), + ), + root_certificates=ca_cert, + require_client_auth=client_auth, + ) + + self.add_secure_port(addr, server_credentials) + else: + self.add_insecure_port(addr) + + async def serve(self) -> None: + self.configure_port(self.bind_address) + await self.startup() + await self.wait_for_termination() + + @property + def on_startup(self) -> OnStartup: + on_startup: OnStartup = [self.bento_service.on_grpc_server_startup] + if BentoMLContainer.development_mode.get(): + for runner in self.bento_service.runners: + on_startup.append(partial(runner.init_local, quiet=True)) + else: + for runner in self.bento_service.runners: + on_startup.append(runner.init_client) + + on_startup.append(self.wait_for_runner_ready) + return on_startup + + async def startup(self) -> None: + from ...exceptions import MissingDependencyException + + _, services = import_generated_stubs(self.protocol_version) + + # Running on_startup callback. + for handler in self.on_startup: + out = handler() + if inspect.isawaitable(out): + await out + + # register bento servicer + services.add_BentoServiceServicer_to_server(self.servicer, self) + services_health.add_HealthServicer_to_server(self.health_servicer, self) + + service_names = self.service_names + # register custom servicer + for ( + user_servicer, + add_servicer_fn, + user_service_names, + ) in self.mount_servicers: + add_servicer_fn(user_servicer(), self) + service_names += tuple(user_service_names) + if self.enable_channelz: + try: + from grpc_channelz.v1 import channelz + except ImportError: + raise MissingDependencyException( + "'--debug' is passed, which requires 'grpcio-channelz' to be installed. Install with 'pip install bentoml[grpc-channelz]'." + ) from None + if "GRPC_TRACE" not in os.environ: + logger.debug( + "channelz is enabled, while GRPC_TRACE is not set. No channel tracing will be recorded." + ) + channelz.add_channelz_servicer(self) + if self.enable_reflection: + try: + # reflection is required for health checking to work. + from grpc_reflection.v1alpha import reflection + except ImportError: + raise MissingDependencyException( + "reflection is enabled, which requires 'grpcio-reflection' to be installed. Install with 'pip install bentoml[grpc-reflection]'." + ) from None + service_names += (reflection.SERVICE_NAME,) + reflection.enable_server_reflection(service_names, self) + # mark all services as healthy + for service in service_names: + await self.health_servicer.set( + service, pb_health.HealthCheckResponse.SERVING # type: ignore (no types available) + ) + await self.start() + + @property + def on_shutdown(self) -> list[t.Callable[[], None]]: + on_shutdown = [self.bento_service.on_grpc_server_shutdown] + for runner in self.bento_service.runners: + on_shutdown.append(runner.destroy) + + return on_shutdown + + async def shutdown(self): + # Running on_startup callback. + for handler in self.on_shutdown: + out = handler() + if inspect.isawaitable(out): + await out + + await self.stop(grace=self.graceful_shutdown_timeout) + await self.health_servicer.enter_graceful_shutdown() + self.loop.stop() diff --git a/src/bentoml/_internal/service/service.py b/src/bentoml/_internal/service/service.py index c905f81801..e290b78299 100644 --- a/src/bentoml/_internal/service/service.py +++ b/src/bentoml/_internal/service/service.py @@ -2,6 +2,7 @@ import typing as t import logging +import importlib from typing import TYPE_CHECKING from functools import partial @@ -13,6 +14,7 @@ from ..models import Model from ..runner import Runner from ...grpc.utils import import_grpc +from ...grpc.utils import LATEST_PROTOCOL_VERSION from ..bento.bento import get_default_svc_readme from .inference_api import InferenceAPI from ..io_descriptors import IODescriptor @@ -25,8 +27,9 @@ from .. import external_typing as ext from ..bento import Bento - from ..server.grpc.servicer import Servicer + from ...grpc.v1 import service_pb2_grpc as services from .openapi.specification import OpenAPISpecification + else: grpc, _ = import_grpc() @@ -220,11 +223,26 @@ def on_grpc_server_startup(self) -> None: def on_grpc_server_shutdown(self) -> None: pass - @property - def grpc_servicer(self) -> Servicer: - from ..server.grpc_app import GRPCAppFactory + def get_grpc_servicer( + self, protocol_version: str = LATEST_PROTOCOL_VERSION + ) -> services.BentoServiceServicer: + """ + Return a gRPC servicer instance for this service. + + Args: + protocol_version: The protocol version to use for the gRPC servicer. + + Returns: + A bento gRPC servicer implementation. + """ + return importlib.import_module( + f".grpc.servicer.{protocol_version}", + package="bentoml._internal.server", + ).create_bento_servicer(self) - return GRPCAppFactory(self)() + @property + def grpc_servicer(self): + return self.get_grpc_servicer(protocol_version=LATEST_PROTOCOL_VERSION) @property def asgi_app(self) -> "ext.ASGIApp": diff --git a/src/bentoml/grpc/utils/__init__.py b/src/bentoml/grpc/utils/__init__.py index a0a998334f..f5ba32f528 100644 --- a/src/bentoml/grpc/utils/__init__.py +++ b/src/bentoml/grpc/utils/__init__.py @@ -10,6 +10,7 @@ from bentoml.exceptions import InvalidArgument from bentoml.grpc.utils._import_hook import import_grpc from bentoml.grpc.utils._import_hook import import_generated_stubs +from bentoml.grpc.utils._import_hook import LATEST_PROTOCOL_VERSION if TYPE_CHECKING: from enum import Enum @@ -36,6 +37,7 @@ "import_generated_stubs", "import_grpc", "validate_proto_fields", + "LATEST_PROTOCOL_VERSION", ] logger = logging.getLogger(__name__) diff --git a/src/bentoml/grpc/utils/_import_hook.py b/src/bentoml/grpc/utils/_import_hook.py index 29b33eac70..147f0c2921 100644 --- a/src/bentoml/grpc/utils/_import_hook.py +++ b/src/bentoml/grpc/utils/_import_hook.py @@ -5,9 +5,11 @@ if TYPE_CHECKING: import types +LATEST_PROTOCOL_VERSION = "v1" + def import_generated_stubs( - version: str = "v1", + version: str = LATEST_PROTOCOL_VERSION, file: str = "service.proto", ) -> tuple[types.ModuleType, types.ModuleType]: """ diff --git a/src/bentoml/serve.py b/src/bentoml/serve.py index 1abb23c025..99d7158dca 100644 --- a/src/bentoml/serve.py +++ b/src/bentoml/serve.py @@ -16,6 +16,7 @@ from simple_di import inject from simple_di import Provide +from .grpc.utils import LATEST_PROTOCOL_VERSION from ._internal.utils import experimental from ._internal.configuration.containers import BentoMLContainer @@ -490,6 +491,7 @@ def serve_grpc_development( reload: bool = False, channelz: bool = Provide[BentoMLContainer.grpc.channelz.enabled], reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled], + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> None: prometheus_dir = ensure_prometheus_dir() @@ -539,6 +541,8 @@ def serve_grpc_development( "--prometheus-dir", prometheus_dir, *ssl_args, + "--protocol-version", + protocol_version, ] if reflection: @@ -668,6 +672,7 @@ def serve_grpc_production( | None = Provide[BentoMLContainer.grpc.max_concurrent_streams], channelz: bool = Provide[BentoMLContainer.grpc.channelz.enabled], reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled], + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> None: prometheus_dir = ensure_prometheus_dir() @@ -812,6 +817,8 @@ def serve_grpc_production( "--worker-id", "$(CIRCUS.WID)", *ssl_args, + "--protocol-version", + protocol_version, ] if reflection: args.append("--enable-reflection") diff --git a/src/bentoml/start.py b/src/bentoml/start.py index df85491fa8..8cb918f8ec 100644 --- a/src/bentoml/start.py +++ b/src/bentoml/start.py @@ -10,6 +10,7 @@ from simple_di import inject from simple_di import Provide +from .grpc.utils import LATEST_PROTOCOL_VERSION from ._internal.configuration.containers import BentoMLContainer logger = logging.getLogger(__name__) @@ -241,6 +242,7 @@ def start_grpc_server( ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile], ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile], ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs], + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> None: from .serve import ensure_prometheus_dir @@ -297,6 +299,8 @@ def start_grpc_server( "--worker-id", "$(CIRCUS.WID)", *ssl_args, + "--protocol-version", + protocol_version, ] if reflection: args.append("--enable-reflection") diff --git a/src/bentoml/testing/grpc/__init__.py b/src/bentoml/testing/grpc/__init__.py index ad1f68d5d0..057b56331e 100644 --- a/src/bentoml/testing/grpc/__init__.py +++ b/src/bentoml/testing/grpc/__init__.py @@ -1,19 +1,20 @@ from __future__ import annotations import typing as t +import importlib import traceback from typing import TYPE_CHECKING from contextlib import ExitStack from contextlib import asynccontextmanager -from bentoml.exceptions import BentoMLException -from bentoml.grpc.utils import import_grpc -from bentoml.grpc.utils import import_generated_stubs -from bentoml._internal.utils import LazyLoader -from bentoml._internal.utils import reserve_free_port -from bentoml._internal.utils import cached_contextmanager -from bentoml._internal.utils import add_experimental_docstring -from bentoml._internal.server.grpc.servicer import create_bento_servicer +from ...exceptions import BentoMLException +from ...grpc.utils import import_grpc +from ...grpc.utils import import_generated_stubs +from ...grpc.utils import LATEST_PROTOCOL_VERSION +from ..._internal.utils import LazyLoader +from ..._internal.utils import reserve_free_port +from ..._internal.utils import cached_contextmanager +from ..._internal.utils import add_experimental_docstring if TYPE_CHECKING: import grpc @@ -23,9 +24,9 @@ from grpc.aio._channel import Channel from google.protobuf.message import Message - from bentoml.grpc.v1 import service_pb2 as pb + from ...grpc.v1 import service_pb2 as pb + from ..._internal.service import Service else: - pb, _ = import_generated_stubs() grpc, aio = import_grpc() # pylint: disable=E1111 np = LazyLoader("np", globals(), "numpy") @@ -35,21 +36,43 @@ "make_pb_ndarray", "create_channel", "make_standalone_server", - "create_bento_servicer", + "create_test_bento_servicer", ] -def randomize_pb_ndarray(shape: tuple[int, ...]) -> pb.NDArray: +def create_test_bento_servicer( + service: Service, + protocol_version: str = LATEST_PROTOCOL_VERSION, +) -> t.Callable[[Service], t.Any]: + try: + module = importlib.import_module( + f".{protocol_version}", package="bentoml._internal.server.grpc.servicer" + ) + return getattr(module, "create_bento_servicer")(service) + except (ImportError, ModuleNotFoundError): + raise BentoMLException( + f"Failed to load servicer implementation for version {protocol_version}" + ) from None + + +def randomize_pb_ndarray( + shape: tuple[int, ...], protocol_version: str = LATEST_PROTOCOL_VERSION +) -> pb.NDArray: + pb, _ = import_generated_stubs(protocol_version) arr: NDArray[np.float32] = t.cast("NDArray[np.float32]", np.random.rand(*shape)) return pb.NDArray( shape=list(shape), dtype=pb.NDArray.DTYPE_FLOAT, float_values=arr.ravel() ) -def make_pb_ndarray(arr: NDArray[t.Any]) -> pb.NDArray: +def make_pb_ndarray( + arr: NDArray[t.Any], protocol_version: str = LATEST_PROTOCOL_VERSION +) -> pb.NDArray: from bentoml._internal.io_descriptors.numpy import npdtype_to_dtypepb_map from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map + pb, _ = import_generated_stubs(protocol_version) + try: fieldpb = npdtype_to_fieldpb_map()[arr.dtype] dtypepb = npdtype_to_dtypepb_map()[arr.dtype] @@ -76,7 +99,7 @@ async def async_client_call( assert_code: grpc.StatusCode | None = None, assert_details: str | None = None, assert_trailing_metadata: aio.Metadata | None = None, - _internal_stubs_version: str = "v1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> pb.Response | None: """ Invoke a given API method via a client. @@ -95,11 +118,12 @@ async def async_client_call( Returns: The response from the server. """ + pb, _ = import_generated_stubs(protocol_version) res: pb.Response | None = None try: Call = channel.unary_unary( - f"/bentoml.grpc.{_internal_stubs_version}.BentoService/Call", + f"/bentoml.grpc.{protocol_version}.BentoService/Call", request_serializer=pb.Request.SerializeToString, response_deserializer=pb.Response.FromString, ) diff --git a/src/bentoml/testing/server.py b/src/bentoml/testing/server.py index b7dd72ca9a..6d2b1d0332 100644 --- a/src/bentoml/testing/server.py +++ b/src/bentoml/testing/server.py @@ -25,6 +25,8 @@ from bentoml._internal.utils import reserve_free_port from bentoml._internal.utils import cached_contextmanager +from ..grpc.utils import LATEST_PROTOCOL_VERSION + if TYPE_CHECKING: from grpc import aio from grpc_health.v1 import health_pb2 as pb_health @@ -75,7 +77,7 @@ async def server_warmup( check_interval: float = 1, popen: subprocess.Popen[t.Any] | None = None, service_name: str | None = None, - _internal_stubs_version: str = "v1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> bool: start_time = time.time() proxy_handler = urllib.request.ProxyHandler({}) @@ -87,9 +89,7 @@ async def server_warmup( try: if service_name is None: - service_name = ( - f"bentoml.grpc.{_internal_stubs_version}.BentoService" - ) + service_name = f"bentoml.grpc.{protocol_version}.BentoService" async with create_channel(host_url) as channel: Check = channel.unary_unary( "/grpc.health.v1.Health/Check", @@ -177,7 +177,7 @@ def containerize( subprocess.call([backend, "rmi", image_tag]) -@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}") +@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}, {protocol_version}") def run_bento_server_container( image_tag: str, config_file: str | None = None, @@ -185,6 +185,7 @@ def run_bento_server_container( timeout: float = 90, host: str = "127.0.0.1", backend: str = "docker", + protocol_version: str = LATEST_PROTOCOL_VERSION, ): """ Launch a bentoml service container from a container, yield the host URL @@ -227,7 +228,13 @@ def run_bento_server_container( try: host_url = f"{host}:{port}" if asyncio.run( - server_warmup(host_url, timeout=timeout, popen=proc, grpc=use_grpc) + server_warmup( + host_url, + timeout=timeout, + popen=proc, + grpc=use_grpc, + protocol_version=protocol_version, + ) ): yield host_url else: @@ -247,6 +254,7 @@ def run_bento_server_standalone( config_file: str | None = None, timeout: float = 90, host: str = "127.0.0.1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ): """ Launch a bentoml service directly by the bentoml CLI, yields the host URL. @@ -277,7 +285,13 @@ def run_bento_server_standalone( try: host_url = f"{host}:{server_port}" assert asyncio.run( - server_warmup(host_url, timeout=timeout, popen=p, grpc=use_grpc) + server_warmup( + host_url, + timeout=timeout, + popen=p, + grpc=use_grpc, + protocol_version=protocol_version, + ) ) yield host_url finally: @@ -302,6 +316,7 @@ def run_bento_server_distributed( use_grpc: bool = False, timeout: float = 90, host: str = "127.0.0.1", + protocol_version: str = LATEST_PROTOCOL_VERSION, ): """ Launch a bentoml service as a simulated distributed environment(Yatai), yields the host URL. @@ -391,7 +406,14 @@ def run_bento_server_distributed( ) try: host_url = f"{host}:{server_port}" - asyncio.run(server_warmup(host_url, timeout=timeout, grpc=use_grpc)) + asyncio.run( + server_warmup( + host_url, + timeout=timeout, + grpc=use_grpc, + protocol_version=protocol_version, + ) + ) yield host_url finally: for p in processes: @@ -404,7 +426,7 @@ def run_bento_server_distributed( @cached_contextmanager( - "{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}" + "{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}, {protocol_version}" ) def host_bento( bento_name: str | Tag | None = None, @@ -417,6 +439,7 @@ def host_bento( host: str = "127.0.0.1", timeout: float = 120, backend: str = "docker", + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> t.Generator[str, None, None]: """ Host a bentoml service, yields the host URL. @@ -473,6 +496,7 @@ def host_bento( use_grpc=use_grpc, host=host, timeout=timeout, + protocol_version=protocol_version, ) as host_url: yield host_url elif deployment_mode == "container": @@ -492,6 +516,7 @@ def host_bento( host=host, timeout=timeout, backend=backend, + protocol_version=protocol_version, ) as host_url: yield host_url elif deployment_mode == "distributed": @@ -501,6 +526,7 @@ def host_bento( use_grpc=use_grpc, host=host, timeout=timeout, + protocol_version=protocol_version, ) as host_url: yield host_url else: diff --git a/src/bentoml_cli/serve.py b/src/bentoml_cli/serve.py index 78d3f4fbd2..65ad548d43 100644 --- a/src/bentoml_cli/serve.py +++ b/src/bentoml_cli/serve.py @@ -13,6 +13,7 @@ def add_serve_command(cli: click.Group) -> None: + from bentoml.grpc.utils import LATEST_PROTOCOL_VERSION from bentoml._internal.log import configure_server_logging from bentoml._internal.configuration.containers import BentoMLContainer @@ -324,6 +325,14 @@ def serve( # type: ignore (unused warning) help="CA certificates file", show_default=True, ) + @click.option( + "-pv", + "--protocol-version", + type=click.Choice(["v1", "v1alpha1"]), + help="Determine the version of generated gRPC stubs to use.", + default=LATEST_PROTOCOL_VERSION, + show_default=True, + ) @add_experimental_docstring def serve_grpc( # type: ignore (unused warning) bento: str, @@ -340,6 +349,7 @@ def serve_grpc( # type: ignore (unused warning) enable_reflection: bool, enable_channelz: bool, max_concurrent_streams: int | None, + protocol_version: str, ): """Start a gRPC BentoServer from a given 🍱 @@ -400,6 +410,7 @@ def serve_grpc( # type: ignore (unused warning) max_concurrent_streams=max_concurrent_streams, reflection=enable_reflection, channelz=enable_channelz, + protocol_version=protocol_version, ) else: from bentoml.serve import serve_grpc_development @@ -417,4 +428,5 @@ def serve_grpc( # type: ignore (unused warning) max_concurrent_streams=max_concurrent_streams, reflection=enable_reflection, channelz=enable_channelz, + protocol_version=protocol_version, ) diff --git a/src/bentoml_cli/start.py b/src/bentoml_cli/start.py index f768a1edab..da6a6e686b 100644 --- a/src/bentoml_cli/start.py +++ b/src/bentoml_cli/start.py @@ -13,6 +13,7 @@ def add_start_command(cli: click.Group) -> None: + from bentoml.grpc.utils import LATEST_PROTOCOL_VERSION from bentoml._internal.utils import add_experimental_docstring from bentoml._internal.configuration.containers import BentoMLContainer @@ -348,6 +349,14 @@ def start_runner_server( # type: ignore (unused warning) default=None, help="CA certificates file", ) + @click.option( + "-pv", + "--protocol-version", + type=click.Choice(["v1", "v1alpha1"]), + help="Determine the version of generated gRPC stubs to use.", + default=LATEST_PROTOCOL_VERSION, + show_default=True, + ) @add_experimental_docstring def start_grpc_server( # type: ignore (unused warning) bento: str, @@ -363,6 +372,7 @@ def start_grpc_server( # type: ignore (unused warning) ssl_ca_certs: str | None, enable_channelz: bool, max_concurrent_streams: int | None, + protocol_version: str, ) -> None: """ Start a gRPC API server standalone. This will be used inside Yatai. @@ -393,4 +403,5 @@ def start_grpc_server( # type: ignore (unused warning) ssl_ca_certs=ssl_ca_certs, channelz=enable_channelz, max_concurrent_streams=max_concurrent_streams, + protocol_version=protocol_version, ) diff --git a/src/bentoml_cli/worker/grpc_api_server.py b/src/bentoml_cli/worker/grpc_api_server.py index 87a8e336eb..0ee9ecca32 100644 --- a/src/bentoml_cli/worker/grpc_api_server.py +++ b/src/bentoml_cli/worker/grpc_api_server.py @@ -70,6 +70,13 @@ default=None, help="CA certificates file", ) +@click.option( + "--protocol-version", + type=click.Choice(["v1", "v1alpha1"]), + help="Determine the version of generated gRPC stubs to use.", + default="v1", + show_default=True, +) def main( bento_identifier: str, host: str, @@ -84,6 +91,7 @@ def main( ssl_certfile: str | None, ssl_keyfile: str | None, ssl_ca_certs: str | None, + protocol_version: str, ): """ Start BentoML API server. @@ -126,12 +134,13 @@ def main( component_context.bento_name = svc.tag.name component_context.bento_version = svc.tag.version or "not available" - from bentoml._internal.server import grpc + from bentoml._internal.server import grpc_app as grpc grpc_options: dict[str, t.Any] = { "bind_address": f"{host}:{port}", "enable_reflection": enable_reflection, "enable_channelz": enable_channelz, + "protocol_version": protocol_version, } if max_concurrent_streams: grpc_options["max_concurrent_streams"] = int(max_concurrent_streams) @@ -142,7 +151,7 @@ def main( if ssl_ca_certs: grpc_options["ssl_ca_certs"] = ssl_ca_certs - grpc.Server(svc.grpc_servicer, **grpc_options).run() + grpc.Server(svc, **grpc_options).run() if __name__ == "__main__": diff --git a/src/bentoml_cli/worker/grpc_dev_api_server.py b/src/bentoml_cli/worker/grpc_dev_api_server.py index 9d277a7732..48acc96732 100644 --- a/src/bentoml_cli/worker/grpc_dev_api_server.py +++ b/src/bentoml_cli/worker/grpc_dev_api_server.py @@ -52,6 +52,13 @@ default=None, help="CA certificates file", ) +@click.option( + "--protocol-version", + type=click.Choice(["v1", "v1alpha1"]), + help="Determine the version of generated gRPC stubs to use.", + default="v1", + show_default=True, +) def main( bento_identifier: str, host: str, @@ -64,6 +71,7 @@ def main( ssl_certfile: str | None, ssl_keyfile: str | None, ssl_ca_certs: str | None, + protocol_version: str, ): import psutil @@ -96,12 +104,13 @@ def main( asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore - from bentoml._internal.server import grpc + from bentoml._internal.server import grpc_app as grpc grpc_options: dict[str, t.Any] = { "bind_address": f"{host}:{port}", "enable_reflection": enable_reflection, "enable_channelz": enable_channelz, + "protocol_version": protocol_version, } if max_concurrent_streams is not None: grpc_options["max_concurrent_streams"] = int(max_concurrent_streams) @@ -112,7 +121,7 @@ def main( if ssl_ca_certs: grpc_options["ssl_ca_certs"] = ssl_ca_certs - grpc.Server(svc.grpc_servicer, **grpc_options).run() + grpc.Server(svc, **grpc_options).run() if __name__ == "__main__": diff --git a/tests/unit/grpc/interceptors/test_access.py b/tests/unit/grpc/interceptors/test_access.py index d07f0607d2..3c47fc6664 100644 --- a/tests/unit/grpc/interceptors/test_access.py +++ b/tests/unit/grpc/interceptors/test_access.py @@ -15,8 +15,8 @@ from bentoml.grpc.utils import import_generated_stubs from bentoml.testing.grpc import create_channel from bentoml.testing.grpc import async_client_call -from bentoml.testing.grpc import create_bento_servicer from bentoml.testing.grpc import make_standalone_server +from bentoml.testing.grpc import create_test_bento_servicer from bentoml._internal.utils import LazyLoader from tests.unit.grpc.conftest import TestServiceServicer from bentoml.grpc.interceptors.access import AccessLogServerInterceptor @@ -29,7 +29,6 @@ from google.protobuf import wrappers_pb2 from bentoml import Service - from bentoml.grpc.v1 import service_pb2_grpc as services from bentoml.grpc.types import Request from bentoml.grpc.types import Response from bentoml.grpc.types import RpcMethodHandler @@ -37,7 +36,6 @@ from bentoml.grpc.types import HandlerCallDetails from bentoml.grpc.types import BentoServicerContext else: - _, services = import_generated_stubs() grpc, aio = import_grpc() wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") @@ -125,7 +123,12 @@ async def test_trailing_metadata(caplog: LogCaptureFixture): @pytest.mark.asyncio @pytest.mark.usefixtures("propagate_logs") -async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: Service): +@pytest.mark.parametrize("protocol_version", ["v1", "v1alpha1"]) +async def test_access_log_exception( + caplog: LogCaptureFixture, simple_service: Service, protocol_version: str +): + _, services = import_generated_stubs(protocol_version) + with make_standalone_server( # we need to also setup opentelemetry interceptor # to make sure the access log is correctly setup. @@ -135,7 +138,7 @@ async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: S ] ) as (server, host_url): services.add_BentoServiceServicer_to_server( - create_bento_servicer(simple_service), server + create_test_bento_servicer(simple_service, protocol_version), server ) try: await server.start() @@ -146,9 +149,10 @@ async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: S channel=channel, data={"text": wrappers_pb2.StringValue(value="asdf")}, assert_code=grpc.StatusCode.INTERNAL, + protocol_version=protocol_version, ) assert ( - "(scheme=http,path=/bentoml.grpc.v1.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)" + f"(scheme=http,path=/bentoml.grpc.{protocol_version}.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)" in caplog.text ) finally: diff --git a/tests/unit/grpc/interceptors/test_prometheus.py b/tests/unit/grpc/interceptors/test_prometheus.py index 237cc0c47c..26da32d70c 100644 --- a/tests/unit/grpc/interceptors/test_prometheus.py +++ b/tests/unit/grpc/interceptors/test_prometheus.py @@ -15,8 +15,8 @@ from bentoml.grpc.utils import import_generated_stubs from bentoml.testing.grpc import create_channel from bentoml.testing.grpc import async_client_call -from bentoml.testing.grpc import create_bento_servicer from bentoml.testing.grpc import make_standalone_server +from bentoml.testing.grpc import create_test_bento_servicer from bentoml._internal.utils import LazyLoader from tests.unit.grpc.conftest import TestServiceServicer from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor @@ -27,10 +27,7 @@ from google.protobuf import wrappers_pb2 from bentoml import Service - from bentoml.grpc.v1 import service_pb2_grpc as services else: - - _, services = import_generated_stubs() wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") grpc, aio = import_grpc() @@ -106,19 +103,23 @@ async def test_empty_metrics(): ("gauge", ["api_name", "service_version", "service_name"]), ], ) +@pytest.mark.parametrize("protocol_version", ["v1", "v1alpha1"]) async def test_metrics_interceptors( simple_service: Service, metric_type: str, parent_set: list[str], + protocol_version: str, ): metrics_client = BentoMLContainer.metrics_client.get() + _, services = import_generated_stubs(protocol_version) + with make_standalone_server(interceptors=[interceptor]) as ( server, host_url, ): services.add_BentoServiceServicer_to_server( - create_bento_servicer(simple_service), server + create_test_bento_servicer(simple_service, protocol_version), server ) try: await server.start() @@ -127,6 +128,7 @@ async def test_metrics_interceptors( "noop_sync", channel=channel, data={"text": wrappers_pb2.StringValue(value="BentoML")}, + protocol_version=protocol_version, ) for m in metrics_client.text_string_to_metric_families(): for sample in m.samples: