From ddd553a782c32c25ee2caff1bb566a2790a963ac Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 8 Dec 2022 01:37:53 -0800 Subject: [PATCH] fix: make sure to call _sync_call in sync api Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- src/bentoml/_internal/client/__init__.py | 4 +- src/bentoml/_internal/client/grpc.py | 93 +++++++++---------- src/bentoml/_internal/client/http.py | 2 + .../bento_server_grpc/tests/test_metrics.py | 2 + 4 files changed, 49 insertions(+), 52 deletions(-) diff --git a/src/bentoml/_internal/client/__init__.py b/src/bentoml/_internal/client/__init__.py index aa77424bfe4..e41d518810b 100644 --- a/src/bentoml/_internal/client/__init__.py +++ b/src/bentoml/_internal/client/__init__.py @@ -46,7 +46,9 @@ def __init__(self, svc: Service, server_url: str): ) def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any: - return asyncio.run(self.async_call(bentoml_api_name, inp, **kwargs)) + return self._sync_call( + inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs + ) async def async_call( self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any diff --git a/src/bentoml/_internal/client/grpc.py b/src/bentoml/_internal/client/grpc.py index e1ab4ba76ef..ba9b99ac9f0 100644 --- a/src/bentoml/_internal/client/grpc.py +++ b/src/bentoml/_internal/client/grpc.py @@ -31,6 +31,7 @@ from google.protobuf import json_format as _json_format from ..types import PathType + from ...grpc.v1.service_pb2 import Response from ...grpc.v1.service_pb2 import ServiceMetadataResponse class ClientCredentials(t.TypedDict): @@ -39,16 +40,13 @@ class ClientCredentials(t.TypedDict): certificate_chain: t.NotRequired[PathType | bytes] else: - ClientCredentials = dict + grpc, aio = import_grpc() _json_format = LazyLoader( "_json_format", globals(), "google.protobuf.json_format", exc_msg=PROTOBUF_EXC_MESSAGE, ) - grpc, aio = import_grpc() - -_INDENTATION = " " * 4 # TODO: xDS support class GrpcClient(Client): @@ -65,13 +63,12 @@ def __init__( *, protocol_version: str = LATEST_PROTOCOL_VERSION, ): - self._pb, self._services = import_generated_stubs(protocol_version) + self._pb, _ = import_generated_stubs(protocol_version) self._protocol_version = protocol_version self._compression = compression self._options = channel_options self._interceptors = interceptors - self._channel = None self._credentials = None if ssl: assert ( @@ -85,38 +82,29 @@ def __init__( ) super().__init__(svc, server_url) - @cached_property + @property def channel(self): - if not self._channel: - if self._credentials is not None: - self._channel = aio.secure_channel( - self.server_url, - credentials=self._credentials, - options=self._options, - compression=self._compression, - interceptors=self._interceptors, - ) - self._channel = aio.insecure_channel( + if self._credentials is not None: + return aio.secure_channel( + self.server_url, + credentials=self._credentials, + options=self._options, + compression=self._compression, + interceptors=self._interceptors, + ) + else: + return aio.insecure_channel( self.server_url, options=self._options, compression=self._compression, interceptors=self._interceptors, ) - return self._channel @cached_property - def _rpc_handler_mapping(self): + def _rpc_metadata(self): # Currently all RPCs in BentoService are unary-unary return { - method: { - "handler": self.channel.unary_unary( - method=method, - request_serializer=input_type.SerializeToString, - response_deserializer=output_type.FromString, - ), - "input_type": input_type, - "output_type": output_type, - } + method: {"input_type": input_type, "output_type": output_type} for method, input_type, output_type in ( ( f"/bentoml.grpc.{self._protocol_version}.BentoService/Call", @@ -133,9 +121,9 @@ def _rpc_handler_mapping(self): async def _invoke(self, method_name: str, **attrs: t.Any): # channel kwargs include timeout, metadata, credentials, wait_for_ready and compression - # to pass it in kwargs add prefix _channel_ + # to pass it in kwargs add prefix _grpc_channel_ channel_kwargs = { - k: attrs.pop(f"_channel_{k}", None) + k: attrs.pop(f"_grpc_channel_{k}", None) for k in { "timeout", "metadata", @@ -144,17 +132,20 @@ async def _invoke(self, method_name: str, **attrs: t.Any): "compression", } } - if method_name not in self._rpc_handler_mapping: + if method_name not in self._rpc_metadata: raise ValueError( - f"'{method_name}' is a yet supported rpc. Current supported are: {list(self._rpc_handler_mapping.keys())}" + f"'{method_name}' is a yet supported rpc. Current supported are: {self._rpc_metadata}" ) - rpc_handler = self._rpc_handler_mapping[method_name] + metadata = self._rpc_metadata[method_name] + rpc = self.channel.unary_unary( + method_name, + request_serializer=metadata["input_type"].SerializeToString, + response_deserializer=metadata["output_type"].FromString, + ) return await t.cast( - t.Awaitable[t.Any], - rpc_handler["handler"]( - rpc_handler["input_type"](**attrs), **channel_kwargs - ), + "t.Awaitable[Response]", + rpc(metadata["input_type"](**attrs), **channel_kwargs), ) async def _call( @@ -164,14 +155,16 @@ async def _call( _bentoml_api: InferenceAPI, **attrs: t.Any, ) -> t.Any: - if self.channel.get_state() != grpc.ChannelConnectivity.READY: + state = self.channel.get_state(try_to_connect=True) + if state != grpc.ChannelConnectivity.READY: # create a blocking call to wait til channel is ready. await self.channel.channel_ready() fn = functools.partial( self._invoke, + method_name=f"/bentoml.grpc.{self._protocol_version}.BentoService/Call", **{ - f"_channel_{k}": attrs.pop(f"_channel_{k}", None) + f"_grpc_channel_{k}": attrs.pop(f"_grpc_channel_{k}", None) for k in { "timeout", "metadata", @@ -192,11 +185,10 @@ async def _call( serialized_req = await _bentoml_api.input.to_proto(inp) # A call includes api_name and given proto_fields - _rev_apis = {v: k for k, v in self._svc.apis.items()} + api_fn = {v: k for k, v in self._svc.apis.items()} return await fn( - f"/bentoml.grpc.{self._protocol_version}.BentoService/Call", **{ - "api_name": _rev_apis[_bentoml_api], + "api_name": api_fn[_bentoml_api], _bentoml_api.input._proto_fields[0]: serialized_req, }, ) @@ -266,15 +258,14 @@ def run(): ) # create an insecure channel to invoke ServiceMetadata rpc - with channel: - metadata = t.cast( - "ServiceMetadataResponse", - channel.unary_unary( - f"/bentoml.grpc.{protocol_version}.BentoService/ServiceMetadata", - request_serializer=pb.ServiceMetadataRequest.SerializeToString, - response_deserializer=pb.ServiceMetadataResponse.FromString, - )(pb.ServiceMetadataRequest()), - ) + metadata = t.cast( + "ServiceMetadataResponse", + channel.unary_unary( + f"/bentoml.grpc.{protocol_version}.BentoService/ServiceMetadata", + request_serializer=pb.ServiceMetadataRequest.SerializeToString, + response_deserializer=pb.ServiceMetadataResponse.FromString, + )(pb.ServiceMetadataRequest()), + ) dummy_service = Service(metadata.name) for api in metadata.apis: diff --git a/src/bentoml/_internal/client/http.py b/src/bentoml/_internal/client/http.py index 49f27188f64..db97bda4dc4 100644 --- a/src/bentoml/_internal/client/http.py +++ b/src/bentoml/_internal/client/http.py @@ -78,6 +78,8 @@ def from_url(cls, server_url: str, **kwargs: t.Any) -> HTTPClient: async def _call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any ) -> t.Any: + # All gRPC kwargs should be poped out. + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_grpc_")} api = _bentoml_api if api.multi_input: diff --git a/tests/e2e/bento_server_grpc/tests/test_metrics.py b/tests/e2e/bento_server_grpc/tests/test_metrics.py index 393be9ad0f1..caa8c40d4d2 100644 --- a/tests/e2e/bento_server_grpc/tests/test_metrics.py +++ b/tests/e2e/bento_server_grpc/tests/test_metrics.py @@ -22,3 +22,5 @@ async def test_metrics_available(host: str): compared=np.random.randint(255, size=(10, 10, 3)).astype("uint8"), ) assert isinstance(resp, pb.Response) + resp = await client.async_ensure_metrics_are_registered("input_data") + assert isinstance(resp, pb.Response)