Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
Signed-off-by: Aaron Pham <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Dec 6, 2022
1 parent 1466e0c commit be25b4f
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 25 deletions.
7 changes: 4 additions & 3 deletions src/bentoml/_internal/bento/build_dev_bentoml_whl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..utils.pkg import source_locations
from ...exceptions import BentoMLException
from ...exceptions import MissingDependencyException
from ...grpc.utils import LATEST_PROTOCOL_VERSION
from ..configuration import is_pypi_installed_bentoml

logger = logging.getLogger(__name__)
Expand All @@ -15,7 +16,7 @@


def build_bentoml_editable_wheel(
target_path: str, *, _internal_stubs_version: str = "v1"
target_path: str, *, _internal_protocol_version: str = LATEST_PROTOCOL_VERSION
) -> None:
"""
This is for BentoML developers to create Bentos that contains the local bentoml
Expand Down Expand Up @@ -52,10 +53,10 @@ def build_bentoml_editable_wheel(
bentoml_path = Path(module_location)

if not Path(
module_location, "grpc", _internal_stubs_version, "service_pb2.py"
module_location, "grpc", _internal_protocol_version, "service_pb2.py"
).exists():
raise ModuleNotFoundError(
f"Generated stubs for version {_internal_stubs_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_stubs_version}' beforehand to generate gRPC stubs."
f"Generated stubs for version {_internal_protocol_version} are missing. Make sure to run '{bentoml_path.as_posix()}/scripts/generate_grpc_stubs.sh {_internal_protocol_version}' beforehand to generate gRPC stubs."
) from None

# location to pyproject.toml
Expand Down
24 changes: 20 additions & 4 deletions src/bentoml/testing/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
import importlib
import traceback
from typing import TYPE_CHECKING
from contextlib import ExitStack
Expand All @@ -9,11 +10,11 @@
from bentoml.exceptions import BentoMLException
from bentoml.grpc.utils import import_grpc
from bentoml.grpc.utils import import_generated_stubs
from bentoml.grpc.utils import LATEST_PROTOCOL_VERSION
from bentoml._internal.utils import LazyLoader
from bentoml._internal.utils import reserve_free_port
from bentoml._internal.utils import cached_contextmanager
from bentoml._internal.utils import add_experimental_docstring
from bentoml._internal.server.grpc.servicer import create_bento_servicer

if TYPE_CHECKING:
import grpc
Expand All @@ -23,9 +24,9 @@
from grpc.aio._channel import Channel
from google.protobuf.message import Message

from bentoml import Service
from bentoml.grpc.v1 import service_pb2 as pb
else:
pb, _ = import_generated_stubs()
grpc, aio = import_grpc() # pylint: disable=E1111
np = LazyLoader("np", globals(), "numpy")

Expand All @@ -39,6 +40,20 @@
]


def create_bento_servicer(
protocol_version: str = LATEST_PROTOCOL_VERSION,
) -> t.Callable[[Service], t.Any]:
try:
module = importlib.import_module(
f".{protocol_version}", package="bentoml._internal.server.grpc.servicer"
)
return getattr(module, "create_bento_servicer")
except (ImportError, ModuleNotFoundError):
raise BentoMLException(
f"Failed to load servicer implementation for version {protocol_version}"
) from None


def randomize_pb_ndarray(shape: tuple[int, ...]) -> pb.NDArray:
arr: NDArray[np.float32] = t.cast("NDArray[np.float32]", np.random.rand(*shape))
return pb.NDArray(
Expand Down Expand Up @@ -76,7 +91,7 @@ async def async_client_call(
assert_code: grpc.StatusCode | None = None,
assert_details: str | None = None,
assert_trailing_metadata: aio.Metadata | None = None,
_internal_stubs_version: str = "v1",
protocol_version: str = "v1",
) -> pb.Response | None:
"""
Invoke a given API method via a client.
Expand All @@ -95,11 +110,12 @@ async def async_client_call(
Returns:
The response from the server.
"""
pb, _ = import_generated_stubs(protocol_version)

res: pb.Response | None = None
try:
Call = channel.unary_unary(
f"/bentoml.grpc.{_internal_stubs_version}.BentoService/Call",
f"/bentoml.grpc.{protocol_version}.BentoService/Call",
request_serializer=pb.Request.SerializeToString,
response_deserializer=pb.Response.FromString,
)
Expand Down
44 changes: 35 additions & 9 deletions src/bentoml/testing/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from bentoml._internal.utils import reserve_free_port
from bentoml._internal.utils import cached_contextmanager

from ..grpc.utils import LATEST_PROTOCOL_VERSION

if TYPE_CHECKING:
from grpc import aio
from grpc_health.v1 import health_pb2 as pb_health
Expand Down Expand Up @@ -75,7 +77,7 @@ async def server_warmup(
check_interval: float = 1,
popen: subprocess.Popen[t.Any] | None = None,
service_name: str | None = None,
_internal_stubs_version: str = "v1",
protocol_version: str = LATEST_PROTOCOL_VERSION,
) -> bool:
start_time = time.time()
proxy_handler = urllib.request.ProxyHandler({})
Expand All @@ -87,9 +89,7 @@ async def server_warmup(

try:
if service_name is None:
service_name = (
f"bentoml.grpc.{_internal_stubs_version}.BentoService"
)
service_name = f"bentoml.grpc.{protocol_version}.BentoService"
async with create_channel(host_url) as channel:
Check = channel.unary_unary(
"/grpc.health.v1.Health/Check",
Expand Down Expand Up @@ -177,14 +177,15 @@ def containerize(
subprocess.call([backend, "rmi", image_tag])


@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}")
@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}, {protocol_version}")
def run_bento_server_container(
image_tag: str,
config_file: str | None = None,
use_grpc: bool = False,
timeout: float = 90,
host: str = "127.0.0.1",
backend: str = "docker",
protocol_version: str = LATEST_PROTOCOL_VERSION,
):
"""
Launch a bentoml service container from a container, yield the host URL
Expand Down Expand Up @@ -227,7 +228,13 @@ def run_bento_server_container(
try:
host_url = f"{host}:{port}"
if asyncio.run(
server_warmup(host_url, timeout=timeout, popen=proc, grpc=use_grpc)
server_warmup(
host_url,
timeout=timeout,
popen=proc,
grpc=use_grpc,
protocol_version=protocol_version,
)
):
yield host_url
else:
Expand All @@ -247,6 +254,7 @@ def run_bento_server_standalone(
config_file: str | None = None,
timeout: float = 90,
host: str = "127.0.0.1",
protocol_version: str = LATEST_PROTOCOL_VERSION,
):
"""
Launch a bentoml service directly by the bentoml CLI, yields the host URL.
Expand Down Expand Up @@ -277,7 +285,13 @@ def run_bento_server_standalone(
try:
host_url = f"{host}:{server_port}"
assert asyncio.run(
server_warmup(host_url, timeout=timeout, popen=p, grpc=use_grpc)
server_warmup(
host_url,
timeout=timeout,
popen=p,
grpc=use_grpc,
protocol_version=protocol_version,
)
)
yield host_url
finally:
Expand All @@ -302,6 +316,7 @@ def run_bento_server_distributed(
use_grpc: bool = False,
timeout: float = 90,
host: str = "127.0.0.1",
protocol_version: str = LATEST_PROTOCOL_VERSION,
):
"""
Launch a bentoml service as a simulated distributed environment(Yatai), yields the host URL.
Expand Down Expand Up @@ -391,7 +406,14 @@ def run_bento_server_distributed(
)
try:
host_url = f"{host}:{server_port}"
asyncio.run(server_warmup(host_url, timeout=timeout, grpc=use_grpc))
asyncio.run(
server_warmup(
host_url,
timeout=timeout,
grpc=use_grpc,
protocol_version=protocol_version,
)
)
yield host_url
finally:
for p in processes:
Expand All @@ -404,7 +426,7 @@ def run_bento_server_distributed(


@cached_contextmanager(
"{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}"
"{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}, {protocol_version}"
)
def host_bento(
bento_name: str | Tag | None = None,
Expand All @@ -417,6 +439,7 @@ def host_bento(
host: str = "127.0.0.1",
timeout: float = 120,
backend: str = "docker",
protocol_version: str = LATEST_PROTOCOL_VERSION,
) -> t.Generator[str, None, None]:
"""
Host a bentoml service, yields the host URL.
Expand Down Expand Up @@ -473,6 +496,7 @@ def host_bento(
use_grpc=use_grpc,
host=host,
timeout=timeout,
protocol_version=protocol_version,
) as host_url:
yield host_url
elif deployment_mode == "container":
Expand All @@ -492,6 +516,7 @@ def host_bento(
host=host,
timeout=timeout,
backend=backend,
protocol_version=protocol_version,
) as host_url:
yield host_url
elif deployment_mode == "distributed":
Expand All @@ -501,6 +526,7 @@ def host_bento(
use_grpc=use_grpc,
host=host,
timeout=timeout,
protocol_version=protocol_version,
) as host_url:
yield host_url
else:
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/grpc/interceptors/test_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@
from google.protobuf import wrappers_pb2

from bentoml import Service
from bentoml.grpc.v1 import service_pb2_grpc as services
from bentoml.grpc.types import Request
from bentoml.grpc.types import Response
from bentoml.grpc.types import RpcMethodHandler
from bentoml.grpc.types import AsyncHandlerMethod
from bentoml.grpc.types import HandlerCallDetails
from bentoml.grpc.types import BentoServicerContext
else:
_, services = import_generated_stubs()
grpc, aio = import_grpc()
wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2")

Expand Down Expand Up @@ -125,7 +123,12 @@ async def test_trailing_metadata(caplog: LogCaptureFixture):

@pytest.mark.asyncio
@pytest.mark.usefixtures("propagate_logs")
async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: Service):
@pytest.mark.parametrize("protocol_version", ["v1", "v1alpha1"])
async def test_access_log_exception(
caplog: LogCaptureFixture, simple_service: Service, protocol_version: str
):
_, services = import_generated_stubs(protocol_version)

with make_standalone_server(
# we need to also setup opentelemetry interceptor
# to make sure the access log is correctly setup.
Expand All @@ -135,7 +138,7 @@ async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: S
]
) as (server, host_url):
services.add_BentoServiceServicer_to_server(
create_bento_servicer(simple_service), server
create_bento_servicer(protocol_version)(simple_service), server
)
try:
await server.start()
Expand All @@ -146,9 +149,10 @@ async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: S
channel=channel,
data={"text": wrappers_pb2.StringValue(value="asdf")},
assert_code=grpc.StatusCode.INTERNAL,
protocol_version=protocol_version,
)
assert (
"(scheme=http,path=/bentoml.grpc.v1.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)"
f"(scheme=http,path=/bentoml.grpc.{protocol_version}.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)"
in caplog.text
)
finally:
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/grpc/interceptors/test_prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
from google.protobuf import wrappers_pb2

from bentoml import Service
from bentoml.grpc.v1 import service_pb2_grpc as services
else:

_, services = import_generated_stubs()
wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2")
grpc, aio = import_grpc()

Expand Down Expand Up @@ -106,19 +103,23 @@ async def test_empty_metrics():
("gauge", ["api_name", "service_version", "service_name"]),
],
)
@pytest.mark.parametrize("protocol_version", ["v1", "v1alpha1"])
async def test_metrics_interceptors(
simple_service: Service,
metric_type: str,
parent_set: list[str],
protocol_version: str,
):
metrics_client = BentoMLContainer.metrics_client.get()

_, services = import_generated_stubs(protocol_version)

with make_standalone_server(interceptors=[interceptor]) as (
server,
host_url,
):
services.add_BentoServiceServicer_to_server(
create_bento_servicer(simple_service), server
create_bento_servicer(protocol_version)(simple_service), server
)
try:
await server.start()
Expand All @@ -127,6 +128,7 @@ async def test_metrics_interceptors(
"noop_sync",
channel=channel,
data={"text": wrappers_pb2.StringValue(value="BentoML")},
protocol_version=protocol_version,
)
for m in metrics_client.text_string_to_metric_families():
for sample in m.samples:
Expand Down

0 comments on commit be25b4f

Please sign in to comment.