Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
- Loading branch information
Showing
47 changed files
with
2,755 additions
and
446 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from __future__ import annotations | ||
|
||
import typing as t | ||
import traceback | ||
from typing import TYPE_CHECKING | ||
from contextlib import contextmanager | ||
from contextlib import asynccontextmanager | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
from bentoml._internal.utils import LazyLoader | ||
|
||
from ._io import make_pb_ndarray | ||
from ._io import randomize_pb_ndarray | ||
|
||
if TYPE_CHECKING: | ||
import grpc | ||
from grpc import aio | ||
from grpc.aio._channel import Channel | ||
from google.protobuf.message import Message | ||
|
||
from bentoml.grpc.v1alpha1 import service_pb2 as pb | ||
else: | ||
from bentoml.grpc.utils import import_generated_stubs | ||
|
||
pb, _ = import_generated_stubs() | ||
exception_msg = ( | ||
"'grpcio' is not installed. Please install it with 'pip install -U grpcio'" | ||
) | ||
grpc = LazyLoader("grpc", globals(), "grpc", exc_msg=exception_msg) | ||
aio = LazyLoader("aio", globals(), "grpc.aio", exc_msg=exception_msg) | ||
|
||
__all__ = [ | ||
"async_client_call", | ||
"randomize_pb_ndarray", | ||
"make_pb_ndarray", | ||
"create_channel", | ||
"make_standalone_server", | ||
] | ||
|
||
|
||
async def async_client_call( | ||
method: str, | ||
channel: Channel, | ||
data: dict[str, Message | pb.Part | bytes | str | dict[str, t.Any]], | ||
assert_data: pb.Response | t.Callable[[pb.Response], bool] | None = None, | ||
assert_code: grpc.StatusCode | None = None, | ||
assert_details: str | None = None, | ||
timeout: int | None = None, | ||
sanity: bool = True, | ||
) -> pb.Response: | ||
""" | ||
Note that to use this function, 'channel' should not be created with any client interceptors, | ||
since we will handle interceptors' lifecycle separately. | ||
This function will also mimic the generated stubs function 'Call' from given 'channel'. | ||
Args: | ||
method: The method name to call. | ||
channel: The given aio.Channel to use for invoking the RPC. | ||
data: The data to send to the server. | ||
assert_data: The data to assert against the response. | ||
assert_code: The code to assert against the response. | ||
assert_details: The details to assert against the response. | ||
timeout: The timeout for the RPC. | ||
sanity: Whether to perform sanity check on the response. | ||
Returns: | ||
The response from the server. | ||
""" | ||
from bentoml.testing.grpc.interceptors import AssertClientInterceptor | ||
|
||
if assert_code is None: | ||
# by default, we want to check if the request is healthy | ||
assert_code = grpc.StatusCode.OK | ||
assert ( | ||
len( | ||
list( | ||
filter( | ||
lambda x: len(x) != 0, | ||
map( | ||
lambda stack: getattr(channel, stack), | ||
[ | ||
"_unary_unary_interceptors", | ||
"_unary_stream_interceptors", | ||
"_stream_unary_interceptors", | ||
"_stream_stream_interceptors", | ||
], | ||
), | ||
) | ||
) | ||
) | ||
== 0 | ||
), "'channel' shouldn't have any interceptors." | ||
try: | ||
# we will handle adding our testing interceptors here. | ||
# prefer not to use private attributes, but this will do | ||
channel._unary_unary_interceptors.append( # type: ignore (private warning) | ||
AssertClientInterceptor( | ||
assert_code=assert_code, assert_details=assert_details | ||
) | ||
) | ||
Call = channel.unary_unary( | ||
"/bentoml.grpc.v1alpha1.BentoService/Call", | ||
request_serializer=pb.Request.SerializeToString, | ||
response_deserializer=pb.Response.FromString, | ||
) | ||
output = await t.cast( | ||
t.Awaitable[pb.Response], | ||
Call(pb.Request(api_name=method, **data), timeout=timeout), | ||
) | ||
if sanity: | ||
assert output | ||
if assert_data: | ||
try: | ||
if callable(assert_data): | ||
assert assert_data(output) | ||
else: | ||
assert output == assert_data | ||
except AssertionError: | ||
raise AssertionError(f"Failed while checking data: {output}") from None | ||
return output | ||
finally: | ||
# we will reset interceptors per call | ||
channel._unary_unary_interceptors = [] # type: ignore (private warning) | ||
|
||
|
||
@asynccontextmanager | ||
async def create_channel( | ||
host_url: str, interceptors: t.Sequence[aio.ClientInterceptor] | None = None | ||
) -> t.AsyncGenerator[Channel, None]: | ||
channel: Channel | None = None | ||
try: | ||
async with aio.insecure_channel(host_url, interceptors=interceptors) as channel: | ||
# create a blocking call to wait til channel is ready. | ||
await channel.channel_ready() | ||
yield channel | ||
except aio.AioRpcError as e: | ||
traceback.print_exc() | ||
raise e from None | ||
finally: | ||
if channel: | ||
await channel.close() | ||
|
||
|
||
@contextmanager | ||
def make_standalone_server( | ||
bind_address: str, interceptors: t.Sequence[aio.ServerInterceptor] | None = None | ||
) -> t.Generator[aio.Server, None, None]: | ||
server = aio.server( | ||
interceptors=interceptors, | ||
migration_thread_pool=ThreadPoolExecutor(max_workers=1), | ||
options=(("grpc.so_reuseport", 0),), | ||
) | ||
server.add_insecure_port(bind_address) | ||
yield server |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from __future__ import annotations | ||
|
||
import typing as t | ||
from typing import TYPE_CHECKING | ||
|
||
from bentoml.exceptions import BentoMLException | ||
from bentoml._internal.utils import LazyLoader | ||
|
||
if TYPE_CHECKING: | ||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
from bentoml.grpc.v1alpha1 import service_pb2 as pb | ||
else: | ||
from bentoml.grpc.utils import import_generated_stubs | ||
|
||
pb, _ = import_generated_stubs() | ||
np = LazyLoader("np", globals(), "numpy") | ||
|
||
|
||
def randomize_pb_ndarray(shape: tuple[int, ...]) -> pb.NDArray: | ||
arr: NDArray[np.float32] = t.cast("NDArray[np.float32]", np.random.rand(*shape)) | ||
return pb.NDArray( | ||
shape=list(shape), dtype=pb.NDArray.DTYPE_FLOAT, float_values=arr.ravel() | ||
) | ||
|
||
|
||
def make_pb_ndarray(arr: NDArray[t.Any]) -> pb.NDArray: | ||
from bentoml._internal.io_descriptors.numpy import npdtype_to_dtypepb_map | ||
from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map | ||
|
||
try: | ||
fieldpb = npdtype_to_fieldpb_map()[arr.dtype] | ||
dtypepb = npdtype_to_dtypepb_map()[arr.dtype] | ||
return pb.NDArray( | ||
**{ | ||
fieldpb: arr.ravel().tolist(), | ||
"dtype": dtypepb, | ||
"shape": tuple(arr.shape), | ||
}, | ||
) | ||
except KeyError: | ||
raise BentoMLException( | ||
f"Unsupported dtype '{arr.dtype}' for response message.", | ||
) from None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
import typing as t | ||
from typing import TYPE_CHECKING | ||
|
||
from bentoml._internal.utils import LazyLoader | ||
|
||
if TYPE_CHECKING: | ||
import grpc | ||
from grpc import aio | ||
|
||
from bentoml.grpc.types import Request | ||
from bentoml.grpc.types import BentoUnaryUnaryCall | ||
else: | ||
aio = LazyLoader("aio", globals(), "grpc.aio") | ||
|
||
|
||
class AssertClientInterceptor(aio.UnaryUnaryClientInterceptor): | ||
def __init__( | ||
self, | ||
assert_code: grpc.StatusCode | None = None, | ||
assert_details: str | None = None, | ||
assert_trailing_metadata: aio.Metadata | None = None, | ||
): | ||
self._assert_code = assert_code | ||
self._assert_details = assert_details | ||
self._assert_trailing_metadata = assert_trailing_metadata | ||
|
||
async def intercept_unary_unary( # type: ignore (unable to infer types from parameters) | ||
self, | ||
continuation: t.Callable[[aio.ClientCallDetails, Request], BentoUnaryUnaryCall], | ||
client_call_details: aio.ClientCallDetails, | ||
request: Request, | ||
) -> BentoUnaryUnaryCall: | ||
# Note that we cast twice here since grpc.aio._call.UnaryUnaryCall | ||
# implements __await__, which returns ResponseType. However, pyright | ||
# are unable to determine types from given mixin. | ||
# | ||
# continuation(client_call_details, request) -> call: UnaryUnaryCall | ||
# await call -> ResponseType | ||
call = await t.cast( | ||
"t.Awaitable[BentoUnaryUnaryCall]", | ||
continuation(client_call_details, request), | ||
) | ||
try: | ||
code = await call.code() | ||
details = await call.details() | ||
trailing_metadata = await call.trailing_metadata() | ||
if self._assert_code: | ||
assert ( | ||
code == self._assert_code | ||
), f"{repr(call)} returns {await call.code()} while expecting {self._assert_code}." | ||
if self._assert_details: | ||
assert ( | ||
self._assert_details in details | ||
), f"'{self._assert_details}' is not in {await call.details()}." | ||
if self._assert_trailing_metadata: | ||
assert ( | ||
self._assert_trailing_metadata == trailing_metadata | ||
), f"Trailing metadata '{trailing_metadata}' while expecting '{self._assert_trailing_metadata}'." | ||
return call | ||
except AssertionError as e: | ||
raise e from None |
Oops, something went wrong.