Skip to content

Commit

Permalink
fix: make sure to call _sync_call in sync api
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Dec 8, 2022
1 parent 0c71ee8 commit ddd553a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 52 deletions.
4 changes: 3 additions & 1 deletion src/bentoml/_internal/client/__init__.py
Expand Up @@ -46,7 +46,9 @@ def __init__(self, svc: Service, server_url: str):
)

def call(self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any) -> t.Any:
return asyncio.run(self.async_call(bentoml_api_name, inp, **kwargs))
return self._sync_call(
inp, _bentoml_api=self._svc.apis[bentoml_api_name], **kwargs
)

async def async_call(
self, bentoml_api_name: str, inp: t.Any = None, **kwargs: t.Any
Expand Down
93 changes: 42 additions & 51 deletions src/bentoml/_internal/client/grpc.py
Expand Up @@ -31,6 +31,7 @@
from google.protobuf import json_format as _json_format

from ..types import PathType
from ...grpc.v1.service_pb2 import Response
from ...grpc.v1.service_pb2 import ServiceMetadataResponse

class ClientCredentials(t.TypedDict):
Expand All @@ -39,16 +40,13 @@ class ClientCredentials(t.TypedDict):
certificate_chain: t.NotRequired[PathType | bytes]

else:
ClientCredentials = dict
grpc, aio = import_grpc()
_json_format = LazyLoader(
"_json_format",
globals(),
"google.protobuf.json_format",
exc_msg=PROTOBUF_EXC_MESSAGE,
)
grpc, aio = import_grpc()

_INDENTATION = " " * 4

# TODO: xDS support
class GrpcClient(Client):
Expand All @@ -65,13 +63,12 @@ def __init__(
*,
protocol_version: str = LATEST_PROTOCOL_VERSION,
):
self._pb, self._services = import_generated_stubs(protocol_version)
self._pb, _ = import_generated_stubs(protocol_version)

self._protocol_version = protocol_version
self._compression = compression
self._options = channel_options
self._interceptors = interceptors
self._channel = None
self._credentials = None
if ssl:
assert (
Expand All @@ -85,38 +82,29 @@ def __init__(
)
super().__init__(svc, server_url)

@cached_property
@property
def channel(self):
if not self._channel:
if self._credentials is not None:
self._channel = aio.secure_channel(
self.server_url,
credentials=self._credentials,
options=self._options,
compression=self._compression,
interceptors=self._interceptors,
)
self._channel = aio.insecure_channel(
if self._credentials is not None:
return aio.secure_channel(
self.server_url,
credentials=self._credentials,
options=self._options,
compression=self._compression,
interceptors=self._interceptors,
)
else:
return aio.insecure_channel(
self.server_url,
options=self._options,
compression=self._compression,
interceptors=self._interceptors,
)
return self._channel

@cached_property
def _rpc_handler_mapping(self):
def _rpc_metadata(self):
# Currently all RPCs in BentoService are unary-unary
return {
method: {
"handler": self.channel.unary_unary(
method=method,
request_serializer=input_type.SerializeToString,
response_deserializer=output_type.FromString,
),
"input_type": input_type,
"output_type": output_type,
}
method: {"input_type": input_type, "output_type": output_type}
for method, input_type, output_type in (
(
f"/bentoml.grpc.{self._protocol_version}.BentoService/Call",
Expand All @@ -133,9 +121,9 @@ def _rpc_handler_mapping(self):

async def _invoke(self, method_name: str, **attrs: t.Any):
# channel kwargs include timeout, metadata, credentials, wait_for_ready and compression
# to pass it in kwargs add prefix _channel_<args>
# to pass it in kwargs add prefix _grpc_channel_<args>
channel_kwargs = {
k: attrs.pop(f"_channel_{k}", None)
k: attrs.pop(f"_grpc_channel_{k}", None)
for k in {
"timeout",
"metadata",
Expand All @@ -144,17 +132,20 @@ async def _invoke(self, method_name: str, **attrs: t.Any):
"compression",
}
}
if method_name not in self._rpc_handler_mapping:
if method_name not in self._rpc_metadata:
raise ValueError(
f"'{method_name}' is a yet supported rpc. Current supported are: {list(self._rpc_handler_mapping.keys())}"
f"'{method_name}' is a yet supported rpc. Current supported are: {self._rpc_metadata}"
)
rpc_handler = self._rpc_handler_mapping[method_name]
metadata = self._rpc_metadata[method_name]
rpc = self.channel.unary_unary(
method_name,
request_serializer=metadata["input_type"].SerializeToString,
response_deserializer=metadata["output_type"].FromString,
)

return await t.cast(
t.Awaitable[t.Any],
rpc_handler["handler"](
rpc_handler["input_type"](**attrs), **channel_kwargs
),
"t.Awaitable[Response]",
rpc(metadata["input_type"](**attrs), **channel_kwargs),
)

async def _call(
Expand All @@ -164,14 +155,16 @@ async def _call(
_bentoml_api: InferenceAPI,
**attrs: t.Any,
) -> t.Any:
if self.channel.get_state() != grpc.ChannelConnectivity.READY:
state = self.channel.get_state(try_to_connect=True)
if state != grpc.ChannelConnectivity.READY:
# create a blocking call to wait til channel is ready.
await self.channel.channel_ready()

fn = functools.partial(
self._invoke,
method_name=f"/bentoml.grpc.{self._protocol_version}.BentoService/Call",
**{
f"_channel_{k}": attrs.pop(f"_channel_{k}", None)
f"_grpc_channel_{k}": attrs.pop(f"_grpc_channel_{k}", None)
for k in {
"timeout",
"metadata",
Expand All @@ -192,11 +185,10 @@ async def _call(
serialized_req = await _bentoml_api.input.to_proto(inp)

# A call includes api_name and given proto_fields
_rev_apis = {v: k for k, v in self._svc.apis.items()}
api_fn = {v: k for k, v in self._svc.apis.items()}
return await fn(
f"/bentoml.grpc.{self._protocol_version}.BentoService/Call",
**{
"api_name": _rev_apis[_bentoml_api],
"api_name": api_fn[_bentoml_api],
_bentoml_api.input._proto_fields[0]: serialized_req,
},
)
Expand Down Expand Up @@ -266,15 +258,14 @@ def run():
)

# create an insecure channel to invoke ServiceMetadata rpc
with channel:
metadata = t.cast(
"ServiceMetadataResponse",
channel.unary_unary(
f"/bentoml.grpc.{protocol_version}.BentoService/ServiceMetadata",
request_serializer=pb.ServiceMetadataRequest.SerializeToString,
response_deserializer=pb.ServiceMetadataResponse.FromString,
)(pb.ServiceMetadataRequest()),
)
metadata = t.cast(
"ServiceMetadataResponse",
channel.unary_unary(
f"/bentoml.grpc.{protocol_version}.BentoService/ServiceMetadata",
request_serializer=pb.ServiceMetadataRequest.SerializeToString,
response_deserializer=pb.ServiceMetadataResponse.FromString,
)(pb.ServiceMetadataRequest()),
)
dummy_service = Service(metadata.name)

for api in metadata.apis:
Expand Down
2 changes: 2 additions & 0 deletions src/bentoml/_internal/client/http.py
Expand Up @@ -78,6 +78,8 @@ def from_url(cls, server_url: str, **kwargs: t.Any) -> HTTPClient:
async def _call(
self, inp: t.Any = None, *, _bentoml_api: InferenceAPI, **kwargs: t.Any
) -> t.Any:
# All gRPC kwargs should be poped out.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_grpc_")}
api = _bentoml_api

if api.multi_input:
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/bento_server_grpc/tests/test_metrics.py
Expand Up @@ -22,3 +22,5 @@ async def test_metrics_available(host: str):
compared=np.random.randint(255, size=(10, 10, 3)).astype("uint8"),
)
assert isinstance(resp, pb.Response)
resp = await client.async_ensure_metrics_are_registered("input_data")
assert isinstance(resp, pb.Response)

0 comments on commit ddd553a

Please sign in to comment.