Skip to content

Commit

Permalink
chore: make sure to shutdown channel if given connection is idle
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Nov 29, 2022
1 parent 3378ae1 commit b2e45a7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 19 deletions.
14 changes: 5 additions & 9 deletions grpc-client/python/client.py
@@ -1,26 +1,21 @@
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

import numpy as np

from bentoml.client import Client

if TYPE_CHECKING:
from bentoml.client import GrpcClient


async def arun(client: GrpcClient):
print("registered services:", await client.get_services())
async def arun(client: Client):

res = await client.async_classify(np.array([[5.9, 3, 5.1, 1.8]]))
print("Result from 'client.async_classify':\n", res)
res = await client.async_call("classify", np.array([[5.9, 3, 5.1, 1.8]]))
print("Result from 'client.async_call':\n", res)


def run(client: GrpcClient):
def run(client: Client):
res = client.classify(np.array([[5.9, 3, 5.1, 1.8]]))
print("Result from 'client.classify':\n", res)
res = client.call("classify", np.array([[5.9, 3, 5.1, 1.8]]))
Expand All @@ -31,10 +26,11 @@ def run(client: GrpcClient):
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-rwa", "--run-with-async", action="store_true")
parser.add_argument("-rwa", "--run-with-async", action="store_true", default=False)
parser.add_argument("--grpc", action="store_true", default=False)
args = parser.parse_args()

c = Client.from_url("localhost:3000", grpc=True)
c = Client.from_url("localhost:3000", grpc=args.grpc)

if args.run_with_async:
loop = asyncio.new_event_loop()
Expand Down
4 changes: 2 additions & 2 deletions src/bentoml/_internal/io_descriptors/json.py
Expand Up @@ -240,11 +240,11 @@ def from_spec(cls, spec: SpecDict) -> Self:
if "args" not in spec:
raise InvalidArgument(f"Missing args key in JSON spec: {spec}")
if "has_pydantic_model" in spec["args"] and spec["args"]["has_pydantic_model"]:
logger.warning(
logger.debug(
"BentoML does not support loading pydantic models from URLs; output will be a normal dictionary."
)
if "has_json_encoder" in spec["args"] and spec["args"]["has_json_encoder"]:
logger.warning(
logger.debug(
"BentoML does not support loading JSON encoders from URLs; output will be a normal dictionary."
)

Expand Down
28 changes: 20 additions & 8 deletions src/bentoml/client.py
Expand Up @@ -411,7 +411,7 @@ def _create_client(parsed: ParseResult, **kwargs: t.Any) -> HTTPClient:
route=route.lstrip("/"),
)

return HTTPClient(dummy_service, server_url)
return HTTPClient(dummy_service, parsed.geturl())

async def _call(
self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any
Expand Down Expand Up @@ -548,6 +548,15 @@ def _reserved_kw_mapping(self):
"reflection": "grpc.reflection.v1alpha.ServerReflection",
}

async def _exit(self):
try:
if self._channel:
if self._channel.get_state() == grpc.ChannelConnectivity.IDLE:
await self._channel.close()
except AttributeError as e:
logger.error(f"Error closing channel: %s", e, exc_info=e)
raise

def __enter__(self):
return self.service().__enter__()

Expand All @@ -560,6 +569,7 @@ def __exit__(
try:
if exc_type is not None:
self.service().__exit__(exc_type, exc, traceback)
self._loop.run_until_complete(self._exit())
except Exception as err: # pylint: disable=broad-except
logger.error(f"Exception occurred: %s (%s)", err, exc_type, exc_info=err)
return False
Expand Down Expand Up @@ -594,6 +604,7 @@ async def __aexit__(
try:
if exc_type is not None:
await self.aservice().__aexit__(exc_type, exc, traceback)
await self._exit()
except Exception as err: # pylint: disable=broad-except
logger.error(f"Exception occurred: %s (%s)", err, exc_type, exc_info=err)
return False
Expand Down Expand Up @@ -621,8 +632,9 @@ async def aservice(
):
await self._register_service(service_name)

# create a blocking call to wait til channel is ready.
await self.channel.channel_ready()
if self.channel.get_state() != grpc.ChannelConnectivity.READY:
# create a blocking call to wait til channel is ready.
await self.channel.channel_ready()

try:
method_meta = self._service_cache[service_name]
Expand Down Expand Up @@ -760,7 +772,7 @@ def _register_methods(
async def _invoke(
self,
method_name: str,
to_json: bool = False,
_deserialize_output: bool = False,
_serialize_input: bool = True,
**attrs: t.Any,
):
Expand Down Expand Up @@ -791,7 +803,7 @@ async def _invoke(
t.Awaitable[t.Any],
rpc_method["handler"](parsed, **channel_kwargs),
)
if not to_json:
if not _deserialize_output:
return result
return await t.cast(
t.Awaitable[t.Dict[str, t.Any]],
Expand Down Expand Up @@ -833,7 +845,7 @@ async def _call(
) -> t.Any:
async with self:
# we need to pop everything that is client specific to separate dictionary
to_json = attrs.pop("to_json", False)
_deserialize_output = attrs.pop("_deserialize_output", False)
fn = functools.partial(
self._invoke,
**{
Expand All @@ -846,6 +858,7 @@ async def _call(
"compression",
}
},
_serialize_input=False,
)

if _bentoml_api.multi_input:
Expand All @@ -860,8 +873,7 @@ async def _call(
# A call includes api_name and given proto_fields
return await fn(
self._call_rpc_method,
to_json=to_json,
_serialize_input=False,
_deserialize_output=_deserialize_output,
**{
"api_name": self._rev_apis[_bentoml_api],
_bentoml_api.input._proto_fields[0]: serialized_req,
Expand Down

0 comments on commit b2e45a7

Please sign in to comment.