From 87b91222af7d80457ff8be8563e24a7961c6e4ea Mon Sep 17 00:00:00 2001 From: Ubuntu <29749331+aarnphm@users.noreply.github.com> Date: Mon, 28 Nov 2022 11:30:46 +0000 Subject: [PATCH] feat: sync and async client implementation Signed-off-by: Ubuntu <29749331+aarnphm@users.noreply.github.com> --- grpc-client/python/client.py | 85 ++++-- src/bentoml/client.py | 280 ++++++++++++++---- .../bento_server_grpc/tests/test_metrics.py | 2 - 3 files changed, 288 insertions(+), 79 deletions(-) diff --git a/grpc-client/python/client.py b/grpc-client/python/client.py index e558048e089..ab4df9f4107 100644 --- a/grpc-client/python/client.py +++ b/grpc-client/python/client.py @@ -1,33 +1,80 @@ from __future__ import annotations import asyncio +import functools +from typing import TYPE_CHECKING + +import numpy as np from bentoml.client import Client +if TYPE_CHECKING: + from bentoml.client import GrpcClient -async def run(): - c = Client.from_url("localhost:3000", grpc=True) - async with c.aservice("grpc.health.v1.Health") as health_client: - print(await health_client.Check(service="bentoml.grpc.v1.BentoService")) + +async def arun(c: GrpcClient): + print("registered services:", await c.get_services()) + + async with c.aservice("health") as health_client: + res = await health_client.Check(service="bentoml.grpc.v1.BentoService") + print(res) async with c.aservice() as client: - print( - await client.Call( - api_name="classify", - ndarray=dict( - dtype=1, - shape=[1, 4], - float_values=[5.9, 3, 5.1, 1.8], - ), - return_proto=True, + res = await client.Call( + api_name="classify", + ndarray=dict( + dtype=1, + shape=[1, 4], + float_values=[5.9, 3, 5.1, 1.8], + ), + ) + print("Result from 'await client.Call':\n", res) + res = await client.async_classify(np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.async_classify':\n", res) + + +def run(c: Client): + with c.service("health") as health_client: + Check = functools.partial( + health_client.Check, service="bentoml.grpc.v1.BentoService" + ) + for to_json in (True, False): + print( + f"Health check ({'serialized' if to_json else 'deserialized'}):", + Check(to_json=to_json), ) + + with c.service() as client: + res = client.Call( + api_name="classify", + ndarray=dict( + dtype=1, + shape=[1, 4], + float_values=[5.9, 3, 5.1, 1.8], + ), ) + print("Result from 'client.Call':\n", res) + res = client.classify(np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.classify':\n", res) + res = client.call("classify", np.array([[5.9, 3, 5.1, 1.8]])) + print("Result from 'client.call(bentoml_api_name='classify')':\n", res) if __name__ == "__main__": - loop = asyncio.new_event_loop() - try: - loop.run_until_complete(run()) - finally: - loop.close() - assert loop.is_closed() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-rwa", "--run-with-async", action="store_true") + args = parser.parse_args() + + c = Client.from_url("localhost:3000", grpc=True) + + if args.run_with_async: + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(arun(c)) + finally: + loop.close() + assert loop.is_closed() + else: + run(c) diff --git a/src/bentoml/client.py b/src/bentoml/client.py index e32e1da296a..5a840fbf8a5 100644 --- a/src/bentoml/client.py +++ b/src/bentoml/client.py @@ -3,13 +3,14 @@ import json import typing as t import asyncio +import inspect import logging import functools +import contextlib from abc import ABC +from abc import abstractmethod from enum import Enum from typing import TYPE_CHECKING -from contextlib import contextmanager -from contextlib import asynccontextmanager from http.client import HTTPConnection from urllib.parse import urlparse @@ -44,13 +45,17 @@ import grpc from grpc import aio from google.protobuf import message as _message - from google.protobuf import descriptor as _descriptor from google.protobuf import json_format as _json_format from google.protobuf import descriptor_pb2 as pb_descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from grpc_reflection.v1alpha import reflection_pb2 as pb_reflection from grpc_reflection.v1alpha import reflection_pb2_grpc as services_reflection + + # type hint specific imports. + from google.protobuf.descriptor import MethodDescriptor + from google.protobuf.descriptor import ServiceDescriptor + from google.protobuf.descriptor_pb2 import FileDescriptorProto from google.protobuf.descriptor_pool import DescriptorPool from google.protobuf.symbol_database import SymbolDatabase from grpc_reflection.v1alpha.reflection_pb2 import ServiceResponse @@ -95,12 +100,6 @@ class RpcMethod(t.TypedDict): "google.protobuf.symbol_database", exc_msg=PROTOBUF_EXC_MESSAGE, ) - _descriptor = LazyLoader( - "_descriptor", - globals(), - "google.protobuf.descriptor", - exc_msg=PROTOBUF_EXC_MESSAGE, - ) _json_format = LazyLoader( "_json_format", globals(), @@ -224,7 +223,6 @@ class GRPC: ssl: bool = attr.field(default=False) channel_options: t.Optional[aio.ChannelArgumentType] = attr.field(default=None) compression: t.Optional[grpc.Compression] = attr.field(default=None) - graceful_shutdown_timeout: t.Optional[float] = attr.field(default=None) ssl_client_credentials: t.Optional[ClientCredentials] = attr.field( factory=lambda: ClientCredentials() ) @@ -258,7 +256,11 @@ class Client(ABC): def __init__(self, svc: Service, server_url: str): if len(svc.apis) == 0: raise BentoMLException("No APIs were found when constructing client.") + self._register_service_endpoint(svc) + self._svc = svc + self.server_url = server_url + def _register_service_endpoint(self, svc: Service): for name, api in svc.apis.items(): if not hasattr(self, name): setattr( @@ -271,11 +273,15 @@ def __init__(self, svc: Service, server_url: str): f"async_{name}", functools.partial(self._call, _bentoml_api=api), ) - self.server_url = server_url - self._svc = svc + + @cached_property + def _loop(self) -> asyncio.AbstractEventLoop: + return asyncio.get_event_loop() def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: - return asyncio.run(self.async_call(bentoml_api_name, inp, **kwargs)) + return self._loop.run_until_complete( + self.async_call(bentoml_api_name, inp, **kwargs) + ) async def async_call( self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any @@ -287,8 +293,11 @@ async def async_call( def _sync_call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any ): - return asyncio.run(self._call(inp, _bentoml_api=_bentoml_api, **kwargs)) + return self._loop.run_until_complete( + self._call(inp, _bentoml_api=_bentoml_api, **kwargs) + ) + @abstractmethod async def _call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any ) -> t.Any: @@ -502,7 +511,6 @@ def __init__( channel_options: aio.ChannelArgumentType | None = None, interceptors: t.Sequence[aio.ClientInterceptor] | None = None, compression: grpc.Compression | None = None, - graceful_shutdown_timeout: float | None = None, ssl_client_credentials: ClientCredentials | None = None, *, protocol_version: str = LATEST_PROTOCOL_VERSION, @@ -510,11 +518,17 @@ def __init__( if svc is not None and len(svc.apis) == 0: raise BentoMLException("No APIs were found when constructing client.") - self._svc = svc or _sentinel_svc self.server_url = server_url - self._graceful_shutdown_timeout = graceful_shutdown_timeout - self._protocol_version = protocol_version + self._svc = svc or _sentinel_svc + # Call requires an api_name, therefore we need a reserved keyset of self._svc.apis + self._rev_apis = {v: k for k, v in self._svc.apis.items()} + + self._protocol_version = protocol_version + self._compression = compression + self._options = channel_options + self._interceptors = interceptors + self._channel = None self._credentials = None if ssl: assert ( @@ -526,18 +540,17 @@ def __init__( for k, v in ssl_client_credentials.items() } ) - self._compression = compression - self._options = channel_options - self._interceptors = interceptors - self._channel = None self._descriptor_pool: DescriptorPool = _descriptor_pool.Default() self._symbol_database: SymbolDatabase = _symbol_database.Default() self._registered_services: tuple[str, ...] = tuple() + # cached of all available rpc for a given service. self._service_methods_cache: dict[str, dict[str, RpcMethod]] = {} + # boolean to determine whether all available services under the gRPC server is registered or not. self._is_registered = False - self._registered_file_name = set() + # Sets of FileDescriptorProto name to be registered + self._registered_file_name: set[str] = set() @cached_property def channel(self): @@ -558,22 +571,79 @@ def channel(self): ) return self._channel - @asynccontextmanager + def _reset_cache(self): + self._registered_file_name.clear() + self._registered_services = tuple() + self._service_methods_cache.clear() + self._is_registered = False + + @contextlib.contextmanager + def service(self, service_name: str = "default"): + stack = contextlib.AsyncExitStack() + + async def enter(): + res = await stack.enter_async_context( + self.aservice(service_name, _wrap_in_sync=True) + ) + return res + + try: + yield self._loop.run_until_complete(enter()) + finally: + + async def cleanup(): + await stack.aclose() + + self._loop.run_until_complete(cleanup()) + + @staticmethod + def make_rpc_method(service_name: str, method: str): + return f"/{service_name}/{method}" + + @property + def _call_rpc_method(self): + return self.make_rpc_method( + f"bentoml.grpc.{self._protocol_version}.BentoService", "Call" + ) + + @cached_property + def _reserved_kw_mapping(self): + return { + "default": f"bentoml.grpc.{self._protocol_version}.BentoService", + "health": "grpc.health.v1.Health", + "reflection": "grpc.reflection.v1alpha.ServerReflection", + } + + @contextlib.asynccontextmanager async def aservice( - self, service_name: str = "default" + self, + service_name: str = "default", + *, + _wrap_in_sync: bool = False, ) -> t.AsyncGenerator[t.Self, None]: # This is the entrypoint for user to instantiate a client for a given service. - # default is a special service name to get the BentoService proto. - if service_name == "default": - service_name = f"bentoml.grpc.{self._protocol_version}.BentoService" + # default is a special case for BentoService proto. + if service_name in self._reserved_kw_mapping: + service_name = self._reserved_kw_mapping[service_name] + + global _cached_grpc_client + cached_client_name = f"{service_name.replace('.', '_').lower()}_{'sync' if _wrap_in_sync else 'async'}" + if cached_client_name in _cached_grpc_client: + yield _cached_grpc_client[cached_client_name] + return - await self.get_available_services() + # we need to clean all cache to setup for sync method + # so that we don't mix-up async and sync function. + if _wrap_in_sync: + self._reset_cache() + + await self._register_services() if ( service_name in self._registered_services and service_name not in self._service_methods_cache ): - await self.register_service(service_name) + await self._register_service(service_name) # create a blocking call to wait til channel is ready. await self.channel.channel_ready() @@ -583,17 +653,58 @@ async def aservice( raise ValueError( f"Failed to find service '{service_name}'. Available: {list(self._service_methods_cache.keys())}" ) from None - for method in method_meta: - object_setattr( - self, - method, - functools.partial(self._invoke, f"/{service_name}/{method}"), + + def _register(method: str): + finaliser = f = functools.partial( + self._invoke, self.make_rpc_method(service_name, method) ) + if _wrap_in_sync: + # We will have to run the async function in a sync wrapper + @functools.wraps(f) + def wrapper(*args: t.Any, **kwargs: t.Any): + coro = f(*args, **kwargs) + task = asyncio.ensure_future(coro, loop=self._loop) + try: + res = self._loop.run_until_complete(task) + if inspect.isasyncgen(res): + # If this is an async generator, then we need to yield again + async def call(): + return await res.__anext__() + + return self._loop.run_until_complete(call()) + return res + except BaseException: + # Consume all exceptions. + if task.done() and not task.cancelled(): + task.exception() + raise + + finaliser = wrapper + object_setattr(self, method, finaliser) + + # 1. Register all RPC method. + for method in reversed(method_meta): + _register(method) + + # 2. Register service method if given service is not _sentinel_svc + # We only set _sentinel_svc if given protocol is older than v1 + if self._svc is not _sentinel_svc: + self._register_service_endpoint(self._svc) + + # Cache given client for each service for later use. + _cached_grpc_client[cached_client_name] = self yield self - async def register_service(self, service_name: str) -> None: + async def _register_services(self): + await self.get_services() + if not self._is_registered: + for svc in self._registered_services: + await self._register_service(svc) + self._is_registered = True + + async def _register_service(self, service_name: str) -> None: cache = self._service_methods_cache - svc_descriptor: _descriptor.ServiceDescriptor | None = None + svc_descriptor: ServiceDescriptor | None = None try: svc_descriptor = self._descriptor_pool.FindServiceByName(service_name) except KeyError: @@ -612,9 +723,7 @@ async def register_service(self, service_name: str) -> None: if svc_descriptor is not None: cache[service_name] = self._register_methods(svc_descriptor) - async def _add_file_descriptor( - self, file_descriptor: pb_descriptor.FileDescriptorProto - ): + async def _add_file_descriptor(self, file_descriptor: FileDescriptorProto): dependencies = file_descriptor.dependency for deps in dependencies: if deps not in self._registered_file_name: @@ -626,6 +735,7 @@ async def _add_file_descriptor( async def _find_descriptor_by_symbol(self, symbol: str): req = pb_reflection.ServerReflectionRequest(file_containing_symbol=symbol) res = await self._do_one_request(req) + assert res is not None fdr: FileDescriptorResponse = res.file_descriptor_response fdp: list[bytes] = fdr.file_descriptor_proto return pb_descriptor.FileDescriptorProto.FromString(fdp[0]) @@ -633,6 +743,7 @@ async def _find_descriptor_by_symbol(self, symbol: str): async def _find_descriptor_by_filename(self, name: str): req = pb_reflection.ServerReflectionRequest(file_by_filename=name) res = await self._do_one_request(req) + assert res is not None fdr: FileDescriptorResponse = res.file_descriptor_response fdp: list[bytes] = fdr.file_descriptor_proto return pb_descriptor.FileDescriptorProto.FromString(fdp[0]) @@ -651,7 +762,7 @@ def _get_rpc_metadata(self, method_name: str) -> RpcMethod: ) from None def _register_methods( - self, service_descriptor: _descriptor.ServiceDescriptor + self, service_descriptor: ServiceDescriptor ) -> dict[str, RpcMethod]: service_descriptor_proto = pb_descriptor.ServiceDescriptorProto() service_descriptor.CopyToProto(service_descriptor_proto) @@ -659,8 +770,8 @@ def _register_methods( metadata: dict[str, RpcMethod] = {} for method_proto in service_descriptor_proto.method: method_name = method_proto.name - method_descriptor: _descriptor.MethodDescriptor = ( - service_descriptor.FindMethodByName(method_name) + method_descriptor: MethodDescriptor = service_descriptor.FindMethodByName( + method_name ) input_type = self._symbol_database.GetPrototype( method_descriptor.input_type @@ -691,7 +802,8 @@ async def _invoke( self, method_name: str, /, - return_proto: bool = False, + to_json: bool = False, + _serialize_input: bool = True, **attrs: t.Any, ): # channel kwargs include timeout, metadata, credentials, wait_for_ready and compression @@ -711,13 +823,16 @@ async def _invoke( rpc_method["request_streaming"], rpc_method["response_streaming"] ) - parsed = handler_type.request_serializer(rpc_method["input_type"], **attrs) + if _serialize_input: + parsed = handler_type.request_serializer(rpc_method["input_type"], **attrs) + else: + parsed = rpc_method["input_type"](**attrs) if handler_type.is_unary_response(): result = await t.cast( t.Awaitable[t.Any], rpc_method["handler"](parsed, **channel_kwargs), ) - if return_proto: + if not to_json: return result return await t.cast( t.Awaitable[t.Dict[str, t.Any]], @@ -729,17 +844,12 @@ async def _invoke( ) async def _validate_rpc(self, method_name: str): - if not self._registered_services: - await self.get_available_services() - if not self._is_registered: - for svc in self._registered_services: - await self.register_service(svc) - self._is_registered = True + await self.get_services() mn, _ = parse_method_name(method_name) - if mn.service not in self._registered_services: + if mn.fully_qualified_service not in self._registered_services: raise ValueError( - f"{mn.service} is not available in server. Registered services: {self.get_available_services}" + f"{mn.service} is not available in server. Registered services: {self._registered_services}" ) return True @@ -747,7 +857,8 @@ async def invoke( self, method_name: str, /, - return_proto: bool = False, + to_json: bool = False, + _serialize_input: bool = True, **attrs: t.Any, ): """Entrypoint to invoke the RPC. @@ -755,7 +866,55 @@ async def invoke( To pass in channel specific options, add prefix ``_channel_`` to the option name. For example, to pass in ``timeout``, add ``_channel_timeout=60`` """ await self._validate_rpc(method_name) - return await self._invoke(method_name, return_proto=return_proto, **attrs) + return await self._invoke( + method_name, + to_json=to_json, + _serialize_input=_serialize_input, + **attrs, + ) + + async def _call( + self, + inp: t.Any = None, + *, + _bentoml_api: InferenceAPI, + **attrs: t.Any, + ) -> t.Any: + # we need to pop everything that is client specific to separate dictionary + to_json = attrs.pop("to_json", False) + fn = functools.partial( + self.invoke, + **{ + f"_channel_{k}": attrs.pop(f"_channel_{k}", None) + for k in { + "timeout", + "metadata", + "credentials", + "wait_for_ready", + "compression", + } + }, + ) + + if _bentoml_api.multi_input: + if inp is not None: + raise BentoMLException( + f"'{_bentoml_api.name}' takes multiple inputs; all inputs must be passed as keyword arguments." + ) + serialized_req = await _bentoml_api.input.to_proto(attrs) + else: + serialized_req = await _bentoml_api.input.to_proto(inp) + + # A call includes api_name and given proto_fields + return await fn( + self._call_rpc_method, + to_json=to_json, + _serialize_input=False, + **{ + "api_name": self._rev_apis[_bentoml_api], + _bentoml_api.input._proto_fields[0]: serialized_req, + }, + ) def __del__(self): if self._channel: @@ -764,15 +923,17 @@ def __del__(self): except Exception: # pylint: disable=bare-except pass - async def get_available_services(self) -> None: + async def get_services(self): if not self._registered_services: resp = await self._do_one_request( pb_reflection.ServerReflectionRequest(list_services="") ) + assert resp is not None list_services: ListServiceResponse = resp.list_services_response services: list[ServiceResponse] = list_services.service self._registered_services = tuple([t.cast(str, s.name) for s in services]) assert self._registered_services + return self._registered_services def _reflection_request(self, *reqs: pb_reflection.ServerReflectionRequest): # ServerReflectionInfo is a stream RPC, hence the generator. @@ -784,7 +945,7 @@ def _reflection_request(self, *reqs: pb_reflection.ServerReflectionRequest): async def _do_one_request( self, req: pb_reflection.ServerReflectionRequest - ) -> pb_reflection.ServerReflectionResponse: + ) -> pb_reflection.ServerReflectionResponse | None: resps: t.AsyncIterator[ pb_reflection.ServerReflectionResponse ] = self._reflection_request(req) @@ -802,6 +963,9 @@ async def _do_one_request( ) from None +_cached_grpc_client: dict[str, GrpcClient] = {} + + class RpcType(Enum): UNARY_UNARY = 1 UNARY_STREAM = 2 diff --git a/tests/e2e/bento_server_grpc/tests/test_metrics.py b/tests/e2e/bento_server_grpc/tests/test_metrics.py index 50287820749..66ff64aa494 100644 --- a/tests/e2e/bento_server_grpc/tests/test_metrics.py +++ b/tests/e2e/bento_server_grpc/tests/test_metrics.py @@ -33,13 +33,11 @@ async def test_metrics_available(host: str, img_file: str): "compared": {"file": {"kind": "image/bmp", "content": fb}}, } }, - return_proto=True, ) assert isinstance(resp, pb.Response) resp = await service.Call( api_name="ensure_metrics_are_registered", text=wrappers_pb2.StringValue(value="input_string"), - return_proto=True, ) assert isinstance(resp, pb.Response)