From b2e45a773f3449ab0c729839b00f7b9b6799ef41 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Mon, 28 Nov 2022 21:17:42 -0800 Subject: [PATCH] chore: make sure to shutdown channel if given connection is idle Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com> --- grpc-client/python/client.py | 14 ++++------ src/bentoml/_internal/io_descriptors/json.py | 4 +-- src/bentoml/client.py | 28 ++++++++++++++------ 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/grpc-client/python/client.py b/grpc-client/python/client.py index 3157dc1256a..d715573946e 100644 --- a/grpc-client/python/client.py +++ b/grpc-client/python/client.py @@ -1,18 +1,13 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING import numpy as np from bentoml.client import Client -if TYPE_CHECKING: - from bentoml.client import GrpcClient - -async def arun(client: GrpcClient): - print("registered services:", await client.get_services()) +async def arun(client: Client): res = await client.async_classify(np.array([[5.9, 3, 5.1, 1.8]])) print("Result from 'client.async_classify':\n", res) @@ -20,7 +15,7 @@ async def arun(client: GrpcClient): print("Result from 'client.async_call':\n", res) -def run(client: GrpcClient): +def run(client: Client): res = client.classify(np.array([[5.9, 3, 5.1, 1.8]])) print("Result from 'client.classify':\n", res) res = client.call("classify", np.array([[5.9, 3, 5.1, 1.8]])) @@ -31,10 +26,11 @@ def run(client: GrpcClient): import argparse parser = argparse.ArgumentParser() - parser.add_argument("-rwa", "--run-with-async", action="store_true") + parser.add_argument("-rwa", "--run-with-async", action="store_true", default=False) + parser.add_argument("--grpc", action="store_true", default=False) args = parser.parse_args() - c = Client.from_url("localhost:3000", grpc=True) + c = Client.from_url("localhost:3000", grpc=args.grpc) if args.run_with_async: loop = asyncio.new_event_loop() diff --git a/src/bentoml/_internal/io_descriptors/json.py b/src/bentoml/_internal/io_descriptors/json.py index f32dbda8020..469667b6a97 100644 --- a/src/bentoml/_internal/io_descriptors/json.py +++ b/src/bentoml/_internal/io_descriptors/json.py @@ -240,11 +240,11 @@ def from_spec(cls, spec: SpecDict) -> Self: if "args" not in spec: raise InvalidArgument(f"Missing args key in JSON spec: {spec}") if "has_pydantic_model" in spec["args"] and spec["args"]["has_pydantic_model"]: - logger.warning( + logger.debug( "BentoML does not support loading pydantic models from URLs; output will be a normal dictionary." ) if "has_json_encoder" in spec["args"] and spec["args"]["has_json_encoder"]: - logger.warning( + logger.debug( "BentoML does not support loading JSON encoders from URLs; output will be a normal dictionary." ) diff --git a/src/bentoml/client.py b/src/bentoml/client.py index 16a359fc397..2a293c7b9a5 100644 --- a/src/bentoml/client.py +++ b/src/bentoml/client.py @@ -411,7 +411,7 @@ def _create_client(parsed: ParseResult, **kwargs: t.Any) -> HTTPClient: route=route.lstrip("/"), ) - return HTTPClient(dummy_service, server_url) + return HTTPClient(dummy_service, parsed.geturl()) async def _call( self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any @@ -548,6 +548,15 @@ def _reserved_kw_mapping(self): "reflection": "grpc.reflection.v1alpha.ServerReflection", } + async def _exit(self): + try: + if self._channel: + if self._channel.get_state() == grpc.ChannelConnectivity.IDLE: + await self._channel.close() + except AttributeError as e: + logger.error(f"Error closing channel: %s", e, exc_info=e) + raise + def __enter__(self): return self.service().__enter__() @@ -560,6 +569,7 @@ def __exit__( try: if exc_type is not None: self.service().__exit__(exc_type, exc, traceback) + self._loop.run_until_complete(self._exit()) except Exception as err: # pylint: disable=broad-except logger.error(f"Exception occurred: %s (%s)", err, exc_type, exc_info=err) return False @@ -594,6 +604,7 @@ async def __aexit__( try: if exc_type is not None: await self.aservice().__aexit__(exc_type, exc, traceback) + await self._exit() except Exception as err: # pylint: disable=broad-except logger.error(f"Exception occurred: %s (%s)", err, exc_type, exc_info=err) return False @@ -621,8 +632,9 @@ async def aservice( ): await self._register_service(service_name) - # create a blocking call to wait til channel is ready. - await self.channel.channel_ready() + if self.channel.get_state() != grpc.ChannelConnectivity.READY: + # create a blocking call to wait til channel is ready. + await self.channel.channel_ready() try: method_meta = self._service_cache[service_name] @@ -760,7 +772,7 @@ def _register_methods( async def _invoke( self, method_name: str, - to_json: bool = False, + _deserialize_output: bool = False, _serialize_input: bool = True, **attrs: t.Any, ): @@ -791,7 +803,7 @@ async def _invoke( t.Awaitable[t.Any], rpc_method["handler"](parsed, **channel_kwargs), ) - if not to_json: + if not _deserialize_output: return result return await t.cast( t.Awaitable[t.Dict[str, t.Any]], @@ -833,7 +845,7 @@ async def _call( ) -> t.Any: async with self: # we need to pop everything that is client specific to separate dictionary - to_json = attrs.pop("to_json", False) + _deserialize_output = attrs.pop("_deserialize_output", False) fn = functools.partial( self._invoke, **{ @@ -846,6 +858,7 @@ async def _call( "compression", } }, + _serialize_input=False, ) if _bentoml_api.multi_input: @@ -860,8 +873,7 @@ async def _call( # A call includes api_name and given proto_fields return await fn( self._call_rpc_method, - to_json=to_json, - _serialize_input=False, + _deserialize_output=_deserialize_output, **{ "api_name": self._rev_apis[_bentoml_api], _bentoml_api.input._proto_fields[0]: serialized_req,