diff --git a/grpc-client/python/client.py b/grpc-client/python/client.py index 7275b7f9253..d715573946e 100644 --- a/grpc-client/python/client.py +++ b/grpc-client/python/client.py @@ -1,32 +1,43 @@ +from __future__ import annotations + import asyncio -import grpc +import numpy as np + +from bentoml.client import Client + -from bentoml.grpc.utils import import_generated_stubs +async def arun(client: Client): -pb, services = import_generated_stubs() + res = await client.async_classify(np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.async_classify':\n", res) + res = await client.async_call("classify", np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.async_call':\n", res) -async def run(): - async with grpc.aio.insecure_channel("localhost:3000") as channel: - stub = services.BentoServiceStub(channel) - req = await stub.Call( - request=pb.Request( - api_name="classify", - ndarray=pb.NDArray( - dtype=pb.NDArray.DTYPE_FLOAT, - shape=(1, 4), - float_values=[5.9, 3, 5.1, 1.8], - ), - ) - ) - print(req) +def run(client: Client): + res = client.classify(np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.classify':\n", res) + res = client.call("classify", np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.call(bentoml_api_name='classify')':\n", res) if __name__ == "__main__": - loop = asyncio.new_event_loop() - try: - loop.run_until_complete(run()) - finally: - loop.close() - assert loop.is_closed() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rwa", "--run-with-async", action="store_true", default=False) + parser.add_argument("--grpc", action="store_true", default=False) + args = parser.parse_args() + + c = Client.from_url("localhost:3000", grpc=args.grpc) + + if args.run_with_async: + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(arun(c)) + finally: + loop.close() + assert loop.is_closed() + else: + run(c) diff --git a/pyproject.toml b/pyproject.toml index 867617b82b3..476708f41b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,17 +136,15 @@ grpc = [ # Restrict maximum version due to breaking protobuf 4.21.0 changes # (see https://github.com/protocolbuffers/protobuf/issues/10051) # 3.19.5 is currently breaking on a lot of system. - "protobuf>=3.5.0, <3.20, !=3.19.5", + "protobuf>=3.5.0,<4.0dev,!=3.19.5", # Lowest version that support 3.10. We need to set an upper bound # We can't use 1.48.2 since it depends on 3.19.5 - "grpcio>=1.41.0,!=1.48.2", - # grpcio>=1.48.0 provides a pre-built M1 wheel. - "grpcio>=1.48.0,!=1.48.2;platform_machine=='arm64' and platform_system=='Darwin'", - "grpcio-health-checking>=1.41.0,!=1.48.2", + "grpcio>=1.41.0", + "grpcio-health-checking>=1.41.0", "opentelemetry-instrumentation-grpc==0.35b0", ] -grpc-reflection = ["bentoml[grpc]", "grpcio-reflection>=1.41.0,!=1.48.2"] -grpc-channelz = ["bentoml[grpc]", "grpcio-channelz>=1.41.0,!=1.48.2"] +grpc-reflection = ["bentoml[grpc]", "grpcio-reflection>=1.41.0"] +grpc-channelz = ["bentoml[grpc]", "grpcio-channelz>=1.41.0"] # We kept for compatibility with previous # versions of BentoML. It is discouraged to use this, instead use any # of the above tracing.* extras. @@ -316,7 +314,7 @@ skip_glob = [ ] [tool.pyright] -pythonVersion = "3.10" +pythonVersion = "3.11" include = ["src/", "examples/", "tests/"] exclude = [ 'src/bentoml/_version.py', diff --git a/requirements/tests-requirements.txt b/requirements/tests-requirements.txt index 7c122a661c3..cc46c59a24b 100644 --- a/requirements/tests-requirements.txt +++ b/requirements/tests-requirements.txt @@ -17,6 +17,4 @@ imageio==2.22.4 pyarrow==10.0.1 build[virtualenv]==0.9.0 protobuf==3.19.6 -grpcio>=1.41.0, <1.49, !=1.48.2 -grpcio-health-checking>=1.41.0, <1.49, !=1.48.2 opentelemetry-instrumentation-grpc==0.35b0 diff --git a/scripts/generate_grpc_stubs.sh b/scripts/generate_grpc_stubs.sh index 9e4353964d1..4e8cb3678eb 100755 --- a/scripts/generate_grpc_stubs.sh +++ b/scripts/generate_grpc_stubs.sh @@ -1,7 +1,10 @@ #!/usr/bin/env bash +set -e + GIT_ROOT=$(git rev-parse --show-toplevel) STUBS_GENERATOR="bentoml/stubs-generator" +BASEDIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]:-$0}")" &>/dev/null && pwd 2>/dev/null)" cd "$GIT_ROOT/src" || exit 1 @@ -9,7 +12,7 @@ main() { local VERSION="${1:-v1}" # Use inline heredoc for even faster build # Keeping image as cache should be fine since we only want to generate the stubs. - if [[ $(docker images --filter=reference="$STUBS_GENERATOR" -q) == "" ]] || test "$(git diff --name-only --diff-filter=d -- "$0")"; then + if test "$(git diff --name-only --diff-filter=d -- "$BASEDIR/$(basename "$0")")" || [[ $(docker images --filter=reference="$STUBS_GENERATOR" -q) == "" ]]; then docker buildx build --platform=linux/amd64 -t "$STUBS_GENERATOR" --load -f- . < Self: if "args" not in spec: raise InvalidArgument(f"Missing args key in JSON spec: {spec}") if "has_pydantic_model" in spec["args"] and spec["args"]["has_pydantic_model"]: - logger.warning( + logger.debug( "BentoML does not support loading pydantic models from URLs; output will be a normal dictionary." ) if "has_json_encoder" in spec["args"] and spec["args"]["has_json_encoder"]: - logger.warning( + logger.debug( "BentoML does not support loading JSON encoders from URLs; output will be a normal dictionary." ) diff --git a/src/bentoml/_internal/server/grpc/server.py b/src/bentoml/_internal/server/grpc/server.py index b96d87506ef..4806025d291 100644 --- a/src/bentoml/_internal/server/grpc/server.py +++ b/src/bentoml/_internal/server/grpc/server.py @@ -48,7 +48,7 @@ ) -def _load_from_file(p: str) -> bytes: +def load_from_file(p: str) -> bytes: rp = resolve_user_filepath(p, ctx=None) with open(rp, "rb") as f: return f.read() @@ -164,12 +164,12 @@ def configure_port(self, addr: str): ), "'ssl_keyfile' is required when 'ssl_certfile' is provided." if self.ssl_ca_certs is not None: client_auth = True - ca_cert = _load_from_file(self.ssl_ca_certs) + ca_cert = load_from_file(self.ssl_ca_certs) server_credentials = grpc.ssl_server_credentials( ( ( - _load_from_file(self.ssl_keyfile), - _load_from_file(self.ssl_certfile), + load_from_file(self.ssl_keyfile), + load_from_file(self.ssl_certfile), ), ), root_certificates=ca_cert, diff --git a/src/bentoml/_internal/server/grpc/servicer.py b/src/bentoml/_internal/server/grpc/servicer.py index d87f0d3f182..0a25cda386e 100644 --- a/src/bentoml/_internal/server/grpc/servicer.py +++ b/src/bentoml/_internal/server/grpc/servicer.py @@ -169,6 +169,7 @@ async def Call( else: output = await api.func(input_data) else: + assert api.func is not None if api.multi_input: output = await anyio.to_thread.run_sync(api.func, **input_data) else: @@ -212,10 +213,12 @@ async def ServiceMetadata( name=api.name, docs=api.doc, input=make_descriptor_spec( - api.input.to_spec(), ServiceMetadataResponse + t.cast("SpecDict", api.input.to_spec()), + ServiceMetadataResponse, ), output=make_descriptor_spec( - api.output.to_spec(), ServiceMetadataResponse + t.cast("SpecDict", api.output.to_spec()), + ServiceMetadataResponse, ), ) for api in service.apis.values() diff --git a/src/bentoml/_internal/service/inference_api.py b/src/bentoml/_internal/service/inference_api.py index bffd6126965..9fb9d98b609 100644 --- a/src/bentoml/_internal/service/inference_api.py +++ b/src/bentoml/_internal/service/inference_api.py @@ -26,7 +26,7 @@ class InferenceAPI: def __init__( self, - user_defined_callback: t.Callable[..., t.Any], + user_defined_callback: t.Callable[..., t.Any] | None, input_descriptor: IODescriptor[t.Any], output_descriptor: IODescriptor[t.Any], name: Optional[str], diff --git a/src/bentoml/client.py b/src/bentoml/client.py index e75630f9e92..2a293c7b9a5 100644 --- a/src/bentoml/client.py +++ b/src/bentoml/client.py @@ -3,48 +3,293 @@ import json import typing as t import asyncio +import inspect +import logging import functools +import contextlib from abc import ABC from abc import abstractmethod +from enum import Enum +from typing import TYPE_CHECKING from http.client import HTTPConnection from urllib.parse import urlparse +import attr import aiohttp import starlette.requests import starlette.datastructures +from packaging.version import parse -import bentoml -from bentoml import Service - +from . import io +from . import Service from .exceptions import BentoMLException +from .grpc.utils import import_grpc +from .grpc.utils import parse_method_name +from .grpc.utils import import_generated_stubs +from .grpc.utils import LATEST_PROTOCOL_VERSION +from ._internal.utils import LazyLoader +from ._internal.utils import bentoml_cattr +from ._internal.utils import cached_property +from ._internal.configuration import get_debug_mode +from ._internal.server.grpc.server import load_from_file from ._internal.service.inference_api import InferenceAPI +logger = logging.getLogger(__name__) + +PROTOBUF_EXC_MESSAGE = "'protobuf' is required to use gRPC Client. Install with 'pip install bentoml[grpc]'." +REFLECTION_EXC_MESSAGE = "'grpcio-reflection' is required to use gRPC Client. Install with 'pip install bentoml[grpc-reflection]'." + +if TYPE_CHECKING: + from types import TracebackType + from urllib.parse import ParseResult + + import grpc + from grpc import aio + from google.protobuf import message as _message + from google.protobuf import json_format as _json_format + from google.protobuf import descriptor_pb2 as pb_descriptor + from google.protobuf import descriptor_pool as _descriptor_pool + from google.protobuf import symbol_database as _symbol_database + from grpc_reflection.v1alpha import reflection_pb2 as pb_reflection + from grpc_reflection.v1alpha import reflection_pb2_grpc as services_reflection + + # type hint specific imports. + from google.protobuf.descriptor import MethodDescriptor + from google.protobuf.descriptor import ServiceDescriptor + from google.protobuf.descriptor_pb2 import FileDescriptorProto + from google.protobuf.descriptor_pb2 import MethodDescriptorProto + from google.protobuf.descriptor_pool import DescriptorPool + from google.protobuf.symbol_database import SymbolDatabase + from grpc_reflection.v1alpha.reflection_pb2 import ServiceResponse + from grpc_reflection.v1alpha.reflection_pb2 import ListServiceResponse + from grpc_reflection.v1alpha.reflection_pb2_grpc import ServerReflectionStub + + from .grpc.types import MultiCallable + from ._internal.types import PathType + from ._internal.io_descriptors.base import SpecDict + + if LATEST_PROTOCOL_VERSION == "v1": + from .grpc.v1.service_pb2 import ServiceMetadataResponse + + class ClientCredentials(t.TypedDict): + root_certificates: t.NotRequired[PathType | bytes] + private_key: t.NotRequired[PathType | bytes] + certificate_chain: t.NotRequired[PathType | bytes] + + class RpcMethod(t.TypedDict): + request_streaming: t.Literal[True, False] + response_streaming: bool + input_type: type[t.Any] + output_type: t.NotRequired[type[t.Any]] + handler: MultiCallable + +else: + pb_descriptor = LazyLoader( + "pb_descriptor", + globals(), + "google.protobuf.descriptor_pb2", + exc_msg=PROTOBUF_EXC_MESSAGE, + ) + _descriptor_pool = LazyLoader( + "_descriptor_pool", + globals(), + "google.protobuf.descriptor_pool", + exc_msg=PROTOBUF_EXC_MESSAGE, + ) + _symbol_database = LazyLoader( + "_symbol_database", + globals(), + "google.protobuf.symbol_database", + exc_msg=PROTOBUF_EXC_MESSAGE, + ) + _json_format = LazyLoader( + "_json_format", + globals(), + "google.protobuf.json_format", + exc_msg=PROTOBUF_EXC_MESSAGE, + ) + grpc, aio = import_grpc() + services_reflection = LazyLoader( + "services_reflection", + globals(), + "grpc_reflection.v1alpha.reflection_pb2_grpc", + exc_msg=REFLECTION_EXC_MESSAGE, + ) + pb_reflection = LazyLoader( + "pb_reflection", + globals(), + "grpc_reflection.v1alpha.reflection_pb2", + exc_msg=REFLECTION_EXC_MESSAGE, + ) + ClientCredentials = dict + SpecDict = dict + RpcMethod = dict + + +@attr.define +class ClientConfig: + http: HTTP = attr.field( + default=attr.Factory(lambda self: self.HTTP(), takes_self=True) + ) + grpc: GRPC = attr.field( + default=attr.Factory(lambda self: self.GRPC(), takes_self=True) + ) + + def with_grpc_options(self, **kwargs: t.Any) -> ClientConfig: + _self_grpc_config = kwargs.pop("_self_grpc_config", None) + if not isinstance(_self_grpc_config, self.GRPC): + _self_grpc_config = ClientConfig.GRPC.from_options(**kwargs) + return attr.evolve(self, **{"grpc": _self_grpc_config}) + + def with_http_options(self, **kwargs: t.Any) -> ClientConfig: + _self_http_config = kwargs.pop("_self_http_config", None) + if not isinstance(_self_http_config, self.HTTP): + _self_http_config = ClientConfig.HTTP.from_options(**kwargs) + return attr.evolve(self, **{"http": _self_http_config}) + + @classmethod + def from_options(cls, **kwargs: t.Any) -> ClientConfig: + return bentoml_cattr.structure(kwargs, cls) + + @staticmethod + def from_grpc_options(**kwargs: t.Any) -> GRPC: + return ClientConfig.GRPC.from_options(**kwargs) + + @staticmethod + def from_http_options(**kwargs: t.Any) -> HTTP: + return ClientConfig.HTTP.from_options(**kwargs) + + def unstructure( + self, target: t.Literal["http", "grpc", "default"] = "default" + ) -> dict[str, t.Any]: + if target == "default": + targ = self + elif target == "http": + targ = self.http + elif target == "grpc": + targ = self.grpc + else: + raise ValueError( + f"Invalid target: {target}. Accepted value are 'http', 'grpc', 'default'." + ) + return bentoml_cattr.unstructure(targ) + + @attr.define + class HTTP: + """HTTP ClientConfig. + + .. TODO:: Add HTTP specific options here. + + """ + + # forbid additional keys to prevent typos. + __forbid_extra_keys__ = True + # Don't omit empty field. + __omit_if_default__ = False + + @classmethod + def from_options(cls, **kwargs: t.Any) -> ClientConfig.HTTP: + return bentoml_cattr.structure(kwargs, cls) + + def unstructure(self) -> dict[str, t.Any]: + return ( + ClientConfig() + .with_http_options( + _self_http_config=self, + ) + .unstructure(target="http") + ) + + @attr.define + class GRPC: + """gRPC ClientConfig. + + .. code-block:: python + + from bentoml.client import ClientConfig + from bentoml.client import Client + + config = ClientConfig.from_grpc_options( + ssl=True, + ssl_client_credentials={ + "root_certificates": "path/to/cert.pem", + "private_key": "/path/to/key", + }, + protocol_version="v1alpha1", + ) + client = Client.from_url("localhost:50051", config) + + """ + + # forbid additional keys to prevent typos. + __forbid_extra_keys__ = True + # Don't omit empty field. + __omit_if_default__ = False + + ssl: bool = attr.field(default=False) + channel_options: t.Optional[aio.ChannelArgumentType] = attr.field(default=None) + compression: t.Optional[grpc.Compression] = attr.field(default=None) + ssl_client_credentials: t.Optional[ClientCredentials] = attr.field( + factory=lambda: ClientCredentials() + ) + protocol_version: str = attr.field(default=LATEST_PROTOCOL_VERSION) + interceptors: t.Optional[t.Sequence[aio.ClientInterceptor]] = attr.field( + default=None + ) + + @classmethod + def from_options(cls, **kwargs: t.Any) -> ClientConfig.GRPC: + return bentoml_cattr.structure(kwargs, cls) + + def unstructure(self) -> dict[str, t.Any]: + return ( + ClientConfig() + .with_grpc_options( + _self_grpc_config=self, + ) + .unstructure(target="grpc") + ) + + +if TYPE_CHECKING: + ClientConfigT = ClientConfig | ClientConfig.HTTP | ClientConfig.GRPC + + +_sentinel_svc = Service("sentinel_svc") +_object_setattr = object.__setattr__ + class Client(ABC): server_url: str + _svc: Service - def __init__(self, svc: Service, server_url: str): - self._svc = svc + def __init__(self, svc: Service | None, server_url: str): + self._svc = svc or _sentinel_svc self.server_url = server_url - if len(self._svc.apis) == 0: - raise BentoMLException("No APIs were found when constructing client") - for name, api in self._svc.apis.items(): - if not hasattr(self, name): - setattr( - self, name, functools.partial(self._sync_call, _bentoml_api=api) - ) + if svc is not None and len(svc.apis) == 0: + raise BentoMLException("No APIs were found when constructing client.") - for name, api in self._svc.apis.items(): - if not hasattr(self, f"async_{name}"): - setattr( - self, - f"async_{name}", - functools.partial(self._call, _bentoml_api=api), - ) + # Register service method if given service is not _sentinel_svc + # We only set _sentinel_svc if given protocol is older than v1 + if self._svc is not _sentinel_svc: + for name, api in self._svc.apis.items(): + if not hasattr(self, name): + setattr( + self, name, functools.partial(self._sync_call, _bentoml_api=api) + ) + + if not hasattr(self, f"async_{name}"): + setattr( + self, + f"async_{name}", + functools.partial(self._call, _bentoml_api=api), + ) def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: - return asyncio.run(self.async_call(bentoml_api_name, inp, **kwargs)) + return self._loop.run_until_complete( + self.async_call(bentoml_api_name, inp, **kwargs) + ) async def async_call( self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any @@ -53,10 +298,63 @@ async def async_call( inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs ) + @t.overload + @staticmethod + def from_url( + server_url: str, + config: ClientConfigT | None = ..., + *, + grpc: t.Literal[False] = ..., + ) -> HTTPClient: + ... + + @t.overload + @staticmethod + def from_url( + server_url: str, + config: ClientConfigT | None = ..., + *, + grpc: t.Literal[True] = ..., + ) -> GrpcClient: + ... + + @staticmethod + def from_url( + server_url: str, config: ClientConfigT | None = None, *, grpc: bool = False + ) -> Client: + server_url = server_url if "://" in server_url else "http://" + server_url + client_type = "http" if not grpc else "grpc" + klass = HTTPClient if not grpc else GrpcClient + + if config is None: + config = ClientConfig() + + # First, if config is a ClientConfig that contains both HTTP and gRPC fields, then we use + # grpc_client boolean to determine which configset to use. + # If config is either ClientConfig.HTTP or ClientConfig.GRPC, then we use unstructure for kwargs + kwargs = config.unstructure() + + if isinstance(config, ClientConfig): + # by default we will set the config to HTTP (backward compatibility) + kwargs = config.unstructure(target=client_type) + + try: + return klass._create_client(urlparse(server_url), **kwargs) + except Exception as e: # pylint: disable=broad-except + raise BentoMLException( + f"Failed to create a BentoML client from given URL '{server_url}': {e} ({e.__class__.__name__})" + ) from e + + @cached_property + def _loop(self) -> asyncio.AbstractEventLoop: + return asyncio.get_event_loop() + def _sync_call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any ): - return asyncio.run(self._call(inp, _bentoml_api=_bentoml_api, **kwargs)) + return self._loop.run_until_complete( + self._call(inp, _bentoml_api=_bentoml_api, **kwargs) + ) @abstractmethod async def _call( @@ -65,12 +363,18 @@ async def _call( raise NotImplementedError @staticmethod - def from_url(server_url: str) -> Client: - server_url = server_url if "://" in server_url else "http://" + server_url - url_parts = urlparse(server_url) + @abstractmethod + def _create_client(parsed: ParseResult, **kwargs: t.Any) -> Client: + raise NotImplementedError + - # TODO: SSL and grpc support - conn = HTTPConnection(url_parts.netloc) +class HTTPClient(Client): + @staticmethod + def _create_client(parsed: ParseResult, **kwargs: t.Any) -> HTTPClient: + # TODO: HTTP SSL support + server_url = parsed.netloc + conn = HTTPConnection(server_url) + conn.set_debuglevel(logging.DEBUG if get_debug_mode() else 0) conn.request("GET", "/docs.json") resp = conn.getresponse() openapi_spec = json.load(resp) @@ -96,10 +400,10 @@ def from_url(server_url: str) -> Client: ) dummy_service.apis[meth_spec["x-bentoml-name"]] = InferenceAPI( None, - bentoml.io.from_spec( + io.from_spec( meth_spec["requestBody"]["x-bentoml-io-descriptor"] ), - bentoml.io.from_spec( + io.from_spec( meth_spec["responses"]["200"]["x-bentoml-io-descriptor"] ), name=meth_spec["x-bentoml-name"], @@ -107,13 +411,7 @@ def from_url(server_url: str) -> Client: route=route.lstrip("/"), ) - res = HTTPClient(dummy_service, server_url) - res.server_url = server_url - return res - - -class HTTPClient(Client): - _svc: Service + return HTTPClient(dummy_service, parsed.geturl()) async def _call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any @@ -149,3 +447,587 @@ async def _call( fake_req._headers = headers # type: ignore (request._headers is property) return await api.output.from_http_request(fake_req) + + +# TODO: xDS support +class GrpcClient(Client): + def __init__( + self, + server_url: str, + svc: Service | None = None, + # gRPC specific options + ssl: bool = False, + channel_options: aio.ChannelArgumentType | None = None, + interceptors: t.Sequence[aio.ClientInterceptor] | None = None, + compression: grpc.Compression | None = None, + ssl_client_credentials: ClientCredentials | None = None, + *, + protocol_version: str = LATEST_PROTOCOL_VERSION, + ): + super().__init__(svc, server_url) + + # Call requires an api_name, therefore we need a reserved keyset of self._svc.apis + self._rev_apis = {v: k for k, v in self._svc.apis.items()} + + self._protocol_version = protocol_version + self._compression = compression + self._options = channel_options + self._interceptors = interceptors + self._channel = None + self._credentials = None + if ssl: + assert ( + ssl_client_credentials is not None + ), "'ssl=True' requires 'credentials'" + self._credentials = grpc.ssl_channel_credentials( + **{ + k: load_from_file(v) if isinstance(v, str) else v + for k, v in ssl_client_credentials.items() + } + ) + + self._descriptor_pool: DescriptorPool = _descriptor_pool.Default() + self._symbol_database: SymbolDatabase = _symbol_database.Default() + + self._registered_services: tuple[str, ...] = tuple() + # cached of all available rpc for a given service. + self._service_cache: dict[str, dict[str, RpcMethod]] = {} + # Sets of FileDescriptorProto name to be registered + self._registered_file_name: set[str] = set() + self._reflection_stub: ServerReflectionStub | None = None + + @cached_property + def channel(self): + if not self._channel: + if self._credentials is not None: + self._channel = aio.secure_channel( + self.server_url, + credentials=self._credentials, + options=self._options, + compression=self._compression, + interceptors=self._interceptors, + ) + self._channel = aio.insecure_channel( + self.server_url, + options=self._options, + compression=self._compression, + interceptors=self._interceptors, + ) + return self._channel + + async def get_services(self): + if not self._registered_services: + resp = await self._do_one_request( + pb_reflection.ServerReflectionRequest(list_services="") + ) + assert resp is not None + services: list[ServiceResponse] = resp.list_services_response.service + self._registered_services = tuple([t.cast(str, s.name) for s in services]) + return self._registered_services + + @staticmethod + def make_rpc_method(service_name: str, method: str): + return f"/{service_name}/{method}" + + def _reset_cache(self): + self._registered_services = tuple() + self._registered_file_name.clear() + self._service_cache.clear() + + @property + def _call_rpc_method(self): + return self.make_rpc_method( + f"bentoml.grpc.{self._protocol_version}.BentoService", "Call" + ) + + @cached_property + def _reserved_kw_mapping(self): + return { + "default": f"bentoml.grpc.{self._protocol_version}.BentoService", + "health": "grpc.health.v1.Health", + "reflection": "grpc.reflection.v1alpha.ServerReflection", + } + + async def _exit(self): + try: + if self._channel: + if self._channel.get_state() == grpc.ChannelConnectivity.IDLE: + await self._channel.close() + except AttributeError as e: + logger.error(f"Error closing channel: %s", e, exc_info=e) + raise + + def __enter__(self): + return self.service().__enter__() + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + try: + if exc_type is not None: + self.service().__exit__(exc_type, exc, traceback) + self._loop.run_until_complete(self._exit()) + except Exception as err: # pylint: disable=broad-except + logger.error(f"Exception occurred: %s (%s)", err, exc_type, exc_info=err) + return False + + @contextlib.contextmanager + def service(self, service_name: str = "default"): + stack = contextlib.AsyncExitStack() + + async def close(): + await stack.aclose() + + async def enter(): + res = await stack.enter_async_context( + self.aservice(service_name, _wrap_in_sync=True) + ) + return res + + try: + yield self._loop.run_until_complete(enter()) + finally: + self._loop.run_until_complete(close()) + + async def __aenter__(self): + return await self.aservice().__aenter__() + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + try: + if exc_type is not None: + await self.aservice().__aexit__(exc_type, exc, traceback) + await self._exit() + except Exception as err: # pylint: disable=broad-except + logger.error(f"Exception occurred: %s (%s)", err, exc_type, exc_info=err) + return False + + @contextlib.asynccontextmanager + async def aservice( + self, service_name: str = "default", *, _wrap_in_sync: bool = False + ) -> t.AsyncGenerator[t.Self, None]: + # This is the entrypoint for user to instantiate a client for a given service. + + # default is a special case for BentoService proto. + if service_name in self._reserved_kw_mapping: + service_name = self._reserved_kw_mapping[service_name] + + # we need to clean all cache to setup for sync method + # so that we don't mix-up async and sync function. + if _wrap_in_sync: + self._reset_cache() + + await self.get_services() + + if ( + service_name in self._registered_services + and service_name not in self._service_cache + ): + await self._register_service(service_name) + + if self.channel.get_state() != grpc.ChannelConnectivity.READY: + # create a blocking call to wait til channel is ready. + await self.channel.channel_ready() + + try: + method_meta = self._service_cache[service_name] + except KeyError: + raise ValueError( + f"Failed to find service '{service_name}'. Available: {list(self._service_cache.keys())}" + ) from None + + def _register(method: str): + finaliser = f = functools.partial( + self._invoke, self.make_rpc_method(service_name, method) + ) + if _wrap_in_sync: + # We will have to run the async function in a sync wrapper + @functools.wraps(f) + def wrapper(*args: t.Any, **kwargs: t.Any): + coro = f(*args, **kwargs) + task = asyncio.ensure_future(coro, loop=self._loop) + try: + res = self._loop.run_until_complete(task) + if inspect.isasyncgen(res): + # If this is an async generator, then we need to yield again + async def call(): + return await res.__anext__() + + return self._loop.run_until_complete(call()) + return res + except BaseException: + # Consume all exceptions. + if task.done() and not task.cancelled(): + task.exception() + raise + + finaliser = wrapper + _object_setattr(self, method, finaliser) + + # Register all RPC method. + for method in reversed(method_meta): + _register(method) + + yield self + + async def _register_service(self, service_name: str) -> None: + svc_descriptor: ServiceDescriptor | None = None + try: + svc_descriptor = self._descriptor_pool.FindServiceByName(service_name) + except KeyError: + file_descriptor = await self._find_descriptor_by_symbol(service_name) + await self._add_file_descriptor(file_descriptor) + # try to register from FileDescriptorProto again. + svc_descriptor = self._descriptor_pool.FindServiceByName(service_name) + except Exception as e: # pylint: disable=broad-except + logger.warning( + "Failed to register %s. This might have already been registered.", + service_name, + exc_info=e, + ) + raise + finally: + if svc_descriptor is not None: + self._service_cache[service_name] = self._register_methods( + svc_descriptor + ) + + def _get_rpc_metadata(self, method_name: str) -> RpcMethod: + mn, is_valid = parse_method_name(method_name) + if not is_valid: + raise ValueError( + f"{method_name} is not a valid method name. Make sure to follow the format '/package.ServiceName/MethodName'" + ) + try: + return self._service_cache[mn.fully_qualified_service][mn.method] + except KeyError: + raise BentoMLException( + f"Method '{method_name}' is not registered in current service client." + ) from None + + async def _add_file_descriptor(self, file_descriptor: FileDescriptorProto): + dependencies = file_descriptor.dependency + for deps in dependencies: + if deps not in self._registered_file_name: + d_descriptor = await self._find_descriptor_by_filename(deps) + await self._add_file_descriptor(d_descriptor) + self._registered_file_name.add(deps) + self._descriptor_pool.Add(file_descriptor) + + async def _find_descriptor_by_symbol(self, symbol: str): + req = pb_reflection.ServerReflectionRequest(file_containing_symbol=symbol) + res = await self._do_one_request(req) + assert res is not None + fdp: list[bytes] = res.file_descriptor_response.file_descriptor_proto + return pb_descriptor.FileDescriptorProto.FromString(fdp[0]) + + async def _find_descriptor_by_filename(self, name: str): + req = pb_reflection.ServerReflectionRequest(file_by_filename=name) + res = await self._do_one_request(req) + assert res is not None + fdp: list[bytes] = res.file_descriptor_response.file_descriptor_proto + return pb_descriptor.FileDescriptorProto.FromString(fdp[0]) + + def _register_methods( + self, service_descriptor: ServiceDescriptor + ) -> dict[str, RpcMethod]: + service_descriptor_proto = pb_descriptor.ServiceDescriptorProto() + service_descriptor.CopyToProto(service_descriptor_proto) + full_name = service_descriptor.full_name + metadata: dict[str, RpcMethod] = {} + for method_proto in service_descriptor_proto.method: + method_name = method_proto.name + method_descriptor: MethodDescriptor = service_descriptor.FindMethodByName( + method_name + ) + input_type = self._symbol_database.GetPrototype( + method_descriptor.input_type + ) + output_type = self._symbol_database.GetPrototype( + method_descriptor.output_type + ) + metadata[method_name] = RpcMethod( + request_streaming=method_proto.client_streaming, + response_streaming=method_proto.server_streaming, + input_type=input_type, + output_type=output_type, + handler=getattr( + self.channel, + _RpcType.from_method_descriptor(method_proto), + )( + method=f"/{full_name}/{method_name}", + request_serializer=input_type.SerializeToString, + response_deserializer=output_type.FromString, + ), + ) + return metadata + + async def _invoke( + self, + method_name: str, + _deserialize_output: bool = False, + _serialize_input: bool = True, + **attrs: t.Any, + ): + await self._validate_rpc(method_name) + # channel kwargs include timeout, metadata, credentials, wait_for_ready and compression + # to pass it in kwargs add prefix _channel_ + channel_kwargs = { + k: attrs.pop(f"_channel_{k}", None) + for k in { + "timeout", + "metadata", + "credentials", + "wait_for_ready", + "compression", + } + } + rpc_method = self._get_rpc_metadata(method_name) + handler_type = _RpcType.from_streaming_type( + rpc_method["request_streaming"], rpc_method["response_streaming"] + ) + + if _serialize_input: + parsed = handler_type.request_serializer(rpc_method["input_type"], **attrs) + else: + parsed = rpc_method["input_type"](**attrs) + if handler_type.is_unary_response(): + result = await t.cast( + t.Awaitable[t.Any], + rpc_method["handler"](parsed, **channel_kwargs), + ) + if not _deserialize_output: + return result + return await t.cast( + t.Awaitable[t.Dict[str, t.Any]], + handler_type.response_deserializer(result), + ) + # streaming response + return handler_type.response_deserializer( + rpc_method["handler"](parsed, **channel_kwargs) + ) + + async def _validate_rpc(self, method_name: str): + await self.get_services() + + mn, _ = parse_method_name(method_name) + if mn.fully_qualified_service not in self._registered_services: + raise ValueError( + f"{mn.service} is not available in server. Registered services: {self._registered_services}" + ) + return True + + def _sync_call( + self, + inp: t.Any = None, + *, + _bentoml_api: InferenceAPI, + **kwargs: t.Any, + ): + with self: + return self._loop.run_until_complete( + self._call(inp, _bentoml_api=_bentoml_api, **kwargs) + ) + + async def _call( + self, + inp: t.Any = None, + *, + _bentoml_api: InferenceAPI, + **attrs: t.Any, + ) -> t.Any: + async with self: + # we need to pop everything that is client specific to separate dictionary + _deserialize_output = attrs.pop("_deserialize_output", False) + fn = functools.partial( + self._invoke, + **{ + f"_channel_{k}": attrs.pop(f"_channel_{k}", None) + for k in { + "timeout", + "metadata", + "credentials", + "wait_for_ready", + "compression", + } + }, + _serialize_input=False, + ) + + if _bentoml_api.multi_input: + if inp is not None: + raise BentoMLException( + f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." + ) + serialized_req = await _bentoml_api.input.to_proto(attrs) + else: + serialized_req = await _bentoml_api.input.to_proto(inp) + + # A call includes api_name and given proto_fields + return await fn( + self._call_rpc_method, + _deserialize_output=_deserialize_output, + **{ + "api_name": self._rev_apis[_bentoml_api], + _bentoml_api.input._proto_fields[0]: serialized_req, + }, + ) + + @staticmethod + def _create_client(parsed: ParseResult, **kwargs: t.Any) -> GrpcClient: + server_url = parsed.netloc + protocol_version = kwargs.get("protocol_version", LATEST_PROTOCOL_VERSION) + + # Since v1, we introduce a ServiceMetadata rpc to retrieve bentoml.Service metadata. + # This means if user are using client for protocol version v1alpha1, + # then `client.predict` or `client.classify` won't be available. + # client.Call will still persist for both protocol version. + dummy_service: Service | None = None + if parse(protocol_version) < parse("v1"): + logger.warning( + "Using protocol version %s older than v1. This means the client won't have service API functions as attributes. To invoke the RPC endpoint, use 'client.Call()'.", + protocol_version, + ) + else: + pb, _ = import_generated_stubs(protocol_version) + + # create an insecure channel to invoke ServiceMetadata rpc + with grpc.insecure_channel(server_url) as channel: + # gRPC sync stub is WIP. + ServiceMetadata = channel.unary_unary( + f"/bentoml.grpc.{protocol_version}.BentoService/ServiceMetadata", + request_serializer=pb.ServiceMetadataRequest.SerializeToString, + response_deserializer=pb.ServiceMetadataResponse.FromString, + ) + metadata = t.cast( + "ServiceMetadataResponse", + ServiceMetadata(pb.ServiceMetadataRequest()), + ) + dummy_service = Service(metadata.name) + + for api in metadata.apis: + dummy_service.apis[api.name] = InferenceAPI( + None, + io.from_spec( + SpecDict( + id=api.input.descriptor_id, + args=_json_format.MessageToDict(api.input.attributes)[ + "args" + ], + ) + ), + io.from_spec( + SpecDict( + id=api.output.descriptor_id, + args=_json_format.MessageToDict(api.output.attributes)[ + "args" + ], + ) + ), + name=api.name, + doc=api.docs, + ) + + return GrpcClient(server_url, dummy_service, **kwargs) + + def __del__(self): + if self._channel: + try: + del self._channel + except Exception: # pylint: disable=bare-except + pass + + def _reflection_request(self, *reqs: pb_reflection.ServerReflectionRequest): + if self._reflection_stub is None: + # ServerReflectionInfo is a stream RPC, hence the generator. + self._reflection_stub = services_reflection.ServerReflectionStub( + self.channel + ) + res: t.AsyncIterator[ + pb_reflection.ServerReflectionResponse + ] = self._reflection_stub.ServerReflectionInfo((r for r in reqs)) + return res + + async def _do_one_request( + self, req: pb_reflection.ServerReflectionRequest + ) -> pb_reflection.ServerReflectionResponse | None: + resps: t.AsyncIterator[ + pb_reflection.ServerReflectionResponse + ] = self._reflection_request(req) + try: + async for r in resps: + return r + except aio.AioRpcError as err: + code = err.code() + if code == grpc.StatusCode.UNIMPLEMENTED: + raise BentoMLException( + f"[{code}] Couldn't locate servicer method. The running server might not have reflection enabled. Make sure to pass '--enable-reflection'" + ) + raise BentoMLException( + f"Caught AioRpcError while handling reflection request: {err}" + ) from None + + +class _RpcType(Enum): + UNARY_UNARY = 1 + UNARY_STREAM = 2 + STREAM_UNARY = 3 + STREAM_STREAM = 4 + + def is_unary_request(self) -> bool: + return self.name.lower().startswith("unary_") + + def is_unary_response(self) -> bool: + return self.name.lower().endswith("_unary") + + @classmethod + def from_method_descriptor(cls, method_descriptor: MethodDescriptorProto) -> str: + rpcs = cls.from_streaming_type( + method_descriptor.client_streaming, method_descriptor.server_streaming + ) + return rpcs.name.lower() + + @classmethod + def from_streaming_type( + cls, client_streaming: bool, server_streaming: bool + ) -> t.Self: + if not client_streaming and not server_streaming: + return cls.UNARY_UNARY + elif client_streaming and not server_streaming: + return cls.STREAM_UNARY + elif not client_streaming and server_streaming: + return cls.UNARY_STREAM + else: + return cls.STREAM_STREAM + + @property + def request_serializer(self) -> t.Callable[..., t.Any]: + def _(input_type: type[t.Any], **request_data: t.Any): + data = request_data or {} + return _json_format.ParseDict(data, input_type()) + + def _it(input_type: type[t.Any], request_data: t.Iterable[t.Any]): + for data in request_data: + yield _(input_type, **data) + + return _ if self.is_unary_request() else _it + + @property + def response_deserializer(self) -> t.Callable[..., t.Any]: + async def _(response: _message.Message): + return _json_format.MessageToDict( + response, preserving_proto_field_name=True + ) + + async def _it(response: t.AsyncIterator[_message.Message]): + async for r in response: + yield await _(r) + + return _ if self.is_unary_response() else _it diff --git a/src/bentoml/grpc/types.py b/src/bentoml/grpc/types.py index 942aa44db0f..f4567e81871 100644 --- a/src/bentoml/grpc/types.py +++ b/src/bentoml/grpc/types.py @@ -12,6 +12,12 @@ import grpc from grpc import aio + from grpc.aio._typing import SerializingFunction + from grpc.aio._typing import DeserializingFunction + from grpc.aio._base_channel import UnaryUnaryMultiCallable + from grpc.aio._base_channel import StreamUnaryMultiCallable + from grpc.aio._base_channel import UnaryStreamMultiCallable + from grpc.aio._base_channel import StreamStreamMultiCallable from bentoml.grpc.v1.service_pb2 import Request from bentoml.grpc.v1.service_pb2 import Response @@ -94,6 +100,13 @@ class HandlerCallDetails( t.Callable[[], aio.ServerInterceptor] | partial[aio.ServerInterceptor] ] + MultiCallable = ( + UnaryUnaryMultiCallable + | UnaryStreamMultiCallable + | StreamUnaryMultiCallable + | StreamStreamMultiCallable + ) + __all__ = [ "Request", "Response", diff --git a/src/bentoml/testing/server.py b/src/bentoml/testing/server.py index b7dd72ca9a5..46e9624a26f 100644 --- a/src/bentoml/testing/server.py +++ b/src/bentoml/testing/server.py @@ -218,6 +218,8 @@ def run_bento_server_container( cmd.append(image_tag) serve_cmd = "serve-grpc" if use_grpc else "serve-http" cmd.extend([serve_cmd, "--production"]) + if use_grpc: + cmd.extend(["--enable-reflection"]) print(f"Running API server in container: '{' '.join(cmd)}'") with subprocess.Popen( cmd, @@ -265,7 +267,7 @@ def run_bento_server_standalone( f"{server_port}", ] if use_grpc: - cmd += ["--host", f"{host}"] + cmd += ["--host", f"{host}", "--enable-reflection"] cmd += [bento] print(f"Running command: '{' '.join(cmd)}'") p = subprocess.Popen( @@ -378,6 +380,8 @@ def run_bento_server_distributed( path, *itertools.chain.from_iterable(runner_args), ] + if use_grpc: + cmd.extend(["--enable-reflection"]) with reserve_free_port(host=host, enable_so_reuseport=use_grpc) as server_port: cmd.extend(["--port", f"{server_port}"]) print(f"Running command: '{' '.join(cmd)}'") diff --git a/tests/e2e/bento_server_grpc/tests/test_metrics.py b/tests/e2e/bento_server_grpc/tests/test_metrics.py index f3ea0adfd76..fcbda1ded36 100644 --- a/tests/e2e/bento_server_grpc/tests/test_metrics.py +++ b/tests/e2e/bento_server_grpc/tests/test_metrics.py @@ -2,19 +2,15 @@ from typing import TYPE_CHECKING +import numpy as np import pytest +from bentoml.client import Client from bentoml.grpc.utils import import_generated_stubs -from bentoml.testing.grpc import create_channel -from bentoml.testing.grpc import async_client_call -from bentoml._internal.utils import LazyLoader if TYPE_CHECKING: - from google.protobuf import wrappers_pb2 - from bentoml.grpc.v1 import service_pb2 as pb else: - wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") pb, _ = import_generated_stubs() @@ -23,21 +19,16 @@ async def test_metrics_available(host: str, img_file: str): with open(str(img_file), "rb") as f: fb = f.read() - async with create_channel(host) as channel: - await async_client_call( - "predict_multi_images", - channel=channel, - data={ - "multipart": { - "fields": { - "original": pb.Part(file=pb.File(kind="image/bmp", content=fb)), - "compared": pb.Part(file=pb.File(kind="image/bmp", content=fb)), - } - } - }, + client = Client.from_url(host, grpc=True) + async with client.aservice() as service: + resp = await service.async_predict_multi_images( + original=np.random.randint(255, size=(10, 10, 3)).astype("uint8"), + compared=np.random.randint(255, size=(10, 10, 3)).astype("uint8"), ) - await async_client_call( - "ensure_metrics_are_registered", - channel=channel, - data={"text": wrappers_pb2.StringValue(value="input_string")}, + assert isinstance(resp, pb.Response) + + resp = await service.Call( + api_name="ensure_metrics_are_registered", + text="input_string", ) + assert isinstance(resp, pb.Response)