diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dbbdeaf9f9..03cacb91f6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,6 +13,11 @@ env: LINES: 120 COLUMNS: 120 +# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun +defaults: + run: + shell: bash --noprofile --norc -exo pipefail {0} + jobs: diff: runs-on: ubuntu-latest @@ -34,7 +39,10 @@ jobs: - scripts/ci/config.yml - scripts/ci/run_tests.sh - requirements/tests-requirements.txt + protos: &protos + - "bentoml/grpc/**/*.proto" bentoml: + - *protos - *related - "bentoml/**" - "tests/**" @@ -46,9 +54,6 @@ jobs: codestyle_check: runs-on: ubuntu-latest - defaults: - run: - shell: bash needs: - diff @@ -72,9 +77,13 @@ jobs: uses: actions/setup-node@v3 with: node-version: "17" - - name: install pyright + - name: Install pyright run: | npm install -g npm@^7 pyright + - name: Setup bufbuild/buf + uses: bufbuild/buf-setup-action@v1.8.0 + with: + github_token: ${{ github.token }} - name: Cache pip dependencies uses: actions/cache@v3 @@ -94,12 +103,11 @@ jobs: run: make ci-lint - name: Type check run: make ci-pyright + - name: Proto check + if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.protos == 'true') || github.event_name == 'push' }} + run: buf lint --config "bentoml/grpc/buf.yaml" --error-format msvs --path "bentoml/grpc" documentation_spelling_check: - defaults: - run: - shell: bash - runs-on: ubuntu-latest needs: - diff @@ -138,7 +146,6 @@ jobs: - name: Run spellcheck script run: make spellcheck-docs - shell: bash unit_tests: needs: @@ -149,9 +156,6 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.7", "3.8", "3.9", "3.10"] - defaults: - run: - shell: bash if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.bentoml == 'true') || github.event_name == 'push' }} name: python${{ matrix.python-version }}_unit_tests (${{ matrix.os }}) @@ -182,17 +186,18 @@ jobs: - name: Install dependencies run: | - pip install . + pip install ".[grpc]" pip install -r requirements/tests-requirements.txt - name: Run unit tests - if: ${{ matrix.os != 'windows-latest' }} - run: make tests-unit - - - name: Run unit tests (Windows) - if: ${{ matrix.os == 'windows-latest' }} - run: make tests-unit - shell: bash + run: | + OPTS=(--cov-config pyproject.toml --cov-report=xml:unit.xml -vvv) + if [ "${{ matrix.os }}" != 'windows-latest' ]; then + # we will use pytest-xdist to improve tests run-time. + OPTS=(${OPTS[@]} --dist loadfile -n auto --run-grpc-tests) + fi + # Now run the unit tests + python -m pytest tests/unit "${OPTS[@]}" - name: Upload test coverage to Codecov uses: codecov/codecov-action@v3 @@ -213,12 +218,13 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] python-version: ["3.7", "3.8", "3.9", "3.10"] - defaults: - run: - shell: bash + server_type: ["http", "grpc"] + exclude: + - os: windows-latest + server_type: "grpc" if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.bentoml == 'true') || github.event_name == 'push' }} - name: python${{ matrix.python-version }}_e2e_tests (${{ matrix.os }}) + name: python${{ matrix.python-version }}_${{ matrix.server_type }}_e2e_tests (${{ matrix.os }}) runs-on: ${{ matrix.os }} timeout-minutes: 20 @@ -256,24 +262,29 @@ jobs: path: ${{ steps.cache-dir.outputs.dir }} key: ${{ runner.os }}-tests-${{ hashFiles('requirements/tests-requirements.txt') }} - - name: Install dependencies + - name: Install dependencies for ${{ matrix.server_type }}-based tests. run: | - pip install -e ".[grpc]" pip install -r requirements/tests-requirements.txt - pip install -r tests/e2e/bento_server_general_features/requirements.txt - - - name: Export Action Envvar - run: export GITHUB_ACTION=true - - - name: Run tests and generate coverage report - run: ./scripts/ci/run_tests.sh general_features + if [ "${{ matrix.server_type }}" == 'grpc' ]; then + pip install -e ".[grpc]" + else + pip install -e . + fi + if [ -f "tests/e2e/bento_server_${{ matrix.server_type }}/requirements.txt" ]; then + pip install -r tests/e2e/bento_server_${{ matrix.server_type }}/requirements.txt + fi + + - name: Run ${{ matrix.server_type }} tests and generate coverage report + run: ./scripts/ci/run_tests.sh ${{ matrix.server_type }}_server --verbose - name: Upload test coverage to Codecov uses: codecov/codecov-action@v3 with: - flags: e2e-tests + flags: e2e-tests-${{ matrix.server_type }} + name: codecov-${{ matrix.os }}-python${{ matrix.python-version }}-e2e + fail_ci_if_error: true directory: ./ - files: ./tests/e2e/bento_server_general_features/general_features.xml + files: ./tests/e2e/bento_server_${{ matrix.server_type }}/${{ matrix.server_type }}_server.xml verbose: true concurrency: diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index a16ea34851..911115d938 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -353,7 +353,7 @@ Flags: If `pytest_additional_arguments` is given, the additional arguments will be passed to all of the tests run by the tests script. Example: - $ ./scripts/ci/run_tests.sh pytorch --gpus --capture=tee-sys + $ ./scripts/ci/run_tests.sh pytorch --run-gpus-tests --capture=tee-sys ``` All tests are then defined under [config.yml](./scripts/ci/config.yml) where each field follows the following format: diff --git a/Makefile b/Makefile index 8d781d9802..5d99bfd5cf 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,7 @@ SHELL := /bin/bash GIT_ROOT ?= $(shell git rev-parse --show-toplevel) USE_VERBOSE ?=false USE_GPU ?= false +USE_GRPC ?= false help: ## Show all Makefile targets @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' @@ -49,7 +50,9 @@ tests-%: ifeq ($(USE_VERBOSE),true) ./scripts/ci/run_tests.sh -v $(type) $(__positional) else ifeq ($(USE_GPU),true) - ./scripts/ci/run_tests.sh -v $(type) --gpus $(__positional) + ./scripts/ci/run_tests.sh -v $(type) --run-gpu-tests $(__positional) +else ifeq ($(USE_GPRC),true) + ./scripts/ci/run_tests.sh -v $(type) --run-gprc-tests $(__positional) else ./scripts/ci/run_tests.sh $(type) $(__positional) endif diff --git a/bentoml/_internal/io_descriptors/file.py b/bentoml/_internal/io_descriptors/file.py index 452088578c..029c34aff0 100644 --- a/bentoml/_internal/io_descriptors/file.py +++ b/bentoml/_internal/io_descriptors/file.py @@ -227,7 +227,7 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]: mime_type = mapping[field.kind] if mime_type != self._mime_type: raise BadInput( - f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'", + f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'", ) except KeyError: raise BadInput( diff --git a/bentoml/_internal/io_descriptors/image.py b/bentoml/_internal/io_descriptors/image.py index 0f0ba23e4a..b803ca6e71 100644 --- a/bentoml/_internal/io_descriptors/image.py +++ b/bentoml/_internal/io_descriptors/image.py @@ -358,7 +358,7 @@ async def from_proto(self, field: pb.File | bytes) -> ImageType: mime_type = mapping[field.kind] if mime_type != self._mime_type: raise BadInput( - f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'", + f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'", ) except KeyError: raise BadInput( diff --git a/bentoml/_internal/io_descriptors/multipart.py b/bentoml/_internal/io_descriptors/multipart.py index afc62190b4..0d31efadc6 100644 --- a/bentoml/_internal/io_descriptors/multipart.py +++ b/bentoml/_internal/io_descriptors/multipart.py @@ -143,12 +143,12 @@ async def predict( | +--------------------------------------------------------+ | | | | | | | Multipart(arr=NumpyNdarray(), annotations=JSON()) | | - | | | | - | +----------------+-----------------------+---------------+ | - | | | | - | | | | - | | | | - | +----+ +---------+ | + | | | | | | + | +---------------+-----------------------+----------------+ | + | | | | + | | | | + | | | | + | +-----+ +--------+ | | | | | | +---------------v--------v---------+ | | | def predict(arr, annotations): | | @@ -236,28 +236,33 @@ async def to_http_response( def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None: if len(set(field) - set(self._inputs)) != 0: raise InvalidArgument( - f"'{repr(self)}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}", + f"'{self!r}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}", ) from None async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]: + from bentoml.grpc.utils import validate_proto_fields + if isinstance(field, bytes): raise InvalidArgument( f"cannot use 'serialized_bytes' with {self.__class__.__name__}" ) from None message = field.fields self.validate_input_mapping(message) + to_populate = {self._inputs[k]: message[k] for k in self._inputs} reqs = await asyncio.gather( *tuple( - io_.from_proto(getattr(input_pb, io_._proto_fields[0])) - for io_, input_pb in self.io_fields_mapping(message).items() + descriptor.from_proto( + getattr( + part, + validate_proto_fields( + part.WhichOneof("representation"), descriptor + ), + ) + ) + for descriptor, part in to_populate.items() ) ) - return dict(zip(message, reqs)) - - def io_fields_mapping( - self, message: t.MutableMapping[str, pb.Part] - ) -> dict[IODescriptor[t.Any], pb.Part]: - return {io_: part for io_, part in zip(self._inputs.values(), message.values())} + return dict(zip(self._inputs.keys(), reqs)) async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart: self.validate_input_mapping(obj) @@ -268,13 +273,14 @@ async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart: ) ) return pb.Multipart( - fields={ - key: pb.Part( - **{ - io_._proto_fields[0]: resp + fields=dict( + zip( + obj, + [ + # TODO: support multiple proto_fields + pb.Part(**{io_._proto_fields[0]: resp}) for io_, resp in zip(self._inputs.values(), resps) - } + ], ) - for key in obj - } + ) ) diff --git a/bentoml/_internal/server/grpc/servicer.py b/bentoml/_internal/server/grpc/servicer.py index cbce08b3a6..83fd021d4e 100644 --- a/bentoml/_internal/server/grpc/servicer.py +++ b/bentoml/_internal/server/grpc/servicer.py @@ -9,6 +9,7 @@ import anyio from bentoml.grpc.utils import grpc_status_code +from bentoml.grpc.utils import validate_proto_fields from ....exceptions import InvalidArgument from ....exceptions import BentoMLException @@ -27,7 +28,6 @@ from bentoml.grpc.types import AddServicerFn from bentoml.grpc.types import ServicerClass from bentoml.grpc.types import BentoServicerContext - from bentoml.grpc.types import GeneratedProtocolMessageType from bentoml.grpc.v1alpha1 import service_pb2 as pb from bentoml.grpc.v1alpha1 import service_pb2_grpc as services @@ -148,28 +148,24 @@ async def Call( # type: ignore (no async types) # pylint: disable=invalid-overr # We will use fields descriptor to determine how to process that request. try: # we will check if the given fields list contains a pb.Multipart. - field = request.WhichOneof("content") - if field is None: - raise InvalidArgument("Request cannot be empty.") - accepted_fields = api.input._proto_fields + ("serialized_bytes",) - if field not in accepted_fields: - raise InvalidArgument( - f"'{api.input.__class__.__name__}' accepts one of the following fields: '{', '.join(accepted_fields)}', and none of them are found in the request message.", - ) from None - input_ = await api.input.from_proto(getattr(request, field)) + input_proto = getattr( + request, + validate_proto_fields(request.WhichOneof("content"), api.input), + ) + input_data = await api.input.from_proto(input_proto) if asyncio.iscoroutinefunction(api.func): if isinstance(api.input, Multipart): - output = await api.func(**input_) + output = await api.func(**input_data) else: - output = await api.func(input_) + output = await api.func(input_data) else: if isinstance(api.input, Multipart): - output = await anyio.to_thread.run_sync(api.func, **input_) + output = await anyio.to_thread.run_sync(api.func, **input_data) else: - output = await anyio.to_thread.run_sync(api.func, input_) - protos = await api.output.to_proto(output) + output = await anyio.to_thread.run_sync(api.func, input_data) + res = await api.output.to_proto(output) # TODO(aarnphm): support multiple proto fields - response = pb.Response(**{api.output._proto_fields[0]: protos}) + response = pb.Response(**{api.output._proto_fields[0]: res}) except BentoMLException as e: log_exception(request, sys.exc_info()) await context.abort(code=grpc_status_code(e), details=e.message) diff --git a/bentoml/bentos.py b/bentoml/bentos.py index 1196f1a690..06b1b6daa6 100644 --- a/bentoml/bentos.py +++ b/bentoml/bentos.py @@ -424,7 +424,7 @@ def construct_dockerfile( with open(bento.path_of(dockerfile_path), "r") as f: FINAL_DOCKERFILE = f"""\ {f.read()} -FROM base-{bento.info.docker.distro} +FROM base-{bento.info.docker.distro} as final # Additional instructions for final image. {final_instruction} """ diff --git a/bentoml/grpc/utils/__init__.py b/bentoml/grpc/utils/__init__.py index 4f252146cc..8d75fa3975 100644 --- a/bentoml/grpc/utils/__init__.py +++ b/bentoml/grpc/utils/__init__.py @@ -7,18 +7,20 @@ from functools import lru_cache from dataclasses import dataclass -from bentoml._internal.utils.lazy_loader import LazyLoader +from bentoml.exceptions import InvalidArgument if TYPE_CHECKING: import types from enum import Enum import grpc - from google.protobuf import descriptor as descriptor_mod from bentoml.exceptions import BentoMLException + from bentoml.grpc.types import ProtoField from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import BentoServicerContext from bentoml.grpc.v1alpha1 import service_pb2 as pb + from bentoml._internal.io_descriptors import IODescriptor # We need this here so that __all__ is detected due to lazy import def import_generated_stubs( @@ -35,9 +37,6 @@ def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: pb, _ = import_generated_stubs() grpc, _ = import_grpc() - descriptor_mod = LazyLoader( - "descriptor_mod", globals(), "google.protobuf.descriptor" - ) __all__ = [ "grpc_status_code", @@ -46,6 +45,7 @@ def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: "GRPC_CONTENT_TYPE", "import_generated_stubs", "import_grpc", + "validate_proto_fields", ] logger = logging.getLogger(__name__) @@ -54,26 +54,17 @@ def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: GRPC_CONTENT_TYPE = "application/grpc" -def get_field_by_name( - descriptor: descriptor_mod.FieldDescriptor | descriptor_mod.Descriptor, - field: str, -) -> descriptor_mod.FieldDescriptor: - if isinstance(descriptor, descriptor_mod.FieldDescriptor): - # descriptor is a FieldDescriptor - return descriptor.message_type.fields_by_name[field] - elif isinstance(descriptor, descriptor_mod.Descriptor): - # descriptor is a Descriptor - return descriptor.fields_by_name[field] - else: - raise NotImplementedError(f"Type {type(descriptor)} is not yet supported.") - - -def is_map_field(field: descriptor_mod.FieldDescriptor) -> bool: - return ( - field.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE - and field.message_type.has_options - and field.message_type.GetOptions().map_entry - ) +def validate_proto_fields( + field: str | None, io_: IODescriptor[t.Any] +) -> str | ProtoField: + if field is None: + raise InvalidArgument('"field" cannot be empty.') + accepted_fields = io_._proto_fields + ("serialized_bytes",) + if field not in accepted_fields: + raise InvalidArgument( + f"'{io_.__class__.__name__}' accepts one of the following fields: '{','.join(accepted_fields)}' got '{field}' instead.", + ) from None + return field @lru_cache(maxsize=1) @@ -179,7 +170,13 @@ def parse_method_name(method_name: str) -> tuple[MethodName, bool]: def wrap_rpc_handler( - wrapper: t.Callable[..., t.Any], + wrapper: t.Callable[ + ..., + t.Callable[ + [pb.Request, BentoServicerContext], + t.Coroutine[t.Any, t.Any, pb.Response | t.Awaitable[pb.Response]], + ], + ], handler: RpcMethodHandler | None, ) -> RpcMethodHandler | None: if not handler: diff --git a/bentoml/models.py b/bentoml/models.py index b9d1490cec..f6d2cea9c4 100644 --- a/bentoml/models.py +++ b/bentoml/models.py @@ -221,15 +221,10 @@ def push( @inject -def pull( - tag: t.Union[Tag, str], - *, - force: bool = False, - _model_store: "ModelStore" = Provide[BentoMLContainer.model_store], -) -> Model: +def pull(tag: t.Union[Tag, str], *, force: bool = False) -> Model: from bentoml._internal.yatai_client import yatai_client - yatai_client.pull_model(tag, force=force) + return yatai_client.pull_model(tag, force=force) @inject diff --git a/bentoml/testing/grpc/__init__.py b/bentoml/testing/grpc/__init__.py new file mode 100644 index 0000000000..9ebf307311 --- /dev/null +++ b/bentoml/testing/grpc/__init__.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import typing as t +import traceback +from typing import TYPE_CHECKING +from contextlib import ExitStack +from contextlib import asynccontextmanager + +from bentoml.exceptions import BentoMLException +from bentoml._internal.utils import LazyLoader +from bentoml._internal.utils import reserve_free_port +from bentoml._internal.utils import cached_contextmanager +from bentoml._internal.utils import add_experimental_docstring +from bentoml._internal.server.grpc.servicer import create_bento_servicer + +from .servicer import TestServiceServicer + +if TYPE_CHECKING: + import grpc + import numpy as np + from grpc import aio + from numpy.typing import NDArray + from grpc.aio._channel import Channel + from google.protobuf.message import Message + + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from bentoml.grpc.v1alpha1 import service_test_pb2_grpc as services_test +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + _, services_test = import_generated_stubs(file="service_test.proto") + grpc, aio = import_grpc() # pylint: disable=E1111 + np = LazyLoader("np", globals(), "numpy") + +__all__ = [ + "async_client_call", + "randomize_pb_ndarray", + "make_pb_ndarray", + "create_channel", + "make_standalone_server", + "TestServiceServicer", + "create_bento_servicer", +] + + +def randomize_pb_ndarray(shape: tuple[int, ...]) -> pb.NDArray: + arr: NDArray[np.float32] = t.cast("NDArray[np.float32]", np.random.rand(*shape)) + return pb.NDArray( + shape=list(shape), dtype=pb.NDArray.DTYPE_FLOAT, float_values=arr.ravel() + ) + + +def make_pb_ndarray(arr: NDArray[t.Any]) -> pb.NDArray: + from bentoml._internal.io_descriptors.numpy import npdtype_to_dtypepb_map + from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map + + try: + fieldpb = npdtype_to_fieldpb_map()[arr.dtype] + dtypepb = npdtype_to_dtypepb_map()[arr.dtype] + return pb.NDArray( + **{ + fieldpb: arr.ravel().tolist(), + "dtype": dtypepb, + "shape": tuple(arr.shape), + }, + ) + except KeyError: + raise BentoMLException( + f"Unsupported dtype '{arr.dtype}' for response message.", + ) from None + + +async def async_client_call( + method: str, + channel: Channel, + data: dict[str, Message | pb.Part | bytes | str | dict[str, t.Any]], + assert_data: pb.Response | t.Callable[[pb.Response], bool] | None = None, + assert_code: grpc.StatusCode | None = None, + assert_details: str | None = None, + timeout: int | None = None, + sanity: bool = True, +) -> pb.Response: + """ + Note that to use this function, 'channel' should not be created with any client interceptors, + since we will handle interceptors' lifecycle separately. + + This function will also mimic the generated stubs function 'Call' from given 'channel'. + + Args: + method: The method name to call. + channel: The given aio.Channel to use for invoking the RPC. Channels shouldn't include + any client interceptors. as we will handle interceptors' lifecycle separately. + data: The data to send to the server. + assert_data: The data to assert against the response. + assert_code: The code to assert against the response. + assert_details: The details to assert against the response. + timeout: The timeout for the RPC. + sanity: Whether to perform sanity check on the response. + + Returns: + The response from the server. + """ + from bentoml.testing.grpc.interceptors import AssertClientInterceptor + + if assert_code is None: + # by default, we want to check if the request is healthy + assert_code = grpc.StatusCode.OK + # We will add our own interceptors to the channel, which means + # We will have to check whether channel already has interceptors. + assert ( + len( + list( + filter( + lambda x: len(x) != 0, + map( + lambda stack: getattr(channel, stack), + [ + "_unary_unary_interceptors", + "_unary_stream_interceptors", + "_stream_unary_interceptors", + "_stream_stream_interceptors", + ], + ), + ) + ) + ) + == 0 + ), "'channel' shouldn't have any interceptors." + try: + # we will handle adding our testing interceptors here. + # prefer not to use private attributes, but this will do + channel._unary_unary_interceptors.append( # type: ignore (private warning) + AssertClientInterceptor( + assert_code=assert_code, assert_details=assert_details + ) + ) + Call = channel.unary_unary( + "/bentoml.grpc.v1alpha1.BentoService/Call", + request_serializer=pb.Request.SerializeToString, + response_deserializer=pb.Response.FromString, + ) + output = await t.cast( + t.Awaitable[pb.Response], + Call(pb.Request(api_name=method, **data), timeout=timeout), + ) + if sanity: + assert output + if assert_data: + try: + if callable(assert_data): + assert assert_data(output) + else: + assert output == assert_data + except AssertionError: + raise AssertionError(f"Failed while checking data: {output}") from None + return output + finally: + # we will reset interceptors per call + channel._unary_unary_interceptors = [] # type: ignore (private warning) + + +@asynccontextmanager +@add_experimental_docstring +async def create_channel( + host_url: str, interceptors: t.Sequence[aio.ClientInterceptor] | None = None +) -> t.AsyncGenerator[Channel, None]: + """Create an async channel with given host_url and client interceptors.""" + channel: Channel | None = None + try: + async with aio.insecure_channel(host_url, interceptors=interceptors) as channel: + # create a blocking call to wait til channel is ready. + await channel.channel_ready() + yield channel + except aio.AioRpcError as e: + traceback.print_exc() + raise e from None + finally: + if channel: + await channel.close() + + +@add_experimental_docstring +@cached_contextmanager("{interceptors}") +def make_standalone_server( + interceptors: t.Sequence[aio.ServerInterceptor] | None = None, + host: str = "0.0.0.0", +) -> t.Generator[tuple[aio.Server, str], None, None]: + """ + Create a standalone aio.Server for testing. + + Args: + interceptors: The interceptors to use for the server, default to None. + + Returns: + The server and the host_url. + + Example for async test cases: + + .. code-block:: python + + async def test_some_async(): + with make_standalone_server() as (server, host_url): + try: + await server.start() + channel = grpc.aio.insecure_channel(host_url) + ... # test code here + finally: + await server.stop(None) + + Example for sync test cases: + + .. code-block:: python + + def test_cases(): + import asyncio + + loop = asyncio.new_event_loop() + with make_standalone_server() as (server, host_url): + try: + loop.run_until_complete(server.start()) + channel = grpc.insecure_channel(host_url) + ... # test code here + finally: + loop.call_soon_threadsafe(lambda: asyncio.ensure_future(server.stop(None))) + loop.close() + assert loop.is_closed() + """ + stack = ExitStack() + port = stack.enter_context(reserve_free_port(enable_so_reuseport=True)) + server = aio.server( + interceptors=interceptors, + options=(("grpc.so_reuseport", 1),), + ) + services_test.add_TestServiceServicer_to_server(TestServiceServicer(), server) # type: ignore (no async types) # pylint: disable=E0601 + server.add_insecure_port(f"{host}:{port}") + print("Using port %d..." % port) + try: + yield server, "localhost:%d" % port + finally: + stack.close() diff --git a/bentoml/testing/grpc/interceptors.py b/bentoml/testing/grpc/interceptors.py new file mode 100644 index 0000000000..737b38b0d2 --- /dev/null +++ b/bentoml/testing/grpc/interceptors.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +from bentoml._internal.utils import LazyLoader + +if TYPE_CHECKING: + import grpc + from grpc import aio + + from bentoml.grpc.types import Request + from bentoml.grpc.types import BentoUnaryUnaryCall +else: + aio = LazyLoader("aio", globals(), "grpc.aio") + + +class AssertClientInterceptor(aio.UnaryUnaryClientInterceptor): + def __init__( + self, + assert_code: grpc.StatusCode | None = None, + assert_details: str | None = None, + assert_trailing_metadata: aio.Metadata | None = None, + ): + self._assert_code = assert_code + self._assert_details = assert_details + self._assert_trailing_metadata = assert_trailing_metadata + + async def intercept_unary_unary( # type: ignore (unable to infer types from parameters) + self, + continuation: t.Callable[[aio.ClientCallDetails, Request], BentoUnaryUnaryCall], + client_call_details: aio.ClientCallDetails, + request: Request, + ) -> BentoUnaryUnaryCall: + # Note that we cast twice here since grpc.aio._call.UnaryUnaryCall + # implements __await__, which returns ResponseType. However, pyright + # are unable to determine types from given mixin. + # + # continuation(client_call_details, request) -> call: UnaryUnaryCall + # await call -> ResponseType + call = await t.cast( + "t.Awaitable[BentoUnaryUnaryCall]", + continuation(client_call_details, request), + ) + try: + code = await call.code() + details = await call.details() + trailing_metadata = await call.trailing_metadata() + if self._assert_code: + assert ( + code == self._assert_code + ), f"{call!r} returns {await call.code()} while expecting {self._assert_code}." + if self._assert_details: + assert ( + self._assert_details in details + ), f"'{self._assert_details}' is not in {await call.details()}." + if self._assert_trailing_metadata: + assert ( + self._assert_trailing_metadata == trailing_metadata + ), f"Trailing metadata '{trailing_metadata}' while expecting '{self._assert_trailing_metadata}'." + return call + except AssertionError as e: + raise e from None diff --git a/bentoml/testing/grpc/servicer.py b/bentoml/testing/grpc/servicer.py new file mode 100644 index 0000000000..e7a5d4d86a --- /dev/null +++ b/bentoml/testing/grpc/servicer.py @@ -0,0 +1,23 @@ +# pylint: disable=unused-argument +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from grpc import aio + + from bentoml.grpc.v1alpha1 import service_test_pb2 as pb + from bentoml.grpc.v1alpha1 import service_test_pb2_grpc as services +else: + from bentoml.grpc.utils import import_generated_stubs + + pb, services = import_generated_stubs(file="service_test.proto") + + +class TestServiceServicer(services.TestServiceServicer): + async def Execute( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self, + request: pb.ExecuteRequest, + context: aio.ServicerContext[pb.ExecuteRequest, pb.ExecuteResponse], + ) -> pb.ExecuteResponse: + return pb.ExecuteResponse(output="Hello, {}!".format(request.input)) diff --git a/bentoml/testing/pytest/__init__.py b/bentoml/testing/pytest/__init__.py new file mode 100644 index 0000000000..b6d2667d21 --- /dev/null +++ b/bentoml/testing/pytest/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .plugin import TEST_MODEL_CONTEXT + +__all__ = ["TEST_MODEL_CONTEXT"] diff --git a/bentoml/testing/pytest/plugin.py b/bentoml/testing/pytest/plugin.py new file mode 100644 index 0000000000..6cfb972a44 --- /dev/null +++ b/bentoml/testing/pytest/plugin.py @@ -0,0 +1,307 @@ +# pylint: disable=unused-argument +from __future__ import annotations + +import os +import typing as t +import tempfile +import contextlib +from typing import TYPE_CHECKING + +import psutil +import pytest +from pytest import MonkeyPatch + +import bentoml +from bentoml._internal.utils import LazyLoader +from bentoml._internal.utils import validate_or_create_dir +from bentoml._internal.models import ModelContext +from bentoml._internal.configuration import CLEAN_BENTOML_VERSION +from bentoml._internal.configuration.containers import BentoMLContainer + +if TYPE_CHECKING: + import numpy as np + from _pytest.main import Session + from _pytest.nodes import Item + from _pytest.config import Config + from _pytest.config import ExitCode + from _pytest.python import Metafunc + from _pytest.fixtures import FixtureRequest + from _pytest.config.argparsing import Parser + + from bentoml._internal.server.metrics.prometheus import PrometheusClient + +else: + np = LazyLoader("np", globals(), "numpy") + + +TEST_MODEL_CONTEXT = ModelContext( + framework_name="testing", + framework_versions={"testing": "v1"}, +) + +_RUN_GPU_TESTS_MARKER = "--run-gpu-tests" +_RUN_GRPC_TESTS_MARKER = "--run-grpc-tests" + + +@pytest.mark.tryfirst +def pytest_report_header(config: Config) -> list[str]: + return [f"bentoml: version={CLEAN_BENTOML_VERSION}"] + + +@pytest.hookimpl +def pytest_addoption(parser: Parser) -> None: + group = parser.getgroup("bentoml", "BentoML pytest plugins.") + group.addoption( + _RUN_GPU_TESTS_MARKER, + action="store_true", + default=False, + help="run gpus related tests.", + ) + group.addoption( + _RUN_GRPC_TESTS_MARKER, + action="store_true", + default=False, + help="run grpc related tests.", + ) + + +def pytest_configure(config: Config) -> None: + # We will inject marker documentation here. + config.addinivalue_line( + "markers", + "requires_gpus: requires GPU to run given test.", + ) + config.addinivalue_line( + "markers", + "requires_grpc: requires gRPC support to run given test.", + ) + + +@pytest.hookimpl(tryfirst=True) +def pytest_runtest_setup(item: Item) -> None: + config = item.config + if "requires_gpus" in item.keywords and not config.getoption(_RUN_GPU_TESTS_MARKER): + item.add_marker( + pytest.mark.skip( + reason=f"need {_RUN_GPU_TESTS_MARKER} option to run gpus related tests." + ) + ) + # We don't run gRPC tests on Windows + if ("requires_grpc" in item.keywords or psutil.WINDOWS) and not config.getoption( + _RUN_GRPC_TESTS_MARKER + ): + item.add_marker( + pytest.mark.skip( + reason=f"need {_RUN_GRPC_TESTS_MARKER} option to run grpc related tests." + ) + ) + + +def _setup_deployment_mode(metafunc: Metafunc): + """ + Setup deployment mode for test session. + We will dynamically add this fixture to tests functions that has ``deployment_mode`` fixtures. + + Current matrix: + - deployment_mode: ["docker", "distributed", "standalone"] + """ + if os.getenv("VSCODE_IPC_HOOK_CLI") and not os.getenv("GITHUB_CODESPACE_TOKEN"): + # When running inside VSCode remote container locally, we don't have access to + # exposed reserved ports, so we can't run docker-based tests. However on GitHub + # Codespaces, we can run docker-based tests. + # Note that inside the remote container, it is already running as a Linux container. + deployment_mode = ["distributed", "standalone"] + else: + if os.environ.get("GITHUB_ACTIONS") and (psutil.WINDOWS or psutil.MACOS): + # Due to GitHub Actions' limitation, we can't run docker-based tests + # on Windows and macOS. However, we can still running those tests on + # local development. + if psutil.MACOS: + deployment_mode = ["distributed", "standalone"] + else: + deployment_mode = ["standalone"] + else: + if psutil.WINDOWS: + deployment_mode = ["standalone", "docker"] + else: + deployment_mode = ["distributed", "standalone", "docker"] + metafunc.parametrize("deployment_mode", deployment_mode, scope="session") + + +def _setup_model_store(metafunc: Metafunc): + """Setup dummy models for test session.""" + with bentoml.models.create( + "testmodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + ): + pass + with bentoml.models.create( + "testmodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + ): + pass + with bentoml.models.create( + "anothermodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + ): + pass + + metafunc.parametrize( + "model_store", [BentoMLContainer.model_store.get()], scope="session" + ) + + +@pytest.mark.tryfirst +def pytest_generate_tests(metafunc: Metafunc): + if "deployment_mode" in metafunc.fixturenames: + _setup_deployment_mode(metafunc) + if "model_store" in metafunc.fixturenames: + _setup_model_store(metafunc) + + +def _setup_session_environment( + mp: MonkeyPatch, o: Session | Config, *pairs: tuple[str, str] +): + """Setup environment variable for test session.""" + for p in pairs: + key, value = p + _ENV_VAR = os.environ.get(key, None) + if _ENV_VAR is not None: + mp.setattr(o, f"_original_{key}", _ENV_VAR, raising=False) + os.environ[key] = value + + +def _setup_test_directory() -> tuple[str, str]: + # Ensure we setup correct home and prometheus_multiproc_dir folders. + # For any given test session. + bentoml_home = tempfile.mkdtemp("bentoml-pytest") + bentos = os.path.join(bentoml_home, "bentos") + models = os.path.join(bentoml_home, "models") + multiproc_dir = os.path.join(bentoml_home, "prometheus_multiproc_dir") + validate_or_create_dir(bentos, models, multiproc_dir) + + # We need to set the below value inside container due to + # the fact that each value is a singleton, and will be cached. + BentoMLContainer.bentoml_home.set(bentoml_home) + BentoMLContainer.bento_store_dir.set(bentos) + BentoMLContainer.model_store_dir.set(models) + BentoMLContainer.prometheus_multiproc_dir.set(multiproc_dir) + return bentoml_home, multiproc_dir + + +@pytest.mark.tryfirst +def pytest_sessionstart(session: Session) -> None: + """Create a temporary directory for the BentoML home directory, then monkey patch to config.""" + from bentoml._internal.utils import analytics + + # We need to clear analytics cache before running tests. + analytics.usage_stats.do_not_track.cache_clear() + analytics.usage_stats._usage_event_debugging.cache_clear() # type: ignore (private warning) + + mp = MonkeyPatch() + config = session.config + config.add_cleanup(mp.undo) + + _PYTEST_BENTOML_HOME, _PYTEST_MULTIPROC_DIR = _setup_test_directory() + + # The evironment variable patch ensures that we will + # always build bento using bentoml from source, use the correct + # test bentoml home directory, and setup prometheus multiproc directory. + _setup_session_environment( + mp, + session, + ("PROMETHEUS_MULTIPROC_DIR", _PYTEST_MULTIPROC_DIR), + ("BENTOML_BUNDLE_LOCAL_BUILD", "True"), + ("SETUPTOOLS_USE_DISTUTILS", "stdlib"), + ("__BENTOML_DEBUG_USAGE", "False"), + ("BENTOML_DO_NOT_TRACK", "True"), + ) + + _setup_session_environment(mp, config, ("BENTOML_HOME", _PYTEST_BENTOML_HOME)) + + +def _teardown_session_environment(o: Session | Config, *variables: str): + """Restore environment variable to original value.""" + for variable in variables: + if hasattr(o, f"_original_{variable}"): + os.environ[variable] = getattr(o, f"_original_{variable}") + else: + os.environ.pop(variable, None) + + +@pytest.mark.tryfirst +def pytest_sessionfinish(session: Session, exitstatus: int | ExitCode) -> None: + config = session.config + + _teardown_session_environment( + session, + "BENTOML_BUNDLE_LOCAL_BUILD", + "PROMETHEUS_MULTIPROC_DIR", + "SETUPTOOLS_USE_DISTUTILS", + "__BENTOML_DEBUG_USAGE", + "BENTOML_DO_NOT_TRACK", + ) + _teardown_session_environment(config, "BENTOML_HOME") + + # reset home and prometheus_multiproc_dir to default + BentoMLContainer.prometheus_multiproc_dir.reset() + + +@pytest.fixture(scope="session") +def bentoml_home() -> str: + """ + Return the BentoML home directory for the test session. + This directory is created via ``pytest_sessionstart``. + """ + return BentoMLContainer.bentoml_home.get() + + +@pytest.fixture(scope="session", autouse=True) +def clean_context() -> t.Generator[contextlib.ExitStack, None, None]: + """ + Create a ExitStack to cleanup contextmanager. + This fixture is available to all tests. + """ + stack = contextlib.ExitStack() + yield stack + stack.close() + + +@pytest.fixture() +def img_file(tmpdir: str) -> str: + """Create a random image/bmp file.""" + from PIL.Image import fromarray + + img_file_ = tmpdir.join("test_img.bmp") + img = fromarray(np.random.randint(255, size=(10, 10, 3)).astype("uint8")) + img.save(str(img_file_)) + return str(img_file_) + + +@pytest.fixture() +def bin_file(tmpdir: str) -> str: + """Create a random binary file.""" + bin_file_ = tmpdir.join("bin_file.bin") + with open(bin_file_, "wb") as of: + of.write("â".encode("gb18030")) + return str(bin_file_) + + +@pytest.fixture(scope="module", name="prom_client") +def fixture_metrics_client() -> PrometheusClient: + """This fixtures return a PrometheusClient instance that can be used for testing.""" + return BentoMLContainer.metrics_client.get() + + +@pytest.fixture(scope="function", name="change_test_dir") +def fixture_change_dir(request: FixtureRequest) -> t.Generator[None, None, None]: + """A fixture to change given test directory to the directory of the current running test.""" + os.chdir(request.fspath.dirname) # type: ignore (bad pytest stubs) + yield + os.chdir(request.config.invocation_dir) # type: ignore (bad pytest stubs) diff --git a/bentoml/testing/server.py b/bentoml/testing/server.py index 8553f0b69f..235df484bb 100644 --- a/bentoml/testing/server.py +++ b/bentoml/testing/server.py @@ -1,14 +1,13 @@ -# pylint: disable=redefined-outer-name # pragma: no cover +# pylint: disable=redefined-outer-name,not-context-manager from __future__ import annotations import os -import re import sys import time import socket import typing as t import urllib -import logging +import asyncio import itertools import contextlib import subprocess @@ -20,20 +19,27 @@ import psutil -from .._internal.tag import Tag -from .._internal.utils import reserve_free_port -from .._internal.utils import cached_contextmanager - -logger = logging.getLogger("bentoml") - +from bentoml.grpc.utils import import_grpc +from bentoml._internal.tag import Tag +from bentoml._internal.utils import LazyLoader +from bentoml._internal.utils import reserve_free_port +from bentoml._internal.utils import cached_contextmanager if TYPE_CHECKING: + from grpc import aio + from grpc_health.v1 import health_pb2 as pb_health from aiohttp.typedefs import LooseHeaders from starlette.datastructures import Headers from starlette.datastructures import FormData + from bentoml._internal.bento.bento import Bento + +else: + pb_health = LazyLoader("pb_health", globals(), "grpc_health.v1.health_pb2") + _, aio = import_grpc() -async def parse_multipart_form(headers: "Headers", body: bytes) -> "FormData": + +async def parse_multipart_form(headers: Headers, body: bytes) -> FormData: """ parse starlette forms from headers and body """ @@ -52,10 +58,10 @@ async def async_bytesio(bytes_: bytes) -> t.AsyncGenerator[bytes, None]: async def async_request( method: str, url: str, - headers: t.Optional["LooseHeaders"] = None, + headers: LooseHeaders | None = None, data: t.Any = None, - timeout: t.Optional[int] = None, -) -> t.Tuple[int, "Headers", bytes]: + timeout: int | None = None, +) -> tuple[int, Headers, bytes]: """ A HTTP client with async API. """ @@ -80,6 +86,7 @@ def kill_subprocess_tree(p: subprocess.Popen[t.Any]) -> None: """ Tell the process to terminate and kill all of its children. Availabe both on Windows and Linux. Note: It will return immediately rather than wait for the process to terminate. + Args: p: subprocess.Popen object """ @@ -89,91 +96,125 @@ def kill_subprocess_tree(p: subprocess.Popen[t.Any]) -> None: p.terminate() -def _wait_until_api_server_ready( +async def server_warmup( host_url: str, timeout: float, + grpc: bool = False, check_interval: float = 1, - popen: t.Optional["subprocess.Popen[t.Any]"] = None, + popen: subprocess.Popen[t.Any] | None = None, + service_name: str | None = None, ) -> bool: + from bentoml.testing.grpc import create_channel + start_time = time.time() - proxy_handler = urllib.request.ProxyHandler({}) - opener = urllib.request.build_opener(proxy_handler) - logger.info("Waiting for host %s to be ready..", host_url) + print("Waiting for host %s to be ready.." % host_url) while time.time() - start_time < timeout: try: if popen and popen.poll() is not None: return False - elif opener.open(f"http://{host_url}/readyz", timeout=1).status == 200: - return True + elif grpc: + if service_name is None: + service_name = "bentoml.grpc.v1alpha1.BentoService" + async with create_channel(host_url) as channel: + Check = channel.unary_unary( + "/grpc.health.v1.Health/Check", + request_serializer=pb_health.HealthCheckRequest.SerializeToString, # type: ignore (no grpc_health type) + response_deserializer=pb_health.HealthCheckResponse.FromString, # type: ignore (no grpc_health type) + ) + resp = await t.cast( + t.Awaitable[pb_health.HealthCheckResponse], + Check( + pb_health.HealthCheckRequest(service=service_name), + timeout=timeout, + ), + ) + if resp.status == pb_health.HealthCheckResponse.SERVING: # type: ignore (no generated enum types) + return True + else: + time.sleep(check_interval) else: - time.sleep(check_interval) + proxy_handler = urllib.request.ProxyHandler({}) + opener = urllib.request.build_opener(proxy_handler) + if opener.open(f"http://{host_url}/readyz", timeout=1).status == 200: + return True + else: + time.sleep(check_interval) except ( + aio.AioRpcError, ConnectionError, urllib.error.URLError, socket.timeout, ) as e: - logger.info(f"[{e}]retrying to connect to the host {host_url}...") - logger.error(e) + print(f"[{e}] Retrying to connect to the host {host_url}...") time.sleep(check_interval) - logger.info( - f"Timed out waiting {timeout} seconds for Server {host_url} to be ready, " - ) + print(f"Timed out waiting {timeout} seconds for Server {host_url} to be ready.") return False -@cached_contextmanager("{project_path}") -def bentoml_build(project_path: str) -> t.Generator["Tag", None, None]: +@cached_contextmanager("{project_path}, {cleanup}") +def bentoml_build( + project_path: str, cleanup: bool = True +) -> t.Generator[Bento, None, None]: """ Build a BentoML project. """ - logger.info(f"Building bento: {project_path}") - output = subprocess.check_output( - ["bentoml", "build", project_path], - stderr=subprocess.STDOUT, - env=dict(os.environ, COLUMNS="200"), - ) - match = re.search( - r'Bento\(tag="([A-Za-z0-9\-_\.]+:[a-z0-9]+)"\)', - output.decode(), - ) - assert match, f"Build failed. The details:\n {output.decode()}" - tag = Tag.from_taglike(match[1]) - yield tag - logger.info(f"Deleting bento: {tag}") - subprocess.call(["bentoml", "delete", "-y", str(tag)]) + from bentoml import bentos + print(f"Building bento: {project_path}") + bento = bentos.build_bentofile(build_ctx=project_path) + yield bento + if cleanup: + print(f"Deleting bento: {str(bento.tag)}") + bentos.delete(bento.tag) -@cached_contextmanager("{bento_tag}, {image_tag}") + +@cached_contextmanager("{bento_tag}, {image_tag}, {cleanup}, {use_grpc}") def bentoml_containerize( - bento_tag: t.Union[str, "Tag"], - image_tag: t.Optional[str] = None, + bento_tag: str | Tag, + image_tag: str | None = None, + cleanup: bool = True, + use_grpc: bool = False, ) -> t.Generator[str, None, None]: """ Build the docker image from a saved bento, yield the docker image tag """ + from bentoml import bentos + bento_tag = Tag.from_taglike(bento_tag) if image_tag is None: image_tag = bento_tag.name - logger.info(f"Building bento server docker image: {bento_tag}") - subprocess.check_call(["bentoml", "containerize", str(bento_tag), "-t", image_tag]) - yield image_tag - logger.info(f"Removing bento server docker image: {image_tag}") - subprocess.call(["docker", "rmi", image_tag]) + try: + print(f"Building bento server docker image: {bento_tag}") + bentos.containerize( + str(bento_tag), + docker_image_tag=[image_tag], + progress="plain", + features=["grpc"] if use_grpc else None, + ) + yield image_tag + finally: + if cleanup: + print(f"Removing bento server docker image: {image_tag}") + subprocess.call(["docker", "rmi", image_tag]) -@cached_contextmanager("{image_tag}, {config_file}") -def run_bento_server_in_docker( +@cached_contextmanager("{image_tag}, {config_file}, {use_grpc}") +def run_bento_server_docker( image_tag: str, - config_file: t.Optional[str] = None, - timeout: float = 40, + config_file: str | None = None, + use_grpc: bool = False, + timeout: float = 90, + host: str = "127.0.0.1", ): """ Launch a bentoml service container from a docker image, yield the host URL """ + from bentoml._internal.configuration.containers import BentoMLContainer + container_name = f"bentoml-test-{image_tag}-{hash(config_file)}" - with reserve_free_port() as port: + with reserve_free_port(enable_so_reuseport=use_grpc) as port: pass - + bind_port = "3000" cmd = [ "docker", "run", @@ -181,133 +222,138 @@ def run_bento_server_in_docker( "--name", container_name, "--publish", - f"{port}:3000", - "--env", - "BENTOML_LOG_STDOUT=true", - "--env", - "BENTOML_LOG_STDERR=true", + f"{port}:{bind_port}", ] - if config_file is not None: cmd.extend(["--env", "BENTOML_CONFIG=/home/bentoml/bentoml_config.yml"]) cmd.extend( ["-v", f"{os.path.abspath(config_file)}:/home/bentoml/bentoml_config.yml"] ) + if use_grpc: + bind_prom_port = BentoMLContainer.grpc.metrics.port.get() + cmd.extend(["--publish", f"{bind_prom_port}:{bind_prom_port}"]) cmd.append(image_tag) - - logger.info(f"Running API server docker image: {cmd}") + if use_grpc: + cmd.extend(["serve-grpc", "--production", "--enable-reflection"]) + print(f"Running API server docker image: '{' '.join(cmd)}'") with subprocess.Popen( cmd, stdin=subprocess.PIPE, encoding="utf-8", ) as proc: try: - host_url = f"127.0.0.1:{port}" - if _wait_until_api_server_ready(host_url, timeout, popen=proc): + host_url = f"{host}:{port}" + if asyncio.run( + server_warmup(host_url, timeout=timeout, popen=proc, grpc=use_grpc) + ): yield host_url else: raise RuntimeError( f"API server {host_url} failed to start within {timeout} seconds" - ) + ) from None finally: + print(f"Stopping Bento container {container_name}...") subprocess.call(["docker", "stop", container_name]) time.sleep(1) @contextmanager -def run_bento_server( +def run_bento_server_standalone( bento: str, - workdir: t.Optional[str] = None, - config_file: t.Optional[str] = None, - dev_server: bool = False, + use_grpc: bool = False, + config_file: str | None = None, timeout: float = 90, + host: str = "127.0.0.1", ): """ Launch a bentoml service directly by the bentoml CLI, yields the host URL. """ - workdir = workdir if workdir is not None else "./" - my_env = os.environ.copy() + copied = os.environ.copy() if config_file is not None: - my_env["BENTOML_CONFIG"] = os.path.abspath(config_file) - with reserve_free_port() as port: - cmd = [sys.executable, "-m", "bentoml", "serve"] - if not dev_server: - cmd += ["--production"] - if port: - cmd += ["--port", f"{port}"] - cmd += [bento] - cmd += ["--working-dir", workdir] - logger.info(f"Running command: `{cmd}`") + copied["BENTOML_CONFIG"] = os.path.abspath(config_file) + with reserve_free_port(host=host, enable_so_reuseport=use_grpc) as server_port: + cmd = [ + sys.executable, + "-m", + "bentoml", + "serve-grpc" if use_grpc else "serve", + "--production", + "--port", + f"{server_port}", + ] + if use_grpc: + cmd += ["--host", f"{host}", "--enable-reflection"] + cmd += [bento] + print(f"Running command: '{' '.join(cmd)}'") p = subprocess.Popen( cmd, stderr=subprocess.STDOUT, - env=my_env, + env=copied, encoding="utf-8", ) - try: - host_url = f"127.0.0.1:{port}" - assert _wait_until_api_server_ready(host_url, timeout=timeout, popen=p) + host_url = f"{host}:{server_port}" + assert asyncio.run( + server_warmup(host_url, timeout=timeout, popen=p, grpc=use_grpc) + ) yield host_url finally: + print(f"Stopping process [{p.pid}]...") kill_subprocess_tree(p) p.communicate() -def _start_mitm_proxy(port: int) -> None: - import uvicorn # type: ignore +def start_mitm_proxy(port: int) -> None: + import uvicorn from .utils import http_proxy_app - logger.info(f"proxy serer listen on {port}") - uvicorn.run(http_proxy_app, port=port) # type: ignore + print(f"Proxy server listen on {port}") + uvicorn.run(http_proxy_app, port=port) # type: ignore (not using ASGI3Application) @contextmanager def run_bento_server_distributed( - bento_tag: t.Union[str, "Tag"], - config_file: t.Optional[str] = None, + bento_tag: str | Tag, + config_file: str | None = None, + use_grpc: bool = False, timeout: float = 90, + host: str = "127.0.0.1", ): """ Launch a bentoml service as a simulated distributed environment(Yatai), yields the host URL. """ - with reserve_free_port() as proxy_port: - pass + import yaml + + import bentoml - logger.warning(f"Starting proxy on port {proxy_port}") + with reserve_free_port(enable_so_reuseport=use_grpc) as proxy_port: + pass + print(f"Starting proxy on port {proxy_port}") proxy_process = multiprocessing.Process( - target=_start_mitm_proxy, + target=start_mitm_proxy, args=(proxy_port,), ) proxy_process.start() - - my_env = os.environ.copy() - + copied = os.environ.copy() # to ensure yatai specified headers BP100 - my_env["YATAI_BENTO_DEPLOYMENT_NAME"] = "sdfasdf" - my_env["YATAI_BENTO_DEPLOYMENT_NAMESPACE"] = "yatai" - my_env["HTTP_PROXY"] = f"http://127.0.0.1:{proxy_port}" - + copied["YATAI_BENTO_DEPLOYMENT_NAME"] = "test-deployment" + copied["YATAI_BENTO_DEPLOYMENT_NAMESPACE"] = "yatai" + if use_grpc: + copied["GPRC_PROXY"] = f"localhost:{proxy_port}" + else: + copied["HTTP_PROXY"] = f"http://127.0.0.1:{proxy_port}" if config_file is not None: - my_env["BENTOML_CONFIG"] = os.path.abspath(config_file) - - import yaml - - import bentoml + copied["BENTOML_CONFIG"] = os.path.abspath(config_file) + runner_map = {} + processes: list[subprocess.Popen[str]] = [] bento_service = bentoml.bentos.get(bento_tag) - path = bento_service.path - with open(os.path.join(path, "bento.yaml"), "r", encoding="utf-8") as f: bentofile = yaml.safe_load(f) - - runner_map = {} - processes: t.List[subprocess.Popen[str]] = [] - for runner in bentofile["runners"]: - with reserve_free_port() as port: + with reserve_free_port(enable_so_reuseport=use_grpc) as port: runner_map[runner["name"]] = f"tcp://127.0.0.1:{port}" cmd = [ sys.executable, @@ -318,84 +364,95 @@ def run_bento_server_distributed( "--runner-name", runner["name"], "--host", - "127.0.0.1", + host, "--port", f"{port}", "--working-dir", path, ] - logger.info(f"Running command: `{cmd}`") - + print(f"Running command: '{' '.join(cmd)}'") processes.append( subprocess.Popen( cmd, encoding="utf-8", stderr=subprocess.STDOUT, - env=my_env, + env=copied, ) ) - - with reserve_free_port() as server_port: - args_pairs = [ - ("--remote-runner", f"{runner['name']}={runner_map[runner['name']]}") - for runner in bentofile["runners"] - ] - cmd = [ - sys.executable, - "-m", - "bentoml", - "start-http-server", - str(bento_tag), - "--host", - "127.0.0.1", - "--port", - f"{server_port}", - "--working-dir", - path, - *itertools.chain.from_iterable(args_pairs), - ] - logger.info(f"Running command: `{cmd}`") - + runner_args = [ + ("--remote-runner", f"{runner['name']}={runner_map[runner['name']]}") + for runner in bentofile["runners"] + ] + cmd = [ + sys.executable, + "-m", + "bentoml", + "start-http-server" if not use_grpc else "start-grpc-server", + str(bento_tag), + "--host", + host, + "--working-dir", + path, + *itertools.chain.from_iterable(runner_args), + ] + with reserve_free_port(host=host, enable_so_reuseport=use_grpc) as server_port: + cmd.extend(["--port", f"{server_port}"]) + if use_grpc: + cmd.append("--enable-reflection") + print(f"Running command: '{' '.join(cmd)}'") processes.append( subprocess.Popen( cmd, stderr=subprocess.STDOUT, encoding="utf-8", - env=my_env, + env=copied, ) ) try: - host_url = f"127.0.0.1:{server_port}" - _wait_until_api_server_ready(host_url, timeout=timeout) + host_url = f"{host}:{server_port}" + asyncio.run(server_warmup(host_url, timeout=timeout, grpc=use_grpc)) yield host_url finally: for p in processes: kill_subprocess_tree(p) for p in processes: p.communicate() - proxy_process.terminate() - proxy_process.join() + if proxy_process is not None: + proxy_process.terminate() + proxy_process.join() -@cached_contextmanager("{bento}, {project_path}, {config_file}, {deployment_mode}") +@cached_contextmanager( + "{bento_name}, {project_path}, {config_file}, {deployment_mode}, {bentoml_home}, {use_grpc}" +) def host_bento( - bento: t.Union[str, Tag, None] = None, + bento_name: str | Tag | None = None, project_path: str = ".", config_file: str | None = None, - deployment_mode: str = "standalone", + deployment_mode: t.Literal["standalone", "distributed", "docker"] = "standalone", + bentoml_home: str | None = None, + use_grpc: bool = False, clean_context: contextlib.ExitStack | None = None, + host: str = "127.0.0.1", ) -> t.Generator[str, None, None]: """ Host a bentoml service, yields the host URL. Args: - bento: a beoto tag or `module_path:service` + bento: a bento tag or :code:`module_path:service` project_path: the path to the project directory config_file: the path to the config file - deployment_mode: the deployment mode, one of `standalone`, `docker` or `distributed` + deployment_mode: the deployment mode, one of :code:`standalone`, :code:`docker` or :code:`distributed` clean_context: a contextlib.ExitStack to clean up the intermediate files, - like docker image and bentos. If None, it will be created. Used for reusing - those files in the same test session. + like docker image and bentos. If None, it will be created. Used for reusing + those files in the same test session. + bentoml_home: if set, we will change the given BentoML home folder to :code:`bentoml_home`. Default + to :code:`$HOME/bentoml` + use_grpc: if True, running gRPC tests. + host: set a given host for the bento, default to ``127.0.0.1`` + + Returns: + :obj:`str`: a generated host URL where we run the test bento. """ import bentoml @@ -404,42 +461,49 @@ def host_bento( clean_on_exit = True else: clean_on_exit = False + if bentoml_home: + from bentoml._internal.configuration.containers import BentoMLContainer + BentoMLContainer.bentoml_home.set(bentoml_home) try: - logger.info( - f"starting bento server {bento} at {project_path} " - f"with config file {config_file} " - f"in {deployment_mode} mode..." + print( + f"Starting bento server {bento_name} at '{project_path}' {'with config file '+config_file+' ' if config_file else ' '}in {deployment_mode} mode..." ) - if bento is None or not bentoml.list(bento): - bento_tag = clean_context.enter_context(bentoml_build(project_path)) + if bento_name is None or not bentoml.list(bento_name): + bento = clean_context.enter_context(bentoml_build(project_path)) else: - bento_tag = bentoml.get(bento).tag - - if deployment_mode == "docker": - image_tag = clean_context.enter_context(bentoml_containerize(bento_tag)) - with run_bento_server_in_docker( # pylint: disable=not-context-manager # cached_contextmanager not detected by pylint - image_tag, - config_file, - ) as host: - yield host - elif deployment_mode == "standalone": - with run_bento_server( - str(bento_tag), + bento = bentoml.get(bento_name) + if deployment_mode == "standalone": + with run_bento_server_standalone( + bento.path, config_file=config_file, - workdir=project_path, - ) as host: - yield host + use_grpc=use_grpc, + host=host, + ) as host_url: + yield host_url + elif deployment_mode == "docker": + container_tag = clean_context.enter_context( + bentoml_containerize(bento.tag, use_grpc=use_grpc) + ) + with run_bento_server_docker( + container_tag, + config_file=config_file, + use_grpc=use_grpc, + host=host, + ) as host_url: + yield host_url elif deployment_mode == "distributed": with run_bento_server_distributed( - str(bento_tag), + bento.tag, config_file=config_file, - ) as host: - yield host + use_grpc=use_grpc, + host=host, + ) as host_url: + yield host_url else: - raise ValueError(f"Unknown deployment mode: {deployment_mode}") + raise ValueError(f"Unknown deployment mode: {deployment_mode}") from None finally: - logger.info("shutting down bento server...") + print("Shutting down bento server...") if clean_on_exit: - logger.info("Cleaning up...") + print("Cleaning on exit...") clean_context.close() diff --git a/bentoml/testing/utils.py b/bentoml/testing/utils.py index ec088a12e8..5ec8157a13 100644 --- a/bentoml/testing/utils.py +++ b/bentoml/testing/utils.py @@ -1,15 +1,11 @@ from __future__ import annotations import typing as t -import logging from typing import TYPE_CHECKING import aiohttp import multidict -logger = logging.getLogger("bentoml.tests") - - if TYPE_CHECKING: from starlette.types import Send from starlette.types import Scope @@ -35,61 +31,78 @@ async def async_bytesio(bytes_: bytes) -> t.AsyncGenerator[bytes, None]: return await parser.parse() +def handle_assert_exception(assert_object: t.Any, obj: t.Any, msg: str): + res = assert_object + try: + if callable(assert_object): + res = assert_object(obj) + assert res + else: + assert obj == assert_object + except AssertionError: + raise ValueError(f"Expected: {res}. {msg}") from None + except Exception as e: # pylint: disable=broad-except + # if callable has some errors, then we raise it here + raise ValueError( + f"Exception while excuting '{assert_object.__name__}': {e}" + ) from None + + async def async_request( method: str, url: str, - headers: t.Union[None, t.Tuple[t.Tuple[str, str], ...], "LooseHeaders"] = None, + headers: None | tuple[tuple[str, str], ...] | LooseHeaders = None, data: t.Any = None, - timeout: t.Optional[int] = None, - assert_status: t.Union[int, t.Callable[[int], bool], None] = None, - assert_data: t.Union[bytes, t.Callable[[bytes], bool], None] = None, - assert_headers: t.Optional[t.Callable[[t.Any], bool]] = None, -) -> t.Tuple[int, "Headers", bytes]: - """ - raw async request client - """ - import aiohttp + timeout: int | None = None, + assert_status: int | t.Callable[[int], bool] | None = None, + assert_data: bytes | t.Callable[[bytes], bool] | None = None, + assert_headers: t.Callable[[t.Any], bool] | None = None, +) -> tuple[int, Headers, bytes]: from starlette.datastructures import Headers async with aiohttp.ClientSession() as sess: try: async with sess.request( method, url, data=data, headers=headers, timeout=timeout - ) as r: - r_body = await r.read() + ) as resp: + body = await resp.read() except Exception: - raise RuntimeError( - "Unable to reach host." - ) from None # suppress exception trace + raise RuntimeError("Unable to reach host.") from None if assert_status is not None: - if callable(assert_status): - assert assert_status(r.status), f"{r.status} {repr(r_body)}" - else: - assert r.status == assert_status, f"{r.status} {repr(r_body)}" - + handle_assert_exception( + assert_status, + resp.status, + f"Return status [{resp.status}] with body: {body!r}", + ) if assert_data is not None: if callable(assert_data): - assert assert_data(r_body), r_body + msg = f"'{assert_data.__name__}' returns {assert_data(body)}" else: - assert r_body == assert_data, r_body - + msg = f"Expects data '{assert_data}'" + handle_assert_exception( + assert_data, + body, + f"{msg}\nReceived response: {body}.", + ) if assert_headers is not None: - assert assert_headers(r.headers), repr(r.headers) - - headers = t.cast(t.Mapping[str, str], r.headers) - return r.status, Headers(headers), r_body + handle_assert_exception( + assert_headers, + resp.headers, + f"Headers assertion failed: {resp.headers!r}", + ) + return resp.status, Headers(resp.headers), body -def check_headers(headers: multidict.CIMultiDict[str]) -> bool: - return ( - headers.get("Yatai-Bento-Deployment-Name") == "sdfasdf" +def assert_distributed_header(headers: multidict.CIMultiDict[str]) -> None: + assert ( + headers.get("Yatai-Bento-Deployment-Name") == "test-deployment" and headers.get("Yatai-Bento-Deployment-Namespace") == "yatai" ) async def http_proxy_app(scope: Scope, receive: Receive, send: Send): """ - A simplest HTTP proxy app. To simulate the behavior of yatai + A simple HTTP proxy app that simulate the behavior of Yatai. """ if scope["type"] == "lifespan": return @@ -100,15 +113,14 @@ async def http_proxy_app(scope: Scope, receive: Receive, send: Send): tuple((k.decode(), v.decode()) for k, v in scope["headers"]) ) - assert check_headers(headers) - - bodys: list[bytes] = [] + assert_distributed_header(headers) + bodies: list[bytes] = [] while True: request_message = await receive() assert request_message["type"] == "http.request" request_body = request_message.get("body") assert isinstance(request_body, bytes) - bodys.append(request_body) + bodies.append(request_body) if not request_message["more_body"]: break @@ -116,7 +128,7 @@ async def http_proxy_app(scope: Scope, receive: Receive, send: Send): method=scope["method"], url=scope["path"], headers=headers, - data=b"".join(bodys), + data=b"".join(bodies), ) as response: await send( { @@ -135,4 +147,4 @@ async def http_proxy_app(scope: Scope, receive: Receive, send: Send): ) return - raise NotImplementedError(f"Scope {scope} is not understood") + raise NotImplementedError(f"Scope {scope} is not understood.") from None diff --git a/codecov.yml b/codecov.yml index a0e2141863..732c12e743 100644 --- a/codecov.yml +++ b/codecov.yml @@ -340,7 +340,8 @@ flags: carryforward: true paths: - "bentoml/**/*" - - bentoml/grpc/utils.py + - bentoml/grpc/interceptors/ + - bentoml/grpc/utils/ unit-tests: carryforward: true paths: diff --git a/pyproject.toml b/pyproject.toml index f72035845c..fc4f43877c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,9 @@ dynamic = ["version"] [project.scripts] bentoml = "bentoml_cli.cli:cli" +[project.entry-points.pytest11] +bentoml = "bentoml.testing.pytest.plugin" + [tool.setuptools] package-data = { "bentoml" = ["bentoml/*"], "bentoml_cli" = ["bentoml_cli/*"] } @@ -136,36 +139,41 @@ source = ["bentoml"] [tool.coverage.run] branch = true -source = ["bentoml", "bentoml_cli"] +source = ["bentoml/", "bentoml_cli/"] omit = [ - "bentoml/**/*_pb2.py", "bentoml/__main__.py", - "bentoml/_internal/types.py", - "bentoml/_internal/external_typing/*", - "bentoml/testing/*", "bentoml/io.py", + "bentoml/serve.py", + "bentoml/start.py", + "bentoml/_internal/types.py", + "bentoml/testing/", + "bentoml/grpc/v1alpha1/", + "bentoml/_internal/external_typing/", ] [tool.coverage.report] show_missing = true precision = 2 omit = [ - "*/bentoml/**/*_pb2*.py", - "*/bentoml/_internal/external_typing/*", - "*/bentoml/_internal/types.py", - "*/bentoml/testing/*", - '*/bentoml/__main__.py', + "*/bentoml/__main__.py", "*/bentoml/io.py", + "*/bentoml/serve.py", + "*/bentoml/start.py", + "*/bentoml/_internal/types.py", + "*/bentoml/testing/", + "*/bentoml/grpc/v1alpha1/", + "*/bentoml/_internal/external_typing/", ] exclude_lines = [ - "pragma: no cover", - "def __repr__", - "raise AssertionError", - "raise NotImplementedError", - "raise MissingDependencyException", - "except ImportError", + "\\#\\s*pragma: no cover", + "^\\s*def __repr__", + "^\\s*raise AssertionError", + "^\\s*raise NotImplementedError", + "^\\s*raise MissingDependencyException", + "^\\s*except ImportError", "if __name__ == .__main__.:", - "if TYPE_CHECKING:", + "^\\s*if TYPE_CHECKING:", + "^\\s*@overload( |$)", ] [tool.black] @@ -194,10 +202,18 @@ exclude = ''' extend-exclude = "(_pb2.py$|_pb2_grpc.py$)" [tool.pytest.ini_options] -addopts = "-rfEX -p pytester -p no:warnings -x --capture=tee-sys --cov-report=term-missing --cov-append" +addopts = [ + "-rfEX", + "-x", + "--capture=tee-sys", + "--tb=long", + "--import-mode=importlib", + "--cov=bentoml", + "--cov-report=term-missing:skip-covered", + "--cov-append", +] python_files = ["test_*.py", "*_test.py"] testpaths = ["tests"] -markers = ["gpus", "disable-tf-eager-execution"] [tool.pylint.main] recursive = true diff --git a/requirements/frameworks-requirements.txt b/requirements/frameworks-requirements.txt new file mode 100644 index 0000000000..4b022133ef --- /dev/null +++ b/requirements/frameworks-requirements.txt @@ -0,0 +1,20 @@ +-r tests-requirements.txt +catboost +lightgbm +mlflow +fastai +xgboost +scikit-learn +# ONNX dependencies +onnx +onnxruntime +# tensorflow dependencies +keras +tensorflow>=2.3.0;platform_system!="Darwin" +tensorflow-macos>=2.3.0;platform_system=="Darwin" +# torch-related dependencies +torch +pytorch-lightning +# huggingface dependencies +transformers +tokenizer diff --git a/requirements/tests-requirements.txt b/requirements/tests-requirements.txt index 83c6c2516c..6d4dac830b 100644 --- a/requirements/tests-requirements.txt +++ b/requirements/tests-requirements.txt @@ -8,7 +8,7 @@ pydantic pylint>=2.14.0 pytest-cov>=3.0.0 pytest>=6.2.0 -pytest-xdist +pytest-xdist[psutil] pytest-asyncio pandas scikit-learn @@ -16,5 +16,6 @@ imageio>=2.5.0 pyarrow build[virtualenv] >=0.8.0 yamllint +protobuf>=3.5.0, <3.20,!=3.19.5 grpcio-tools>=1.41.0,<1.49.0 opentelemetry-test-utils==0.33b0 diff --git a/scripts/ci/config.yml b/scripts/ci/config.yml index 569100fd18..ae1ae266a9 100644 --- a/scripts/ci/config.yml +++ b/scripts/ci/config.yml @@ -27,21 +27,24 @@ unit: is_dir: true type_tests: "unit" -general_features: +http_server: <<: *tmpl - root_test_dir: "tests/e2e/bento_server_general_features" + root_test_dir: "tests/e2e/bento_server_http" is_dir: true type_tests: "e2e" dependencies: - - "Pillow" + - Pillow + - pydantic + - fastapi -general_features_sync: +grpc_server: <<: *tmpl - root_test_dir: "tests/e2e/bento_server_general_features_sync" + root_test_dir: "tests/e2e/bento_server_grpc" is_dir: true type_tests: "e2e" dependencies: - - "Pillow" + - Pillow + - pydantic catboost: <<: *ntmpl @@ -78,7 +81,6 @@ fastai: - pandas - scikit-learn - fasttext: <<: *tmpl dependencies: @@ -197,7 +199,7 @@ statsmodels: <<: *tmpl dependencies: - "statsmodels==0.12.2" - - "scipy==1.7.3" # statsmodels 0.12.2 is using internal APIs of scipy + - "scipy==1.7.3" # statsmodels 0.12.2 is using internal APIs of scipy - "joblib" tf1: @@ -206,7 +208,6 @@ tf1: dependencies: - "tensorflow==1.15" - transformers: <<: *ntmpl dependencies: @@ -240,7 +241,7 @@ torchscript: - "-f https://download.pytorch.org/whl/torch_stable.html" - "torch==1.11.0+cpu" - "torchvision==0.12.0+cpu" - - "protobuf<4.21.0" # https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 + - "protobuf<4.21.0" # https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 - "psutil" pytorch_lightning: @@ -250,5 +251,5 @@ pytorch_lightning: - "torch==1.11.0+cpu" - "torchvision==0.12.0+cpu" - "pytorch_lightning==1.6.3" - - "protobuf<4.21.0" # https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 + - "protobuf<4.21.0" # https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 - "psutil" diff --git a/scripts/ci/run_tests.sh b/scripts/ci/run_tests.sh index 0fb4765450..c3dffcbe20 100755 --- a/scripts/ci/run_tests.sh +++ b/scripts/ci/run_tests.sh @@ -9,18 +9,21 @@ fname=$(basename "$0") dname=$(dirname "$0") +# shellcheck disable=SC1091 source "$dname/helpers.sh" set_on_failed_callback "ERR=1" GIT_ROOT=$(git rev-parse --show-toplevel) -ERR=0 - declare -a PYTESTARGS CONFIG_FILE="$dname/config.yml" REQ_FILE="/tmp/additional-requirements.txt" SKIP_DEPS=0 +ERR=0 +VERBOSE=0 +ENABLE_XDIST=1 +WORKERS=auto cd "$GIT_ROOT" || exit @@ -53,27 +56,33 @@ usage() { Running unit/integration tests with pytest and generate coverage reports. Make sure that given testcases is defined under $CONFIG_FILE. Usage: - $dname/$fname [-h|--help] [-v|--verbose] [-s|--skip_deps] + $dname/$fname [-h|--help] [-v|--verbose] [-s|--skip-deps] Flags: -h, --help show this message -v, --verbose set verbose scripts - -s, --skip_deps skip install dependencies + -s, --skip-deps skip install dependencies + -w, --workers number of workers for pytest-xdist + --disable-xdist disable pytest-xdist If pytest_additional_arguments is given, this will be appended to given tests run. Example: - $ $dname/$fname pytorch --gpus + $ $dname/$fname pytorch --run-gpu-tests HEREDOC exit 2 } parse_args() { - if [ "${#@}" -eq 0 ]; then + if [ "${#}" -eq 0 ]; then FAIL "$0 doesn't run without any arguments" exit 1 fi + if [ "${1:0:1}" = "-" ]; then + FAIL "First arguments must be a target, not a flag." + exit 1 + fi for arg in "$@"; do case "$arg" in @@ -82,9 +91,19 @@ parse_args() { ;; -v | --verbose) set -x + VERBOSE=1 + shift + ;; + -w | --workers) + shift + WORKERS="$2" + shift + ;; + --disable-xdist) + ENABLE_XDIST=0 shift ;; - -s | --skip_deps) + -s | --skip-deps) SKIP_DEPS=1 shift ;; @@ -176,10 +195,10 @@ main() { fi done - # validate_yaml + # validate_yaml parse_config "$argv" - OPTS=(--cov=bentoml --cov-config="$GIT_ROOT"/pyproject.toml --cov-report=xml:"$target.xml" --cov-report=term-missing -x) + OPTS=(--cov-config="$GIT_ROOT/pyproject.toml" --cov-report=xml:"$target.xml") if [ -n "${PYTESTARGS[*]}" ]; then # shellcheck disable=SC2206 @@ -187,7 +206,14 @@ main() { fi if [ "$fname" == "test_frameworks.py" ]; then - OPTS=("--framework" "$target" ${OPTS[@]}) + OPTS=("--framework" "$target" "${OPTS[@]}") + fi + if [ "$VERBOSE" -eq 1 ]; then + OPTS=("${OPTS[@]}" -vvv) + fi + + if [ "$type_tests" == 'unit' ] && [ "$ENABLE_XDIST" -eq 1 ] && [ "$(uname | tr '[:upper:]' '[:lower:]')" != "win32" ]; then + OPTS=("${OPTS[@]}" --dist loadfile -n "$WORKERS") fi if [ "$SKIP_DEPS" -eq 0 ]; then @@ -202,7 +228,8 @@ main() { fi if [ "$type_tests" == 'e2e' ]; then - cd "$GIT_ROOT"/"$test_dir"/"$fname" || exit 1 + p="$GIT_ROOT/$test_dir" + cd "$p" || exit 1 path="." else path="$GIT_ROOT"/"$test_dir"/"$fname" @@ -213,13 +240,11 @@ main() { # Return non-zero if pytest failed if ! test $ERR = 0; then - FAIL "$args $type_tests tests failed!" + FAIL "$type_tests tests failed!" exit 1 fi - PASS "$args $type_tests tests passed!" + PASS "$type_tests tests passed!" } main "$@" || exit 1 - -# vim: set ft=sh ts=2 sw=2 tw=0 et : diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 51506d6e0b..0000000000 --- a/tests/README.md +++ /dev/null @@ -1,5 +0,0 @@ -```bash -tests/utils -> tests helpers and utilities -tests/integration -> integration tests -tests/unit -> unitest -``` diff --git a/tests/e2e/README.md b/tests/e2e/README.md new file mode 100644 index 0000000000..b2ac7728d4 --- /dev/null +++ b/tests/e2e/README.md @@ -0,0 +1,114 @@ +# End-to-end tests suite + +This folder contains end-to-end test suite. + +## Instruction + +To create a new test suite (for simplicity let's call our test suite `qa`), do the following: + +1. Navigate to [`config.yml`](../../scripts/ci/config.yml) and add the E2E definition: + +```yaml +qa: + <<: *tmpl + root_test_dir: "tests/e2e/qa" + is_dir: true + type_tests: "e2e" + dependencies: # add required Python dependencies here. + - Pillow + - pydantic + - grpcio-status +``` + +2. Create the folder `qa` with the following project structure: + +```bash +. +├── bentofile.yaml +├── train.py +... +├── service.py +└── tests + ├── conftest.py + ├── test_io.py + ... + └── test_meta.py +``` + +> Note that files under `tests` are merely examples, feel free to add any types of +> additional tests. + +3. Create a `train.py`: + +```python +if __name__ == "__main__": + import python_model + + import bentoml + + bentoml.picklable_model.save_model( + "py_model.case-1.grpc.e2e", + python_model.PythonFunction(), + signatures={ + "echo_json": {"batchable": True}, + "echo_object": {"batchable": False}, + "echo_ndarray": {"batchable": True}, + "double_ndarray": {"batchable": True}, + }, + external_modules=[python_model], + ) +``` + +4. Inside `tests/conftest.py`, create a `host` fixture like so: + +```python +# pylint: disable=unused-argument +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +import pytest + +from bentoml._internal.configuration.containers import BentoMLContainer + +if TYPE_CHECKING: + from contextlib import ExitStack + + from _pytest.main import Session + from _pytest.nodes import Item + from _pytest.config import Config + + +def pytest_collection_modifyitems( + session: Session, config: Config, items: list[Item] +) -> None: + subprocess.check_call( + [sys.executable, "-m", "train"], + env={"BENTOML_HOME": BentoMLContainer.bentoml_home.get()}, + ) + + +@pytest.fixture(scope="module") +def host( + bentoml_home: str, + deployment_mode: str, + clean_context: ExitStack, +) -> t.Generator[str, None, None]: + from bentoml.testing.server import host_bento + + with host_bento( + "service:svc", + deployment_mode=deployment_mode, + bentoml_home=bentoml_home, + clean_context=clean_context, + use_grpc=True, + ) as _host: + yield _host +``` + +5. To run the tests, navigate to `GIT_ROOT` (root directory of bentoml), and call: + +```bash +./scripts/ci/run_tests.sh qa +``` diff --git a/tests/e2e/__init__.py b/tests/e2e/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/e2e/bento_server_general_features/server_config_cors_enabled.yml b/tests/e2e/bento_server_general_features/server_config_cors_enabled.yml deleted file mode 100644 index cc51066139..0000000000 --- a/tests/e2e/bento_server_general_features/server_config_cors_enabled.yml +++ /dev/null @@ -1,9 +0,0 @@ -api_server: - cors: # standard: https://fetch.spec.whatwg.org/#http-cors-protocol - enabled: True - access_control_allow_origin: "*" - access_control_allow_methods: ["GET", "OPTIONS", "POST", "HEAD", "PUT"] - access_control_allow_credentials: True - access_control_allow_headers: Null - access_control_max_age: Null - access_control_expose_headers: ["Content-Length"] diff --git a/tests/e2e/bento_server_general_features/server_config_default.yml b/tests/e2e/bento_server_general_features/server_config_default.yml deleted file mode 100644 index 096565f0b1..0000000000 --- a/tests/e2e/bento_server_general_features/server_config_default.yml +++ /dev/null @@ -1,3 +0,0 @@ -api_server: - cors: # standard: https://fetch.spec.whatwg.org/#http-cors-protocol - enabled: False diff --git a/tests/e2e/bento_server_general_features/tests/conftest.py b/tests/e2e/bento_server_general_features/tests/conftest.py deleted file mode 100644 index 17bb94ff19..0000000000 --- a/tests/e2e/bento_server_general_features/tests/conftest.py +++ /dev/null @@ -1,100 +0,0 @@ -# type: ignore[no-untyped-def] - -import os -import typing as t -import contextlib - -import numpy as np -import psutil -import pytest - - -@pytest.fixture() -def img_file(tmpdir) -> str: - import PIL.Image - - img_file_ = tmpdir.join("test_img.bmp") - img = PIL.Image.fromarray(np.random.randint(255, size=(10, 10, 3)).astype("uint8")) - img.save(str(img_file_)) - return str(img_file_) - - -@pytest.fixture() -def bin_file(tmpdir) -> str: - bin_file_ = tmpdir.join("bin_file.bin") - with open(bin_file_, "wb") as of: - of.write("â".encode("gb18030")) - return str(bin_file_) - - -def pytest_configure(config): # pylint: disable=unused-argument - import sys - import subprocess - - cmd = f"{sys.executable} {os.path.join(os.getcwd(), 'train.py')}" - subprocess.run(cmd, shell=True, check=True) - - # use the local bentoml package in development - os.environ["BENTOML_BUNDLE_LOCAL_BUILD"] = "True" - os.environ["SETUPTOOLS_USE_DISTUTILS"] = "stdlib" - - -@pytest.fixture(scope="session", autouse=True) -def clean_context(): - stack = contextlib.ExitStack() - yield stack - stack.close() - - -@pytest.fixture( - params=[ - "server_config_default.yml", - "server_config_cors_enabled.yml", - ], - scope="session", -) -def server_config_file(request): - return request.param - - -@pytest.fixture( - params=[ - # "dev", - "standalone", - "docker", - "distributed", - ], - scope="session", -) -def deployment_mode(request) -> str: - return request.param - - -@pytest.fixture(scope="session") -def host( - deployment_mode: str, - server_config_file: str, - clean_context: contextlib.ExitStack, -) -> t.Generator[str, None, None]: - if ( - (psutil.WINDOWS or psutil.MACOS) - and os.environ.get("GITHUB_ACTION") - and deployment_mode == "docker" - ): - pytest.skip( - "due to GitHub Action's limitation, docker deployment is not supported on " - "windows/macos. But you can still run this test on macos/windows locally." - ) - - if not psutil.LINUX and deployment_mode == "distributed": - pytest.skip("distributed deployment is only supported on Linux") - - from bentoml.testing.server import host_bento - - with host_bento( - "service:svc", - config_file=server_config_file, - deployment_mode=deployment_mode, - clean_context=clean_context, - ) as host: - yield host diff --git a/tests/e2e/bento_server_general_features/tests/test_microbatch.py b/tests/e2e/bento_server_general_features/tests/test_microbatch.py deleted file mode 100644 index fa992d85ff..0000000000 --- a/tests/e2e/bento_server_general_features/tests/test_microbatch.py +++ /dev/null @@ -1,119 +0,0 @@ -# import asyncio - -# import time - -# import psutil -# import pytest - -DEFAULT_MAX_LATENCY = 10 * 1000 - - -""" - -@pytest.mark.skipif(not psutil.POSIX, reason="production server only works on POSIX") -@pytest.mark.asyncio -async def test_slow_server(host): - - A, B = 0.2, 1 - data = '{"a": %s, "b": %s}' % (A, B) - - time_start = time.time() - req_count = 10 - tasks = tuple( - pytest.async_request( - "POST", - f"http://{host}/echo_with_delay", - headers=(("Content-Type", "application/json"),), - data=data, - timeout=30, - assert_status=200, - assert_data=data.encode(), - ) - for i in range(req_count) - ) - await asyncio.gather(*tasks) - assert time.time() - time_start < 12 - - -@pytest.mark.skipif(not psutil.POSIX, reason="production server only works on POSIX") -@pytest.mark.asyncio -async def test_fast_server(host): - - A, B = 0.0002, 0.01 - data = '{"a": %s, "b": %s}' % (A, B) - - req_count = 100 - tasks = tuple( - pytest.async_request( - "POST", - f"http://{host}/echo_with_delay", - headers=(("Content-Type", "application/json"),), - data=data, - assert_status=lambda i: i in (200, 429), - ) - for i in range(req_count) - ) - await asyncio.gather(*tasks) - - time_start = time.time() - req_count = 200 - tasks = tuple( - pytest.async_request( - "POST", - f"http://{host}/echo_with_delay", - headers=(("Content-Type", "application/json"),), - data=data, - timeout=30, - assert_status=200, - assert_data=data.encode(), - ) - for i in range(req_count) - ) - await asyncio.gather(*tasks) - assert time.time() - time_start < 2 - - -@pytest.mark.skipif(not psutil.POSIX, reason="production server only works on POSIX") -@pytest.mark.asyncio -async def test_batch_size_limit(host): - - A, B = 0.0002, 0.01 - data = '{"a": %s, "b": %s}' % (A, B) - - # test for max_batch_size=None - tasks = tuple( - pytest.async_request( - "POST", - f"http://{host}/echo_batch_size", - headers=(("Content-Type", "application/json"),), - data=data, - assert_status=lambda i: i in (200, 429), - ) - for _ in range(100) - ) - await asyncio.gather(*tasks) - await asyncio.sleep(1) - - batch_bucket = [] - - tasks = tuple( - pytest.async_request( - "POST", - f"http://{host}/echo_batch_size", - headers=(("Content-Type", "application/json"),), - data=data, - assert_status=200, - assert_data=lambda d: ( - d == b"429: Too Many Requests" - or batch_bucket.append(int(d.decode())) - or True - ), - ) - for _ in range(50) - ) - await asyncio.gather(*tasks) - - # batch size could be dynamic because of the bentoml_config.yml - # microbatch.max_batch_size=Null - assert any(b > 1 for b in batch_bucket), batch_bucket -""" diff --git a/tests/e2e/bento_server_grpc/bentofile.yaml b/tests/e2e/bento_server_grpc/bentofile.yaml new file mode 100644 index 0000000000..997c2df199 --- /dev/null +++ b/tests/e2e/bento_server_grpc/bentofile.yaml @@ -0,0 +1,11 @@ +service: service:svc +exclude: + - python_model.py + - "*.xml" +python: + packages: + - pandas + - pydantic + - Pillow + - scikit-learn + - pyarrow diff --git a/tests/e2e/bento_server_grpc/context_server_interceptor.py b/tests/e2e/bento_server_grpc/context_server_interceptor.py new file mode 100644 index 0000000000..dd565cfae7 --- /dev/null +++ b/tests/e2e/bento_server_grpc/context_server_interceptor.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import typing as t +import functools +import dataclasses +from typing import TYPE_CHECKING + +from grpc import aio + +if TYPE_CHECKING: + from bentoml.grpc.types import Request + from bentoml.grpc.types import Response + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import AsyncHandlerMethod + from bentoml.grpc.types import HandlerCallDetails + from bentoml.grpc.types import BentoServicerContext + + +@dataclasses.dataclass +class Context: + usage: str + accuracy_score: float + + +class AsyncContextInterceptor(aio.ServerInterceptor): + def __init__(self, *, usage: str, accuracy_score: float) -> None: + self.context = Context(usage=usage, accuracy_score=accuracy_score) + self._record: set[str] = set() + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + from bentoml.grpc.utils import wrap_rpc_handler + + handler = await continuation(handler_call_details) + + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + self._record.update( + {f"{self.context.usage}:{self.context.accuracy_score}"} + ) + resp = await behaviour(request, context) + context.set_trailing_metadata( + tuple( + [ + (k, str(v).encode("utf-8")) + for k, v in dataclasses.asdict(self.context).items() + ] + ) + ) + return resp + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) diff --git a/tests/e2e/bento_server_grpc/python_model.py b/tests/e2e/bento_server_grpc/python_model.py new file mode 100644 index 0000000000..2acf9715e1 --- /dev/null +++ b/tests/e2e/bento_server_grpc/python_model.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from bentoml._internal.types import FileLike + from bentoml._internal.types import JSONSerializable + + +class PythonFunction: + def predict_file(self, files: list[FileLike[bytes]]) -> list[bytes]: + return [f.read() for f in files] + + @classmethod + def echo_json(cls, datas: JSONSerializable) -> JSONSerializable: + return datas + + @classmethod + def echo_ndarray(cls, datas: NDArray[Any]) -> NDArray[Any]: + return datas + + def double_ndarray(self, data: NDArray[Any]) -> NDArray[Any]: + assert isinstance(data, np.ndarray) + return data * 2 + + def multiply_float_ndarray( + self, arr1: NDArray[np.float32], arr2: NDArray[np.float32] + ) -> NDArray[np.float32]: + assert isinstance(arr1, np.ndarray) + assert isinstance(arr2, np.ndarray) + return arr1 * arr2 + + def double_dataframe_column(self, df: pd.DataFrame) -> pd.DataFrame: + assert isinstance(df, pd.DataFrame) + return df[["col1"]] * 2 # type: ignore (no pandas types) + + def echo_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: + return df diff --git a/tests/e2e/bento_server_grpc/service.py b/tests/e2e/bento_server_grpc/service.py new file mode 100644 index 0000000000..bcc9a79e81 --- /dev/null +++ b/tests/e2e/bento_server_grpc/service.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +from pydantic import BaseModel +from context_server_interceptor import AsyncContextInterceptor + +import bentoml +from bentoml.io import File +from bentoml.io import JSON +from bentoml.io import Text +from bentoml.io import Image +from bentoml.io import Multipart +from bentoml.io import NumpyNdarray +from bentoml.io import PandasDataFrame +from bentoml.testing.grpc import TestServiceServicer +from bentoml._internal.utils import LazyLoader + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + import PIL.Image + from numpy.typing import NDArray + + from bentoml.grpc.v1alpha1 import service_test_pb2 as pb_test + from bentoml.grpc.v1alpha1 import service_test_pb2_grpc as services_test + from bentoml._internal.types import FileLike + from bentoml._internal.types import JSONSerializable + from bentoml.picklable_model import get_runnable + from bentoml._internal.runner.runner import RunnerMethod + + RunnableImpl = get_runnable(bentoml.picklable_model.get("py_model.case-1.grpc.e2e")) + + class PythonModelRunner(bentoml.Runner): + predict_file: RunnerMethod[RunnableImpl, [list[FileLike[bytes]]], list[bytes]] + echo_json: RunnerMethod[ + RunnableImpl, [list[JSONSerializable]], list[JSONSerializable] + ] + echo_ndarray: RunnerMethod[RunnableImpl, [NDArray[t.Any]], NDArray[t.Any]] + double_ndarray: RunnerMethod[RunnableImpl, [NDArray[t.Any]], NDArray[t.Any]] + multiply_float_ndarray: RunnerMethod[ + RunnableImpl, + [NDArray[np.float32], NDArray[np.float32]], + NDArray[np.float32], + ] + double_dataframe_column: RunnerMethod[ + RunnableImpl, [pd.DataFrame], pd.DataFrame + ] + echo_dataframe: RunnerMethod[RunnableImpl, [pd.DataFrame], pd.DataFrame] + +else: + from bentoml.grpc.utils import import_generated_stubs + + pb_test, services_test = import_generated_stubs(file="service_test.proto") + np = LazyLoader("np", globals(), "numpy") + pd = LazyLoader("pd", globals(), "pandas") + PIL = LazyLoader("PIL", globals(), "PIL") + PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image") + + +py_model = t.cast( + "PythonModelRunner", + bentoml.picklable_model.get("py_model.case-1.grpc.e2e").to_runner(), +) + +svc = bentoml.Service(name="general_grpc_service.case-1.e2e", runners=[py_model]) + +svc.mount_grpc_servicer( + TestServiceServicer, + add_servicer_fn=services_test.add_TestServiceServicer_to_server, + service_names=[v.full_name for v in pb_test.DESCRIPTOR.services_by_name.values()], +) +svc.add_grpc_interceptor(AsyncContextInterceptor, usage="NLP", accuracy_score=0.8247) + + +class IrisFeatures(BaseModel): + sepal_len: float + sepal_width: float + petal_len: float + petal_width: float + + +class IrisClassificationRequest(BaseModel): + request_id: str + iris_features: IrisFeatures + + +@svc.api(input=Text(), output=Text()) +async def bonjour(inp: str) -> str: + return f"Hello, {inp}!" + + +@svc.api(input=JSON(), output=JSON()) +async def echo_json(json_obj: JSONSerializable) -> JSONSerializable: + batched = await py_model.echo_json.async_run([json_obj]) + return batched[0] + + +@svc.api( + input=JSON(pydantic_model=IrisClassificationRequest), + output=JSON(), +) +def echo_json_validate(input_data: IrisClassificationRequest) -> dict[str, float]: + print("request_id: ", input_data.request_id) + return input_data.iris_features.dict() + + +@svc.api(input=NumpyNdarray(), output=NumpyNdarray()) +async def double_ndarray(arr: NDArray[t.Any]) -> NDArray[t.Any]: + return await py_model.double_ndarray.async_run(arr) + + +@svc.api(input=NumpyNdarray.from_sample(np.random.rand(2, 2)), output=NumpyNdarray()) +async def echo_ndarray_from_sample(arr: NDArray[t.Any]) -> NDArray[t.Any]: + assert arr.shape == (2, 2) + return await py_model.echo_ndarray.async_run(arr) + + +@svc.api(input=NumpyNdarray(shape=(2, 2), enforce_shape=True), output=NumpyNdarray()) +async def echo_ndarray_enforce_shape(arr: NDArray[t.Any]) -> NDArray[t.Any]: + assert arr.shape == (2, 2) + return await py_model.echo_ndarray.async_run(arr) + + +@svc.api( + input=NumpyNdarray(dtype=np.float32, enforce_dtype=True), output=NumpyNdarray() +) +async def echo_ndarray_enforce_dtype(arr: NDArray[t.Any]) -> NDArray[t.Any]: + assert arr.dtype == np.float32 + return await py_model.echo_ndarray.async_run(arr) + + +@svc.api(input=PandasDataFrame(orient="columns"), output=PandasDataFrame()) +async def echo_dataframe(df: pd.DataFrame) -> pd.DataFrame: + assert isinstance(df, pd.DataFrame) + return df + + +@svc.api( + input=PandasDataFrame.from_sample( + pd.DataFrame({"age": [3, 29], "height": [94, 170], "weight": [31, 115]}), + orient="columns", + ), + output=PandasDataFrame(), +) +async def echo_dataframe_from_sample(df: pd.DataFrame) -> pd.DataFrame: + assert isinstance(df, pd.DataFrame) + return df + + +@svc.api( + input=PandasDataFrame(dtype={"col1": "int64"}, orient="columns"), + output=PandasDataFrame(), +) +async def double_dataframe(df: pd.DataFrame) -> pd.DataFrame: + assert df["col1"].dtype == "int64" + output = await py_model.double_dataframe_column.async_run(df) + dfo = pd.DataFrame() + dfo["col1"] = output + return dfo + + +@svc.api(input=File(), output=File()) +async def predict_file(f: FileLike[bytes]) -> bytes: + batch_ret = await py_model.predict_file.async_run([f]) + return batch_ret[0] + + +@svc.api(input=Image(mime_type="image/bmp"), output=Image(mime_type="image/bmp")) +async def echo_image(f: PIL.Image.Image) -> NDArray[t.Any]: + assert isinstance(f, PIL.Image.Image) + return np.array(f) + + +@svc.api( + input=Multipart( + original=Image(mime_type="image/bmp"), compared=Image(mime_type="image/bmp") + ), + output=Multipart(meta=Text(), result=Image(mime_type="image/bmp")), +) +async def predict_multi_images(original: Image, compared: Image): + output_array = await py_model.multiply_float_ndarray.async_run( + np.array(original), np.array(compared) + ) + img = PIL.Image.fromarray(output_array) + return {"meta": "success", "result": img} diff --git a/tests/e2e/bento_server_grpc/tests/conftest.py b/tests/e2e/bento_server_grpc/tests/conftest.py new file mode 100644 index 0000000000..4ded9dbcfd --- /dev/null +++ b/tests/e2e/bento_server_grpc/tests/conftest.py @@ -0,0 +1,49 @@ +# pylint: disable=unused-argument +from __future__ import annotations + +import os +import sys +import typing as t +import subprocess +from typing import TYPE_CHECKING + +import psutil +import pytest + +if TYPE_CHECKING: + from contextlib import ExitStack + + from _pytest.main import Session + from _pytest.nodes import Item + from _pytest.config import Config + + +PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +def pytest_collection_modifyitems( + session: Session, config: Config, items: list[Item] +) -> None: + subprocess.check_call([sys.executable, f"{os.path.join(PROJECT_DIR, 'train.py')}"]) + + +@pytest.mark.usefixtures("change_test_dir") +@pytest.fixture(scope="module") +def host( + bentoml_home: str, + deployment_mode: t.Literal["docker", "distributed", "standalone"], + clean_context: ExitStack, +) -> t.Generator[str, None, None]: + from bentoml.testing.server import host_bento + + if psutil.WINDOWS: + pytest.skip("gRPC is not supported on Windows.") + with host_bento( + "service:svc", + deployment_mode=deployment_mode, + project_path=PROJECT_DIR, + bentoml_home=bentoml_home, + clean_context=clean_context, + use_grpc=True, + ) as _host: + yield _host diff --git a/tests/e2e/bento_server_grpc/tests/test_custom_components.py b/tests/e2e/bento_server_grpc/tests/test_custom_components.py new file mode 100644 index 0000000000..99263fa9db --- /dev/null +++ b/tests/e2e/bento_server_grpc/tests/test_custom_components.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import typing as t + +import pytest +from grpc import aio +from grpc_health.v1 import health_pb2 as pb_health +from google.protobuf import wrappers_pb2 + +from bentoml.testing.grpc import create_channel +from bentoml.grpc.v1alpha1 import service_pb2 as pb +from bentoml.grpc.v1alpha1 import service_test_pb2 as pb_test +from bentoml.grpc.v1alpha1 import service_test_pb2_grpc as services_test +from bentoml.testing.grpc.interceptors import AssertClientInterceptor + + +@pytest.mark.asyncio +async def test_success_invocation_custom_servicer(host: str) -> None: + async with create_channel(host) as channel: + HealthCheck = channel.unary_unary( + "/grpc.health.v1.Health/Check", + request_serializer=pb_health.HealthCheckRequest.SerializeToString, # type: ignore (no grpc_health type) + response_deserializer=pb_health.HealthCheckResponse.FromString, # type: ignore (no grpc_health type) + ) + health = await t.cast( + t.Awaitable[pb_health.HealthCheckResponse], + HealthCheck( + pb_health.HealthCheckRequest( + service="bentoml.testing.v1alpha1.TestService" + ) + ), + ) + assert health.status == pb_health.HealthCheckResponse.SERVING # type: ignore ( no generated enum types) + stub = services_test.TestServiceStub(channel) # type: ignore (no async types) + request = pb_test.ExecuteRequest(input="BentoML") + resp: pb_test.ExecuteResponse = await stub.Execute(request) + assert resp.output == "Hello, BentoML!" + + +@pytest.mark.asyncio +async def test_trailing_metadata_interceptors(host: str) -> None: + async with create_channel( + host, + interceptors=[ + AssertClientInterceptor( + assert_trailing_metadata=aio.Metadata.from_tuple( + (("usage", "NLP"), ("accuracy_score", "0.8247")) + ) + ) + ], + ) as channel: + Call = channel.unary_unary( + "/bentoml.grpc.v1alpha1.BentoService/Call", + request_serializer=pb.Request.SerializeToString, + response_deserializer=pb.Response.FromString, + ) + await t.cast( + t.Awaitable[pb.Request], + Call( + pb.Request( + api_name="bonjour", text=wrappers_pb2.StringValue(value="BentoML") + ) + ), + ) diff --git a/tests/e2e/bento_server_grpc/tests/test_descriptors.py b/tests/e2e/bento_server_grpc/tests/test_descriptors.py new file mode 100644 index 0000000000..deabe1fc8b --- /dev/null +++ b/tests/e2e/bento_server_grpc/tests/test_descriptors.py @@ -0,0 +1,403 @@ +from __future__ import annotations + +import io +import random +import traceback +from typing import TYPE_CHECKING +from functools import partial + +import pytest + +from bentoml.testing.grpc import create_channel +from bentoml.testing.grpc import async_client_call +from bentoml.testing.grpc import randomize_pb_ndarray +from bentoml._internal.utils import LazyType +from bentoml._internal.utils import LazyLoader + +if TYPE_CHECKING: + import grpc + import numpy as np + import pandas as pd + import PIL.Image as PILImage + from grpc import aio + from google.protobuf import struct_pb2 + from google.protobuf import wrappers_pb2 + + from bentoml._internal import external_typing as ext + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + grpc, aio = import_grpc() + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") + np = LazyLoader("np", globals(), "numpy") + pd = LazyLoader("pd", globals(), "pandas") + PILImage = LazyLoader("PILImage", globals(), "PIL.Image") + + +def assert_ndarray( + resp: pb.Response, + assert_shape: list[int], + assert_dtype: pb.NDArray.DType.ValueType, +) -> bool: + __tracebackhide__ = True # Hide traceback for py.test + + dtype = resp.ndarray.dtype + try: + assert resp.ndarray.shape == assert_shape + assert dtype == assert_dtype + return True + except AssertionError: + traceback.print_exc() + return False + + +def make_iris_proto(**fields: struct_pb2.Value) -> struct_pb2.Value: + return struct_pb2.Value( + struct_value=struct_pb2.Struct( + fields={ + "request_id": struct_pb2.Value(string_value="123"), + "iris_features": struct_pb2.Value( + struct_value=struct_pb2.Struct(fields=fields) + ), + } + ) + ) + + +@pytest.mark.asyncio +async def test_numpy(host: str): + async with create_channel(host) as channel: + await async_client_call( + "double_ndarray", + channel=channel, + data={"ndarray": randomize_pb_ndarray((1000,))}, + assert_data=partial( + assert_ndarray, assert_shape=[1000], assert_dtype=pb.NDArray.DTYPE_FLOAT + ), + ) + await async_client_call( + "double_ndarray", + channel=channel, + data={"ndarray": pb.NDArray(shape=[2, 2], int32_values=[1, 2, 3, 4])}, + assert_data=lambda resp: resp.ndarray.int32_values == [2, 4, 6, 8], + ) + with pytest.raises(aio.AioRpcError): + await async_client_call( + "double_ndarray", + channel=channel, + data={"ndarray": pb.NDArray(string_values=np.array(["2", "2f"]))}, + assert_code=grpc.StatusCode.INTERNAL, + ) + await async_client_call( + "double_ndarray", + channel=channel, + data={ + "ndarray": pb.NDArray( + dtype=123, string_values=np.array(["2", "2f"]) # type: ignore (test exception) + ) + }, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "double_ndarray", + channel=channel, + data={"serialized_bytes": np.array([1, 2, 3, 4]).ravel().tobytes()}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "double_ndarray", + channel=channel, + data={"text": wrappers_pb2.StringValue(value="asdf")}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "echo_ndarray_enforce_shape", + channel=channel, + data={"ndarray": randomize_pb_ndarray((1000,))}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "echo_ndarray_enforce_dtype", + channel=channel, + data={"ndarray": pb.NDArray(string_values=np.array(["2", "2f"]))}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + + +@pytest.mark.asyncio +async def test_json(host: str): + async with create_channel(host) as channel: + await async_client_call( + "echo_json", + channel=channel, + data={"json": struct_pb2.Value(string_value='"hi"')}, + assert_data=lambda resp: resp.json.string_value == '"hi"', + ) + await async_client_call( + "echo_json", + channel=channel, + data={ + "serialized_bytes": b'{"request_id": "123", "iris_features": {"sepal_len":2.34,"sepal_width":1.58, "petal_len":6.52, "petal_width":3.23}}' + }, + assert_data=lambda resp: resp.json # type: ignore (bad lambda types) + == make_iris_proto( + sepal_len=struct_pb2.Value(number_value=2.34), + sepal_width=struct_pb2.Value(number_value=1.58), + petal_len=struct_pb2.Value(number_value=6.52), + petal_width=struct_pb2.Value(number_value=3.23), + ), + ) + await async_client_call( + "echo_json_validate", + channel=channel, + data={ + "json": make_iris_proto( + **{ + k: struct_pb2.Value(number_value=random.uniform(1.0, 6.0)) + for k in [ + "sepal_len", + "sepal_width", + "petal_len", + "petal_width", + ] + } + ) + }, + ) + with pytest.raises(aio.AioRpcError): + await async_client_call( + "echo_json", + channel=channel, + data={"serialized_bytes": b"\n?xfa"}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "echo_json", + channel=channel, + data={"text": wrappers_pb2.StringValue(value="asdf")}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "echo_json_validate", + channel=channel, + data={ + "json": make_iris_proto( + sepal_len=struct_pb2.Value(number_value=2.34), + sepal_width=struct_pb2.Value(number_value=1.58), + petal_len=struct_pb2.Value(number_value=6.52), + ), + }, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + + +@pytest.mark.asyncio +async def test_file(host: str, bin_file: str): + # Test File as binary + with open(str(bin_file), "rb") as f: + fb = f.read() + + async with create_channel(host) as channel: + await async_client_call( + "predict_file", + channel=channel, + data={"serialized_bytes": fb}, + assert_data=lambda resp: resp.file.content == fb, + ) + await async_client_call( + "predict_file", + channel=channel, + data={"file": pb.File(kind=pb.File.FILE_TYPE_BYTES, content=fb)}, + assert_data=lambda resp: resp.file.content == b"\x810\x899" + and resp.file.kind == pb.File.FILE_TYPE_BYTES, + ) + with pytest.raises(aio.AioRpcError): + await async_client_call( + "predict_file", + channel=channel, + data={"file": pb.File(kind=123, content=fb)}, # type: ignore (testing exception) + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "predict_file", + channel=channel, + data={"file": pb.File(kind=pb.File.FILE_TYPE_PDF, content=fb)}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "predict_file", + channel=channel, + data={"text": wrappers_pb2.StringValue(value="asdf")}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + + +def assert_image( + resp: pb.Response | pb.Part, + assert_kind: pb.File.FileType.ValueType, + im_file: str | ext.NpNDArray, +) -> bool: + fio = io.BytesIO(resp.file.content) + fio.name = "test.bmp" + img = PILImage.open(fio) + a1 = np.array(img) + if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(im_file): + a2 = PILImage.fromarray(im_file) + else: + assert isinstance(im_file, str) + a2 = PILImage.open(im_file) + try: + assert resp.file.kind == assert_kind + np.testing.assert_array_almost_equal(a1, np.array(a2)) + return True + except AssertionError: + traceback.print_exc() + return False + + +@pytest.mark.asyncio +async def test_image(host: str, img_file: str): + with open(str(img_file), "rb") as f: + fb = f.read() + + async with create_channel(host) as channel: + await async_client_call( + "echo_image", + channel=channel, + data={"serialized_bytes": fb}, + assert_data=partial( + assert_image, im_file=img_file, assert_kind=pb.File.FILE_TYPE_BMP + ), + ) + await async_client_call( + "echo_image", + channel=channel, + data={"file": pb.File(kind=pb.File.FILE_TYPE_BMP, content=fb)}, + assert_data=partial( + assert_image, im_file=img_file, assert_kind=pb.File.FILE_TYPE_BMP + ), + ) + with pytest.raises(aio.AioRpcError): + await async_client_call( + "echo_image", + channel=channel, + data={"file": pb.File(kind=123, content=fb)}, # type: ignore (testing exception) + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "echo_image", + channel=channel, + data={"file": pb.File(kind=pb.File.FILE_TYPE_PDF, content=fb)}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + await async_client_call( + "echo_image", + channel=channel, + data={"text": wrappers_pb2.StringValue(value="asdf")}, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + + +@pytest.mark.asyncio +async def test_pandas(host: str): + async with create_channel(host) as channel: + await async_client_call( + "echo_dataframe", + channel=channel, + data={ + "dataframe": pb.DataFrame( + column_names=[ + str(i) for i in pd.RangeIndex(0, 3, 1, dtype=np.int64).tolist() + ], + columns=[ + pb.Series(int32_values=[1]), + pb.Series(int32_values=[2]), + pb.Series(int32_values=[3]), + ], + ), + }, + ) + await async_client_call( + "echo_dataframe_from_sample", + channel=channel, + data={ + "dataframe": pb.DataFrame( + column_names=["age", "height", "weight"], + columns=[ + pb.Series(int64_values=[12, 23]), + pb.Series(int64_values=[40, 83]), + pb.Series(int64_values=[32, 89]), + ], + ), + }, + ) + await async_client_call( + "double_dataframe", + channel=channel, + data={ + "dataframe": pb.DataFrame( + column_names=["col1"], + columns=[pb.Series(int64_values=[23])], + ), + }, + assert_data=lambda resp: resp.dataframe # type: ignore (bad lambda types) + == pb.DataFrame( + column_names=["col1"], + columns=[pb.Series(int64_values=[46])], + ), + ) + with pytest.raises(aio.AioRpcError): + await async_client_call( + "echo_dataframe", + channel=channel, + data={ + "dataframe": pb.DataFrame( + column_names=["col1"], + columns=[pb.Series(int64_values=[23], int32_values=[23])], + ), + }, + assert_code=grpc.StatusCode.INVALID_ARGUMENT, + ) + + +def assert_multi_images(resp: pb.Response, method: str, im_file: str) -> bool: + assert method == "pred_multi_images" + img = PILImage.open(im_file) + arr = np.array(img) + expected = arr * arr + return assert_image( + resp.multipart.fields["result"], + assert_kind=pb.File.FILE_TYPE_BMP, + im_file=expected, + ) + + +@pytest.mark.asyncio +async def test_multipart(host: str, img_file: str): + with open(str(img_file), "rb") as f: + fb = f.read() + + async with create_channel(host) as channel: + await async_client_call( + "predict_multi_images", + channel=channel, + data={ + "multipart": { + "fields": { + "original": pb.Part( + file=pb.File(kind=pb.File.FILE_TYPE_BMP, content=fb) + ), + "compared": pb.Part( + file=pb.File(kind=pb.File.FILE_TYPE_BMP, content=fb) + ), + } + } + }, + assert_data=partial( + assert_multi_images, method="pred_multi_images", im_file=img_file + ), + ) diff --git a/tests/e2e/bento_server_grpc/train.py b/tests/e2e/bento_server_grpc/train.py new file mode 100644 index 0000000000..0d24797e7f --- /dev/null +++ b/tests/e2e/bento_server_grpc/train.py @@ -0,0 +1,19 @@ +if __name__ == "__main__": + import python_model + + import bentoml + + bentoml.picklable_model.save_model( + "py_model.case-1.grpc.e2e", + python_model.PythonFunction(), + signatures={ + "predict_file": {"batchable": True}, + "echo_json": {"batchable": True}, + "echo_object": {"batchable": False}, + "echo_ndarray": {"batchable": True}, + "double_ndarray": {"batchable": True}, + "multiply_float_ndarray": {"batchable": True}, + "double_dataframe_column": {"batchable": True}, + }, + external_modules=[python_model], + ) diff --git a/tests/e2e/bento_server_general_features/bentofile.yaml b/tests/e2e/bento_server_http/bentofile.yaml similarity index 100% rename from tests/e2e/bento_server_general_features/bentofile.yaml rename to tests/e2e/bento_server_http/bentofile.yaml diff --git a/tests/e2e/bento_server_http/configs/cors_enabled.yml b/tests/e2e/bento_server_http/configs/cors_enabled.yml new file mode 100644 index 0000000000..bc2ca106a3 --- /dev/null +++ b/tests/e2e/bento_server_http/configs/cors_enabled.yml @@ -0,0 +1,10 @@ +api_server: + http: + cors: # standard: https://fetch.spec.whatwg.org/#http-cors-protocol + enabled: True + access_control_allow_origin: "*" + access_control_allow_methods: ["GET", "OPTIONS", "POST", "HEAD", "PUT"] + access_control_allow_credentials: True + access_control_allow_headers: Null + access_control_max_age: Null + access_control_expose_headers: ["Content-Length"] diff --git a/tests/e2e/bento_server_http/configs/default.yml b/tests/e2e/bento_server_http/configs/default.yml new file mode 100644 index 0000000000..bbfaa3aa79 --- /dev/null +++ b/tests/e2e/bento_server_http/configs/default.yml @@ -0,0 +1,4 @@ +api_server: + http: + cors: # standard: https://fetch.spec.whatwg.org/#http-cors-protocol + enabled: False diff --git a/tests/e2e/bento_server_general_features/pickle_model.py b/tests/e2e/bento_server_http/pickle_model.py similarity index 57% rename from tests/e2e/bento_server_general_features/pickle_model.py rename to tests/e2e/bento_server_http/pickle_model.py index 4d6627b0ca..b91ea9bac1 100644 --- a/tests/e2e/bento_server_general_features/pickle_model.py +++ b/tests/e2e/bento_server_http/pickle_model.py @@ -1,10 +1,16 @@ +from __future__ import annotations + import typing as t +from typing import TYPE_CHECKING import numpy as np import pandas as pd -from bentoml._internal.types import FileLike -from bentoml._internal.types import JSONSerializable +if TYPE_CHECKING: + from numpy.typing import NDArray + + from bentoml._internal.types import FileLike + from bentoml._internal.types import JSONSerializable class PickleModel: @@ -19,29 +25,21 @@ def echo_json(cls, input_datas: JSONSerializable) -> JSONSerializable: def echo_obj(cls, input_datas: t.Any) -> t.Any: return input_datas - def echo_multi_ndarray( - self, - *input_arr: "np.ndarray[t.Any, np.dtype[t.Any]]", - ) -> t.Tuple["np.ndarray[t.Any, np.dtype[t.Any]]", ...]: + def echo_multi_ndarray(self, *input_arr: NDArray[t.Any]) -> tuple[NDArray[t.Any]]: return input_arr - def predict_ndarray( - self, - arr: "np.ndarray[t.Any, np.dtype[t.Any]]", - ) -> "np.ndarray[t.Any, np.dtype[t.Any]]": + def predict_ndarray(self, arr: NDArray[t.Any]) -> NDArray[t.Any]: assert isinstance(arr, np.ndarray) return arr * 2 def predict_multi_ndarray( - self, - arr1: "np.ndarray[t.Any, np.dtype[t.Any]]", - arr2: "np.ndarray[t.Any, np.dtype[t.Any]]", - ) -> "np.ndarray[t.Any, np.dtype[t.Any]]": + self, arr1: NDArray[t.Any], arr2: NDArray[t.Any] + ) -> NDArray[t.Any]: assert isinstance(arr1, np.ndarray) assert isinstance(arr2, np.ndarray) return (arr1 + arr2) // 2 - def predict_dataframe(self, df: "pd.DataFrame") -> "pd.DataFrame": + def predict_dataframe(self, df: pd.DataFrame) -> pd.DataFrame: assert isinstance(df, pd.DataFrame) output = df[["col1"]] * 2 # type: ignore assert isinstance(output, pd.DataFrame) diff --git a/tests/e2e/bento_server_general_features/requirements.txt b/tests/e2e/bento_server_http/requirements.txt similarity index 86% rename from tests/e2e/bento_server_general_features/requirements.txt rename to tests/e2e/bento_server_http/requirements.txt index 108615ac40..102c584db9 100644 --- a/tests/e2e/bento_server_general_features/requirements.txt +++ b/tests/e2e/bento_server_http/requirements.txt @@ -1,6 +1,6 @@ pandas pydantic -pillow +Pillow scikit-learn pyarrow fastapi diff --git a/tests/e2e/bento_server_general_features/service.py b/tests/e2e/bento_server_http/service.py similarity index 74% rename from tests/e2e/bento_server_general_features/service.py rename to tests/e2e/bento_server_http/service.py index f50b31fe5d..b6acaf3e9a 100644 --- a/tests/e2e/bento_server_general_features/service.py +++ b/tests/e2e/bento_server_http/service.py @@ -1,29 +1,38 @@ +from __future__ import annotations + import typing as t +from typing import TYPE_CHECKING import numpy as np import pandas as pd import pydantic from PIL.Image import Image as PILImage from PIL.Image import fromarray +from starlette.requests import Request import bentoml -import bentoml.picklable_model from bentoml.io import File from bentoml.io import JSON from bentoml.io import Image from bentoml.io import Multipart from bentoml.io import NumpyNdarray from bentoml.io import PandasDataFrame -from bentoml._internal.types import FileLike -from bentoml._internal.types import JSONSerializable -py_model = bentoml.picklable_model.get("py_model.case-1.e2e").to_runner() +if TYPE_CHECKING: + from numpy.typing import NDArray + from starlette.types import Send + from starlette.types import Scope + from starlette.types import ASGIApp + from starlette.types import Receive + from bentoml._internal.types import FileLike + from bentoml._internal.types import JSONSerializable -svc = bentoml.Service( - name="general_workflow_service.case-1.e2e", - runners=[py_model], -) + +py_model = bentoml.picklable_model.get("py_model.case-1.http.e2e").to_runner() + + +svc = bentoml.Service(name="general_http_service.case-1.e2e", runners=[py_model]) @svc.api(input=JSON(), output=JSON()) @@ -38,13 +47,13 @@ def echo_json_sync(json_obj: JSONSerializable) -> JSONSerializable: return batch_ret[0] -class _Schema(pydantic.BaseModel): +class ValidateSchema(pydantic.BaseModel): name: str endpoints: t.List[str] @svc.api( - input=JSON(pydantic_model=_Schema), + input=JSON(pydantic_model=ValidateSchema), output=JSON(), ) async def echo_json_enforce_structure(json_obj: JSONSerializable) -> JSONSerializable: @@ -61,9 +70,7 @@ async def echo_obj(obj: JSONSerializable) -> JSONSerializable: input=NumpyNdarray(shape=(2, 2), enforce_shape=True), output=NumpyNdarray(shape=(2, 2)), ) -async def predict_ndarray_enforce_shape( - inp: "np.ndarray[t.Any, np.dtype[t.Any]]", -) -> "np.ndarray[t.Any, np.dtype[t.Any]]": +async def predict_ndarray_enforce_shape(inp: NDArray[t.Any]) -> NDArray[t.Any]: assert inp.shape == (2, 2) return await py_model.predict_ndarray.async_run(inp) @@ -72,9 +79,7 @@ async def predict_ndarray_enforce_shape( input=NumpyNdarray(dtype="uint8", enforce_dtype=True), output=NumpyNdarray(dtype="str"), ) -async def predict_ndarray_enforce_dtype( - inp: "np.ndarray[t.Any, np.dtype[t.Any]]", -) -> "np.ndarray[t.Any, np.dtype[t.Any]]": +async def predict_ndarray_enforce_dtype(inp: NDArray[t.Any]) -> NDArray[t.Any]: assert inp.dtype == np.dtype("uint8") return await py_model.predict_ndarray.async_run(inp) @@ -94,7 +99,7 @@ async def predict_ndarray_multi_output( input=PandasDataFrame(dtype={"col1": "int64"}, orient="records"), output=PandasDataFrame(), ) -async def predict_dataframe(df: "pd.DataFrame") -> "pd.DataFrame": +async def predict_dataframe(df: pd.DataFrame) -> pd.DataFrame: assert df["col1"].dtype == "int64" output = await py_model.predict_dataframe.async_run(df) dfo = pd.DataFrame() @@ -110,18 +115,16 @@ async def predict_file(f: FileLike[bytes]) -> bytes: @svc.api(input=Image(), output=Image(mime_type="image/bmp")) -async def echo_image(f: PILImage) -> "np.ndarray[t.Any, np.dtype[t.Any]]": +async def echo_image(f: PILImage) -> NDArray[t.Any]: assert isinstance(f, PILImage) - return np.array(f) # type: ignore[arg-type] + return np.array(f) @svc.api( input=Multipart(original=Image(), compared=Image()), output=Multipart(img1=Image(), img2=Image()), ) -async def predict_multi_images( - original: t.Dict[str, Image], compared: t.Dict[str, Image] -): +async def predict_multi_images(original: dict[str, Image], compared: dict[str, Image]): output_array = await py_model.predict_multi_ndarray.async_run( np.array(original), np.array(compared) ) @@ -130,13 +133,6 @@ async def predict_multi_images( # customise the service -from starlette.types import Send -from starlette.types import Scope -from starlette.types import ASGIApp -from starlette.types import Receive -from starlette.requests import Request - - class AllowPingMiddleware: def __init__( self, @@ -154,7 +150,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return -svc.add_asgi_middleware(AllowPingMiddleware) # type: ignore[arg-type] +svc.add_asgi_middleware(AllowPingMiddleware) # type: ignore (hint not yet supported for hooks) from fastapi import FastAPI diff --git a/tests/e2e/bento_server_http/tests/conftest.py b/tests/e2e/bento_server_http/tests/conftest.py new file mode 100644 index 0000000000..68052a720e --- /dev/null +++ b/tests/e2e/bento_server_http/tests/conftest.py @@ -0,0 +1,60 @@ +# pylint: disable=unused-argument +from __future__ import annotations + +import os +import sys +import typing as t +import subprocess +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from contextlib import ExitStack + + from _pytest.main import Session + from _pytest.nodes import Item + from _pytest.config import Config + from _pytest.fixtures import FixtureRequest as _PytestFixtureRequest + + class FixtureRequest(_PytestFixtureRequest): + param: str + + +PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +def pytest_collection_modifyitems( + session: Session, config: Config, items: list[Item] +) -> None: + subprocess.check_call([sys.executable, f"{os.path.join(PROJECT_DIR, 'train.py')}"]) + + +@pytest.fixture( + name="server_config_file", + params=["default.yml", "cors_enabled.yml"], + scope="session", +) +def fixture_server_config_file(request: FixtureRequest) -> str: + return os.path.join(PROJECT_DIR, "configs", request.param) + + +@pytest.mark.usefixtures("change_test_dir") +@pytest.fixture(scope="session") +def host( + bentoml_home: str, + deployment_mode: t.Literal["docker", "distributed", "standalone"], + server_config_file: str, + clean_context: ExitStack, +) -> t.Generator[str, None, None]: + from bentoml.testing.server import host_bento + + with host_bento( + "service:svc", + config_file=server_config_file, + project_path=PROJECT_DIR, + deployment_mode=deployment_mode, + bentoml_home=bentoml_home, + clean_context=clean_context, + ) as _host: + yield _host diff --git a/tests/e2e/bento_server_general_features/tests/test_io.py b/tests/e2e/bento_server_http/tests/test_io.py similarity index 87% rename from tests/e2e/bento_server_general_features/tests/test_io.py rename to tests/e2e/bento_server_http/tests/test_io.py index 0837c202a7..a826f7d19a 100644 --- a/tests/e2e/bento_server_general_features/tests/test_io.py +++ b/tests/e2e/bento_server_http/tests/test_io.py @@ -1,8 +1,9 @@ -# type: ignore[no-untyped-def] +from __future__ import annotations import io import sys import json +from typing import TYPE_CHECKING import numpy as np import pytest @@ -11,9 +12,16 @@ from bentoml.testing.utils import async_request from bentoml.testing.utils import parse_multipart_form +if TYPE_CHECKING: + import PIL.Image as PILImage +else: + from bentoml._internal.utils import LazyLoader + + PILImage = LazyLoader("PILImage", globals(), "PIL.Image") + @pytest.mark.asyncio -async def test_numpy(host): +async def test_numpy(host: str): await async_request( "POST", f"http://{host}/predict_ndarray_enforce_shape", @@ -55,7 +63,7 @@ async def test_numpy(host): @pytest.mark.asyncio -async def test_json(host): +async def test_json(host: str): ORIGIN = "http://bentoml.ai" await async_request( @@ -87,7 +95,7 @@ async def test_json(host): @pytest.mark.asyncio -async def test_obj(host): +async def test_obj(host: str): for obj in [1, 2.2, "str", [1, 2, 3], {"a": 1, "b": 2}]: obj_str = json.dumps(obj, separators=(",", ":")) await async_request( @@ -101,7 +109,7 @@ async def test_obj(host): @pytest.mark.asyncio -async def test_pandas(host): +async def test_pandas(host: str): import pandas as pd ORIGIN = "http://bentoml.ai" @@ -139,7 +147,7 @@ async def test_pandas(host): @pytest.mark.asyncio -async def test_file(host, bin_file): +async def test_file(host: str, bin_file: str): # Test File as binary with open(str(bin_file), "rb") as f: b = f.read() @@ -174,9 +182,7 @@ async def test_file(host, bin_file): @pytest.mark.asyncio -async def test_image(host, img_file): - import PIL.Image - +async def test_image(host: str, img_file: str): with open(str(img_file), "rb") as f1: img_bytes = f1.read() @@ -191,11 +197,11 @@ async def test_image(host, img_file): bio = io.BytesIO(body) bio.name = "test.bmp" - img = PIL.Image.open(bio) + img = PILImage.open(bio) array1 = np.array(img) - array2 = PIL.Image.open(img_file) + array2 = PILImage.open(img_file) - np.testing.assert_array_almost_equal(array1, array2) + np.testing.assert_array_almost_equal(array1, np.array(array2)) await async_request( "POST", @@ -216,12 +222,8 @@ async def test_image(host, img_file): ) -# SklearnRunner is not suppose to take multiple arguments -# TODO: move e2e tests to use a new bentoml.PickleModel module -@pytest.mark.skip @pytest.mark.asyncio -async def test_multipart_image_io(host, img_file): - import PIL.Image +async def test_multipart_image_io(host: str, img_file: str): from starlette.datastructures import UploadFile with open(img_file, "rb") as f1: @@ -230,16 +232,12 @@ async def test_multipart_image_io(host, img_file): form.add_field("original", f1.read(), content_type="image/bmp") form.add_field("compared", f2.read(), content_type="image/bmp") - status, headers, body = await async_request( - "POST", - f"http://{host}/predict_multi_images", - data=form, + _, headers, body = await async_request( + "POST", f"http://{host}/predict_multi_images", data=form, assert_status=200 ) - assert status == 200 - form = await parse_multipart_form(headers=headers, body=body) for _, v in form.items(): assert isinstance(v, UploadFile) - img = PIL.Image.open(v.file) + img = PILImage.open(v.file) assert np.array(img).shape == (10, 10, 3) diff --git a/tests/e2e/bento_server_general_features/tests/test_meta.py b/tests/e2e/bento_server_http/tests/test_meta.py similarity index 56% rename from tests/e2e/bento_server_general_features/tests/test_meta.py rename to tests/e2e/bento_server_http/tests/test_meta.py index 48c2fd35c6..1761a18cb8 100644 --- a/tests/e2e/bento_server_general_features/tests/test_meta.py +++ b/tests/e2e/bento_server_http/tests/test_meta.py @@ -1,5 +1,8 @@ # pylint: disable=redefined-outer-name -# type: ignore[no-untyped-def] + +from __future__ import annotations + +from pathlib import Path import pytest @@ -45,7 +48,11 @@ async def test_cors(host: str, server_config_file: str) -> None: "Access-Control-Request-Headers": "Content-Type", }, ) - if server_config_file == "server_config_cors_enabled.yml": + + # all test configs lives under ../configs, but we are only interested in name. + fname = Path(server_config_file).name + + if fname == "cors_enabled.yml": assert status == 200 else: assert status != 200 @@ -56,7 +63,7 @@ async def test_cors(host: str, server_config_file: str) -> None: headers={"Content-Type": "application/json", "Origin": ORIGIN}, data='"hi"', ) - if server_config_file == "server_config_cors_enabled.yml": + if fname == "cors_enabled.yml": assert status == 200 assert body == b'"hi"' assert headers["Access-Control-Allow-Origin"] in ("*", ORIGIN) @@ -69,10 +76,10 @@ async def test_cors(host: str, server_config_file: str) -> None: def test_service_init_checks(): - py_model1 = bentoml.picklable_model.get("py_model.case-1.e2e").to_runner( + py_model1 = bentoml.picklable_model.get("py_model.case-1.http.e2e").to_runner( name="invalid" ) - py_model2 = bentoml.picklable_model.get("py_model.case-1.e2e").to_runner( + py_model2 = bentoml.picklable_model.get("py_model.case-1.http.e2e").to_runner( name="invalid" ) with pytest.raises(ValueError) as excinfo: @@ -85,84 +92,11 @@ def test_service_init_checks(): def test_dunder_string(): - runner = bentoml.picklable_model.get("py_model.case-1.e2e").to_runner() + runner = bentoml.picklable_model.get("py_model.case-1.http.e2e").to_runner() svc = bentoml.Service(name="dunder_string", runners=[runner]) assert ( str(svc) - == 'bentoml.Service(name="dunder_string", runners=[py_model.case-1.e2e])' - ) - - -""" -@pytest.since_bentoml_version("0.11.0+0") -@pytest.mark.asyncio -async def test_customized_route(host): - CUSTOM_ROUTE = "$~!@%^&*()_-+=[]\\|;:,./predict" - - def path_in_docs(response_body): - d = json.loads(response_body.decode()) - return f"/{CUSTOM_ROUTE}" in d['paths'] - - await async_request( - "GET", - f"http://{host}/docs.json", - headers=(("Content-Type", "application/json"),), - assert_data=path_in_docs, - ) - - await async_request( - "POST", - f"http://{host}/{CUSTOM_ROUTE}", - headers=(("Content-Type", "application/json"),), - data=json.dumps("hello"), - assert_data=bytes('"hello"', 'ascii'), - ) - - -@pytest.mark.asyncio -async def test_customized_request_schema(host): - def has_customized_schema(doc_bytes): - json_str = doc_bytes.decode() - return "field1" in json_str - - await async_request( - "GET", - f"http://{host}/docs.json", - headers=(("Content-Type", "application/json"),), - assert_data=has_customized_schema, - ) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "metrics", - [ - pytest.param( - '_mb_request_duration_seconds_count', - marks=pytest.mark.skipif( - psutil.MACOS, reason="microbatch metrics is not shown in MacOS tests" - ), - ), - pytest.param( - '_mb_request_total', - marks=pytest.mark.skipif( - psutil.MACOS, reason="microbatch metrics is not shown in MacOS tests" - ), - ), - '_request_duration_seconds_bucket', - ], -) -async def test_api_server_metrics(host, metrics): - await async_request( - "POST", f"http://{host}/echo_json", data='"hi"', - ) - - await async_request( - "GET", - f"http://{host}/metrics", - assert_status=200, - assert_data=lambda d: metrics in d.decode(), + == 'bentoml.Service(name="dunder_string", runners=[py_model.case-1.http.e2e])' ) -""" diff --git a/tests/e2e/bento_server_general_features/train.py b/tests/e2e/bento_server_http/train.py similarity index 83% rename from tests/e2e/bento_server_general_features/train.py rename to tests/e2e/bento_server_http/train.py index 8fd5a68339..c36f264acc 100644 --- a/tests/e2e/bento_server_general_features/train.py +++ b/tests/e2e/bento_server_http/train.py @@ -1,11 +1,10 @@ -import pickle_model - -import bentoml.picklable_model +if __name__ == "__main__": + import pickle_model + import bentoml -def train(): bentoml.picklable_model.save_model( - "py_model.case-1.e2e", + "py_model.case-1.http.e2e", pickle_model.PickleModel(), signatures={ "predict_file": {"batchable": True}, @@ -18,7 +17,3 @@ def train(): }, external_modules=[pickle_model], ) - - -if __name__ == "__main__": - train() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d426aafe6d..75ca4e9781 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,11 +1,8 @@ import typing as t -import tempfile from typing import TYPE_CHECKING import pytest -from bentoml._internal.models import ModelStore - if TYPE_CHECKING: from _pytest.nodes import Item from _pytest.config import Config @@ -13,12 +10,6 @@ def pytest_addoption(parser: "Parser") -> None: - parser.addoption( - "--runslow", action="store_true", default=False, help="run slow tests" - ) - parser.addoption( - "--gpus", action="store_true", default=False, help="run gpus related tests" - ) parser.addoption( "--disable-tf-eager-execution", action="store_true", @@ -27,6 +18,14 @@ def pytest_addoption(parser: "Parser") -> None: ) +def pytest_configure(config: "Config") -> None: + # We will inject marker documentation here. + config.addinivalue_line( + "markers", + "requires_eager_execution: requires enable eager execution to run Tensorflow-based tests.", + ) + + def pytest_collection_modifyitems(config: "Config", items: t.List["Item"]) -> None: if config.getoption("--disable-tf-eager-execution"): try: @@ -35,20 +34,8 @@ def pytest_collection_modifyitems(config: "Config", items: t.List["Item"]) -> No disable_eager_execution() except ImportError: return - elif config.getoption("--gpus"): - return - skip_gpus = pytest.mark.skip(reason="Skip gpus tests") requires_eager_execution = pytest.mark.skip(reason="Requires eager execution") for item in items: - if "gpus" in item.keywords: - item.add_marker(skip_gpus) if "requires_eager_execution" in item.keywords: item.add_marker(requires_eager_execution) - - -def pytest_sessionstart(session): - path = tempfile.mkdtemp("bentoml-pytest") - from bentoml._internal.configuration.containers import BentoMLContainer - - BentoMLContainer.model_store.set(ModelStore(path)) diff --git a/tests/integration/frameworks/test_frameworks.py b/tests/integration/frameworks/test_frameworks.py index a2c059aecd..6ce5fd34d8 100644 --- a/tests/integration/frameworks/test_frameworks.py +++ b/tests/integration/frameworks/test_frameworks.py @@ -308,7 +308,7 @@ def test_runner_cpu( ) -@pytest.mark.gpus +@pytest.mark.requires_gpus def test_runner_nvidia_gpu( framework: types.ModuleType, test_model: FrameworkTestModel, diff --git a/tests/unit/_internal/bento/test_bento.py b/tests/unit/_internal/bento/test_bento.py index 7a208e6bae..c897088cb9 100644 --- a/tests/unit/_internal/bento/test_bento.py +++ b/tests/unit/_internal/bento/test_bento.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-argument from __future__ import annotations import os @@ -141,14 +142,11 @@ def test_bento_info(tmpdir: Path): assert bentoinfo_b_from_yaml == bentoinfo_b -def build_test_bento(model_store: ModelStore) -> Bento: +def build_test_bento() -> Bento: bento_cfg = BentoBuildConfig( "simplebento.py:svc", include=["*.py", "config.json", "somefile", "*dir*", ".bentoignore"], - exclude=[ - "*.storage", - "/somefile", - ], + exclude=["*.storage", "/somefile", "/subdir2"], conda={ "environment_yml": "./environment.yaml", }, @@ -175,10 +173,10 @@ def fs_identical(fs1: fs.base.FS, fs2: fs.base.FS): @pytest.mark.usefixtures("change_test_dir") -def test_bento_export(tmpdir: "Path", dummy_model_store: "ModelStore"): +def test_bento_export(tmpdir: "Path", model_store: "ModelStore"): working_dir = os.getcwd() - testbento = build_test_bento(dummy_model_store) + testbento = build_test_bento() # Bento build will change working dir to the build_context, this will reset it os.chdir(working_dir) @@ -316,9 +314,9 @@ def test_bento_export(tmpdir: "Path", dummy_model_store: "ModelStore"): @pytest.mark.usefixtures("change_test_dir") -def test_bento(dummy_model_store: ModelStore): +def test_bento(model_store: ModelStore): start = datetime.now(timezone.utc) - bento = build_test_bento(dummy_model_store) + bento = build_test_bento() end = datetime.now(timezone.utc) assert bento.info.bentoml_version == BENTOML_VERSION diff --git a/tests/unit/_internal/io/test_file.py b/tests/unit/_internal/io/test_file.py index eebb7d6cc6..39edb5a85c 100644 --- a/tests/unit/_internal/io/test_file.py +++ b/tests/unit/_internal/io/test_file.py @@ -1,8 +1,19 @@ from __future__ import annotations +import io +from typing import TYPE_CHECKING + import pytest from bentoml.io import File +from bentoml.exceptions import BadInput + +if TYPE_CHECKING: + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() def test_file_openapi_schema(): @@ -27,3 +38,47 @@ def test_file_openapi_request_responses(mime_type: str): assert responses.content assert mime_type in responses.content + + +@pytest.mark.asyncio +async def test_from_proto(bin_file: str): + with open(bin_file, "rb") as f: + content = f.read() + res = await File().from_proto(content) + assert res.read() == b"\x810\x899" + + +@pytest.mark.asyncio +async def test_exception_from_proto(): + with pytest.raises(AssertionError): + await File().from_proto(pb.NDArray(string_values="asdf")) # type: ignore (testing exceptions) + await File().from_proto("") # type: ignore (testing exceptions) + with pytest.raises(BadInput) as exc_info: + await File(mime_type="image/jpeg").from_proto( + pb.File(kind=pb.File.FILE_TYPE_BYTES, content=b"asdf") + ) + assert "Inferred mime_type from 'kind' is" in str(exc_info.value) + with pytest.raises(BadInput) as exc_info: + await File(mime_type="image/jpeg").from_proto( + pb.File(kind=123, content=b"asdf") # type: ignore (testing exceptions) + ) + assert "is not a valid File kind." in str(exc_info.value) + with pytest.raises(BadInput) as exc_info: + await File(mime_type="image/jpeg").from_proto( + pb.File(kind=pb.File.FILE_TYPE_JPEG) + ) + assert "Content is empty!" == str(exc_info.value) + + +@pytest.mark.asyncio +async def test_exception_to_proto(): + with pytest.raises(BadInput) as exc_info: + await File(mime_type="application/bentoml.vnd").to_proto(io.BytesIO(b"asdf")) + assert "doesn't have a corresponding File 'kind'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_to_proto() -> None: + assert await File(mime_type="image/bmp").to_proto(io.BytesIO(b"asdf")) == pb.File( + kind=pb.File.FILE_TYPE_BMP, content=b"asdf" + ) diff --git a/tests/unit/_internal/io/test_image.py b/tests/unit/_internal/io/test_image.py index f16806b0d8..89a4a47d37 100644 --- a/tests/unit/_internal/io/test_image.py +++ b/tests/unit/_internal/io/test_image.py @@ -1,10 +1,36 @@ from __future__ import annotations +import io +from typing import TYPE_CHECKING + import pytest from bentoml.io import Image +from bentoml.exceptions import BadInput from bentoml.exceptions import InvalidArgument +if TYPE_CHECKING: + import numpy as np + import PIL.Image as PILImage + + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + pb, _ = import_generated_stubs() + np = LazyLoader("np", globals(), "numpy") + PILImage = LazyLoader("PILImage", globals(), "PIL.Image") + + +def test_invalid_init(): + with pytest.raises(InvalidArgument) as exc_info: + Image(mime_type="application/vnd.bentoml+json") + assert "Invalid Image mime_type" in str(exc_info.value) + with pytest.raises(InvalidArgument) as exc_info: + Image(pilmode="asdf") + assert "Invalid Image pilmode" in str(exc_info.value) + def test_image_openapi_schema(): assert Image().openapi_schema().type == "string" @@ -31,3 +57,52 @@ def test_image_openapi_request_responses(mime_type: str): assert responses.content assert mime_type in responses.content + + +@pytest.mark.asyncio +async def test_from_proto(img_file: str): + with open(img_file, "rb") as f: + content = f.read() + res = await Image(mime_type="image/bmp").from_proto(content) + assert_file = PILImage.open(img_file) + np.testing.assert_array_almost_equal(np.array(res), np.array(assert_file)) + + +@pytest.mark.asyncio +async def test_exception_from_proto(): + with pytest.raises(AssertionError): + await Image().from_proto(pb.NDArray(string_values="asdf")) # type: ignore (testing exception) + await Image().from_proto("") # type: ignore (testing exception) + with pytest.raises(BadInput) as exc_info: + await Image(mime_type="image/jpeg").from_proto( + pb.File(kind=pb.File.FILE_TYPE_BYTES, content=b"asdf") + ) + assert "Inferred mime_type from 'kind' is" in str(exc_info.value) + with pytest.raises(BadInput) as exc_info: + await Image(mime_type="image/jpeg").from_proto(pb.File(kind=123, content=b"asdf")) # type: ignore (testing exception) + assert "is not a valid File kind." in str(exc_info.value) + with pytest.raises(BadInput) as exc_info: + await Image(mime_type="image/jpeg").from_proto( + pb.File(kind=pb.File.FILE_TYPE_JPEG) + ) + assert "Content is empty!" == str(exc_info.value) + + +@pytest.mark.asyncio +async def test_exception_to_proto(): + with pytest.raises(BadInput) as exc_info: + await Image().to_proto(io.BytesIO(b"asdf")) # type: ignore (testing exception) + assert "Unsupported Image type received:" in str(exc_info.value) + with pytest.raises(BadInput) as exc_info: + example = np.random.rand(255, 255, 3) + await Image(mime_type="image/sgi").to_proto(example) + assert "doesn't have a corresponding File 'kind'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_to_proto(img_file: str) -> None: + with open(img_file, "rb") as f: + content = f.read() + img = PILImage.open(io.BytesIO(content)) + res = await Image(mime_type="image/bmp").to_proto(img) + assert res.kind == pb.File.FILE_TYPE_BMP diff --git a/tests/unit/_internal/io/test_json.py b/tests/unit/_internal/io/test_json.py index fbd40b5381..4a1c0a6472 100644 --- a/tests/unit/_internal/io/test_json.py +++ b/tests/unit/_internal/io/test_json.py @@ -15,12 +15,23 @@ import pydantic from bentoml.io import JSON +from bentoml.exceptions import BadInput +from bentoml.exceptions import UnprocessableEntity +from bentoml._internal.utils.pkg import pkg_version_info from bentoml._internal.io_descriptors.json import DefaultJsonEncoder if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture + from google.protobuf import struct_pb2 + from bentoml.grpc.v1alpha1 import service_pb2 as pb from bentoml._internal.service.openapi.specification import Schema +else: + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + pb, _ = import_generated_stubs() + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") @dataclass @@ -70,6 +81,24 @@ class Config: ) +@pytest.mark.skipif( + pkg_version_info("pydantic")[0] < 2 and pkg_version_info("bentoml")[:2] <= (1, 1), + reason="Pydantic 2.x is not yet supported until official releases of Pydantic.", +) +def test_not_yet_supported_pydantic(): + with pytest.raises(UnprocessableEntity) as exc_info: + JSON(pydantic_model=Nested) + assert "pydantic 2.x is not yet supported" in str(exc_info.value) + + +def test_invalid_init(): + with pytest.raises(AssertionError) as exc_info: + JSON(pydantic_model=ExampleAttrsClass) # type: ignore (testing exception) + assert "'pydantic_model' must be a subclass of 'pydantic.BaseModel'." in str( + exc_info.value + ) + + def test_json_encoder_dataclass_like(): expected = '{"name":"test","endpoints":["predict","health"]}' assert ( @@ -172,3 +201,93 @@ def test_json_openapi_request_responses(): assert responses.content assert "application/json" in responses.content + + +@pytest.mark.asyncio +async def test_from_proto(): + res = await JSON().from_proto( + b'{"request_id": "123", "iris_features": {"sepal_len":2.34,"sepal_width":1.58, "petal_len":6.52, "petal_width":3.23}}', + ) + assert res == { + "request_id": "123", + "iris_features": { + "sepal_len": 2.34, + "sepal_width": 1.58, + "petal_len": 6.52, + "petal_width": 3.23, + }, + } + res = await JSON(pydantic_model=BaseSchema).from_proto( + b'{"name":"test","endpoints":["predict","health"]}', + ) + assert isinstance(res, pydantic.BaseModel) and res == BaseSchema( + name="test", endpoints=["predict", "health"] + ) + res = await JSON(pydantic_model=Nested).from_proto( + struct_pb2.Value( + struct_value=struct_pb2.Struct( + fields={ + "toplevel": struct_pb2.Value(string_value="test"), + "nested": struct_pb2.Value( + struct_value=struct_pb2.Struct( + fields={ + "name": struct_pb2.Value(string_value="test"), + "endpoints": struct_pb2.Value( + list_value=struct_pb2.ListValue( + values=[ + struct_pb2.Value(string_value="predict"), + struct_pb2.Value(string_value="health"), + ] + ) + ), + } + ), + ), + } + ) + ), + ) + assert isinstance(res, pydantic.BaseModel) and res == Nested( + toplevel="test", + nested=BaseSchema(name="test", endpoints=["predict", "health"]), + ) + + +@pytest.mark.asyncio +async def test_exception_from_proto(): + with pytest.raises(AssertionError): + await JSON().from_proto(pb.NDArray(string_values="asdf")) # type: ignore (testing exception) + await JSON().from_proto("") # type: ignore (testing exception) + with pytest.raises(BadInput, match="Invalid JSON input received*"): + await JSON(pydantic_model=Nested).from_proto( + struct_pb2.Value(string_value="asdf") + ) + with pytest.raises(BadInput, match="Invalid JSON input received*"): + await JSON(pydantic_model=Nested).from_proto(b"") + await JSON().from_proto(b"\n?xfa") + + +@pytest.mark.asyncio +async def test_exception_to_proto(): + with pytest.raises(TypeError): + await JSON().to_proto(b"asdf") # type: ignore (testing exception) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "o", + [ + {"asdf": 1}, + ["asdf", "1"], + "asdf", + 1.0, + 1, + True, + BaseSchema(name="test", endpoints=["predict", "health"]), + np.random.rand(6, 6), + None, + ], +) +async def test_to_proto(o: t.Any) -> None: + res = await JSON().to_proto(o) + assert res and isinstance(res, struct_pb2.Value) diff --git a/tests/unit/_internal/io/test_multipart.py b/tests/unit/_internal/io/test_multipart.py index da51c22708..ef2d1ffd43 100644 --- a/tests/unit/_internal/io/test_multipart.py +++ b/tests/unit/_internal/io/test_multipart.py @@ -1,5 +1,8 @@ from __future__ import annotations +import io +from typing import TYPE_CHECKING + import pytest from bentoml.io import JSON @@ -7,16 +10,35 @@ from bentoml.io import Multipart from bentoml.exceptions import InvalidArgument -multipart = Multipart(arg1=JSON(), arg2=Image(pilmode="RGB")) +example = Multipart(arg1=JSON(), arg2=Image(mime_type="image/bmp", pilmode="RGB")) + +if TYPE_CHECKING: + import PIL.Image as PILImage + from google.protobuf import struct_pb2 + from google.protobuf import wrappers_pb2 + + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + pb, _ = import_generated_stubs() + np = LazyLoader("np", globals(), "numpy") + PILImage = LazyLoader("PILImage", globals(), "PIL.Image") + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") def test_invalid_multipart(): - with pytest.raises(InvalidArgument): + with pytest.raises( + InvalidArgument, + match="Multipart IO can not contain nested Multipart IO descriptor", + ): _ = Multipart(arg1=Multipart(arg1=JSON())) def test_multipart_openapi_schema(): - schema = multipart.openapi_schema() + schema = example.openapi_schema() assert schema.type == "object" assert schema.properties @@ -24,9 +46,56 @@ def test_multipart_openapi_schema(): def test_multipart_openapi_request_responses(): - request_body = multipart.openapi_request_body() + request_body = example.openapi_request_body() assert request_body.required - responses = multipart.openapi_responses() + responses = example.openapi_responses() assert responses.content + + +@pytest.mark.asyncio +async def test_exception_from_to_proto(): + with pytest.raises(InvalidArgument): + await example.from_proto(b"") # type: ignore (test exception) + with pytest.raises(InvalidArgument) as e: + await example.from_proto( + pb.Multipart( + fields={"asdf": pb.Part(text=wrappers_pb2.StringValue(value="asdf"))} + ) + ) + assert f"'{example!r}' accepts the following keys: " in str(e.value) + with pytest.raises(InvalidArgument) as e: + await example.to_proto( + {"asdf": pb.Part(text=wrappers_pb2.StringValue(value="asdf"))} + ) + assert f"'{example!r}' accepts the following keys: " in str(e.value) + + +@pytest.mark.asyncio +async def test_multipart_from_to_proto(img_file: str): + with open(img_file, "rb") as f: + img = f.read() + obj = await example.from_proto( + pb.Multipart( + fields={ + "arg1": pb.Part( + json=struct_pb2.Value( + struct_value=struct_pb2.Struct( + fields={"asd": struct_pb2.Value(string_value="asd")} + ) + ) + ), + "arg2": pb.Part(file=pb.File(kind=pb.File.FILE_TYPE_BMP, content=img)), + } + ) + ) + assert obj["arg1"] == {"asd": "asd"} + assert_file = PILImage.open(img_file) + np.testing.assert_array_almost_equal(np.array(obj["arg2"]), np.array(assert_file)) + + message = await example.to_proto( + {"arg1": {"asd": "asd"}, "arg2": PILImage.open(io.BytesIO(img))} + ) + assert isinstance(message, pb.Multipart) + assert message.fields["arg1"].json.struct_value.fields["asd"].string_value == "asd" diff --git a/tests/unit/_internal/io/test_numpy.py b/tests/unit/_internal/io/test_numpy.py index 4f4d5765cf..ff1dc62795 100644 --- a/tests/unit/_internal/io/test_numpy.py +++ b/tests/unit/_internal/io/test_numpy.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-argument from __future__ import annotations import logging @@ -9,12 +10,19 @@ from bentoml.io import NumpyNdarray from bentoml.exceptions import BadInput +from bentoml.exceptions import InvalidArgument from bentoml.exceptions import BentoMLException from bentoml._internal.service.openapi.specification import Schema if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + class ExampleGeneric(str, np.generic): pass @@ -35,6 +43,15 @@ def test_invalid_dtype(): assert "expects a 'numpy.array'" in str(e.value) +def test_invalid_init(): + with pytest.raises(InvalidArgument) as exc_info: + NumpyNdarray(enforce_dtype=True) + assert "'dtype' must be specified" in str(exc_info.value) + with pytest.raises(InvalidArgument) as exc_info: + NumpyNdarray(enforce_shape=True) + assert "'shape' must be specified" in str(exc_info.value) + + @pytest.mark.parametrize("dtype, expected", [("float", "number"), (">U8", "integer")]) def test_numpy_to_openapi_types(dtype: str, expected: str): assert NumpyNdarray(dtype=dtype)._openapi_types() == expected # type: ignore (private functions warning) @@ -99,3 +116,95 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture): with caplog.at_level(logging.DEBUG): example.validate_array(np.array("asdf")) assert "Failed to reshape" in caplog.text + + +def generate_1d_array(dtype: pb.NDArray.DType.ValueType, length: int = 3): + if dtype == pb.NDArray.DTYPE_BOOL: + return [True] * length + elif dtype == pb.NDArray.DTYPE_STRING: + return ["a"] * length + else: + return [1] * length + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "dtype", + filter(lambda x: x > 0, [v.number for v in pb.NDArray.DType.DESCRIPTOR.values]), +) +async def test_from_proto(dtype: pb.NDArray.DType.ValueType) -> None: + from bentoml._internal.io_descriptors.numpy import dtypepb_to_fieldpb_map + from bentoml._internal.io_descriptors.numpy import dtypepb_to_npdtype_map + + np.testing.assert_array_equal( + await NumpyNdarray(dtype=example.dtype, shape=example.shape).from_proto( + example.ravel().tobytes(), + ), + example, + ) + # DTYPE_UNSPECIFIED + np.testing.assert_array_equal( + await NumpyNdarray().from_proto( + pb.NDArray(dtype=pb.NDArray.DType.DTYPE_UNSPECIFIED), + ), + np.empty(0), + ) + np.testing.assert_array_equal( + await NumpyNdarray().from_proto( + pb.NDArray(shape=tuple(example.shape)), + ), + np.empty(tuple(example.shape)), + ) + # different DTYPE + np.testing.assert_array_equal( + await NumpyNdarray().from_proto( + pb.NDArray( + dtype=dtype, + **{dtypepb_to_fieldpb_map()[dtype]: generate_1d_array(dtype)}, + ), + ), + np.array(generate_1d_array(dtype), dtype=dtypepb_to_npdtype_map()[dtype]), + ) + # given shape from message. + np.testing.assert_array_equal( + await NumpyNdarray().from_proto( + pb.NDArray(shape=[3, 3], float_values=[1.0] * 9), + ), + np.array([[1.0] * 3] * 3), + ) + + +@pytest.mark.asyncio +async def test_exception_from_proto(): + with pytest.raises(AssertionError): + await NumpyNdarray().from_proto(pb.NDArray(string_values="asdf")) + await NumpyNdarray().from_proto(pb.File(content=b"asdf")) # type: ignore (testing exception) + with pytest.raises(BadInput): + await NumpyNdarray().from_proto(b"asdf") + with pytest.raises(BadInput) as exc_info: + await NumpyNdarray().from_proto(pb.NDArray(dtype=123, string_values="asdf")) # type: ignore (testing exception) + assert "123 is invalid." == str(exc_info.value) + with pytest.raises(BadInput) as exc_info: + await NumpyNdarray().from_proto( + pb.NDArray(string_values="asdf", float_values=[1.0, 2.0]) + ) + assert "Array contents can only be one of" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_exception_to_proto(): + with pytest.raises(BadInput): + await NumpyNdarray(dtype=np.float32, enforce_dtype=True).to_proto( + np.array("asdf") + ) + with pytest.raises(BadInput): + await NumpyNdarray(dtype=np.dtype(np.void)).to_proto(np.array("asdf")) + + +@pytest.mark.asyncio +async def test_to_proto() -> None: + assert await NumpyNdarray().to_proto(example) == pb.NDArray( + shape=example.shape, + dtype=pb.NDArray.DType.DTYPE_DOUBLE, + double_values=example.ravel().tolist(), + ) diff --git a/tests/unit/_internal/io/test_text.py b/tests/unit/_internal/io/test_text.py index d2c3be2bfa..3b2cecb1f3 100644 --- a/tests/unit/_internal/io/test_text.py +++ b/tests/unit/_internal/io/test_text.py @@ -1,10 +1,23 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from bentoml.io import Text from bentoml.exceptions import BentoMLException +if TYPE_CHECKING: + from google.protobuf import wrappers_pb2 + + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + pb, _ = import_generated_stubs() + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") + def test_text_openapi_schema(): assert Text().openapi_schema().type == "string" @@ -28,3 +41,24 @@ def test_text_openapi_request_responses(): assert responses.content assert mime_type in responses.content + + +@pytest.mark.asyncio +async def test_from_proto(): + res = await Text().from_proto(wrappers_pb2.StringValue(value="asdf")) + assert res == "asdf" + res = await Text().from_proto(b"asdf") + assert res == "asdf" + + +@pytest.mark.asyncio +async def test_exception_from_proto(): + with pytest.raises(AssertionError): + await Text().from_proto(pb.NDArray(string_values="asdf")) # type: ignore (testing exception) + await Text().from_proto(b"") + + +@pytest.mark.asyncio +async def test_to_proto() -> None: + res = await Text().to_proto("asdf") + assert res.value == "asdf" diff --git a/tests/unit/_internal/models/test_model.py b/tests/unit/_internal/models/test_model.py index 3da70cdef6..5f1b9db530 100644 --- a/tests/unit/_internal/models/test_model.py +++ b/tests/unit/_internal/models/test_model.py @@ -14,7 +14,7 @@ from bentoml import Tag from bentoml.exceptions import BentoMLException -from bentoml._internal.models import ModelContext +from bentoml.testing.pytest import TEST_MODEL_CONTEXT from bentoml._internal.models import ModelOptions as InternalModelOptions from bentoml._internal.models.model import Model from bentoml._internal.models.model import ModelInfo @@ -24,16 +24,12 @@ if TYPE_CHECKING: from pathlib import Path -TEST_MODEL_CONTEXT = ModelContext( - framework_name="testing", framework_versions={"testing": "v1"} -) - TEST_PYTHON_VERSION = f"{pyver.major}.{pyver.minor}.{pyver.micro}" expected_yaml = """\ name: test version: v1 -module: test_model +module: tests.unit._internal.models.test_model labels: label: stringvalue options: @@ -85,6 +81,7 @@ class ModelOptions(InternalModelOptions): option_c: list[float] +@pytest.mark.usefixtures("change_test_dir") def test_model_info(tmpdir: "Path"): start = datetime.now(timezone.utc) modelinfo_a = ModelInfo( diff --git a/tests/unit/_internal/runner/container.py b/tests/unit/_internal/runner/test_container.py similarity index 88% rename from tests/unit/_internal/runner/container.py rename to tests/unit/_internal/runner/test_container.py index da784eb2b4..3a5473c4d3 100644 --- a/tests/unit/_internal/runner/container.py +++ b/tests/unit/_internal/runner/test_container.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + import numpy as np import pandas as pd import pytest @@ -7,7 +11,7 @@ @pytest.mark.parametrize("batch_dim_exc", [AssertionError]) @pytest.mark.parametrize("wrong_batch_dim", [1, 19]) -def test_default_container(batch_dim_exc, wrong_batch_dim): +def test_default_container(batch_dim_exc: t.Type[Exception], wrong_batch_dim: int): l1 = [1, 2, 3] l2 = [3, 4, 5, 6] @@ -31,7 +35,7 @@ def _generator(): yield "cherry" assert c.DefaultContainer.from_payload( - c.DefaultContainer.to_payload(_generator()) + c.DefaultContainer.to_payload(_generator(), batch_dim=0) ) == list(_generator()) assert c.DefaultContainer.from_batch_payloads( @@ -40,7 +44,7 @@ def _generator(): @pytest.mark.parametrize("batch_dim", [0, 1]) -def test_ndarray_container(batch_dim): +def test_ndarray_container(batch_dim: int): arr1 = np.ones((3, 3)) if batch_dim == 0: @@ -58,7 +62,8 @@ def test_ndarray_container(batch_dim): assert (arr2 == restored_arr2).all() assert ( - c.NdarrayContainer.from_payload(c.NdarrayContainer.to_payload(arr1)) == arr1 + c.NdarrayContainer.from_payload(c.NdarrayContainer.to_payload(arr1, batch_dim)) + == arr1 ).all() restored_batch, restored_indices = c.NdarrayContainer.from_batch_payloads( @@ -71,7 +76,7 @@ def test_ndarray_container(batch_dim): @pytest.mark.parametrize("batch_dim_exc", [AssertionError]) @pytest.mark.parametrize("wrong_batch_dim", [1, 19]) -def test_pandas_container(batch_dim_exc, wrong_batch_dim): +def test_pandas_container(batch_dim_exc: t.Type[Exception], wrong_batch_dim: int): cols = ["a", "b", "c"] arr1 = np.ones((3, 3)) @@ -89,7 +94,7 @@ def test_pandas_container(batch_dim_exc, wrong_batch_dim): assert df2.equals(restored_df2) assert c.PandasDataFrameContainer.from_payload( - c.PandasDataFrameContainer.to_payload(df1) + c.PandasDataFrameContainer.to_payload(df1, batch_dim=0) ).equals(df1) restored_batch, restored_indices = c.PandasDataFrameContainer.from_batch_payloads( diff --git a/tests/unit/_internal/runner/utils.py b/tests/unit/_internal/runner/utils.py deleted file mode 100644 index f83df27a03..0000000000 --- a/tests/unit/_internal/runner/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import numpy as np - -from bentoml._internal.types import LazyType - - -def test_typeref(): - - # assert __eq__ - assert LazyType("numpy", "ndarray") == np.ndarray - assert LazyType("numpy", "ndarray") == LazyType(type(np.array([2, 3]))) - - # evaluate - assert LazyType("numpy", "ndarray").get_class() == np.ndarray diff --git a/tests/unit/_internal/test_configuration.py b/tests/unit/_internal/test_configuration.py index 18d303c1fa..534882efac 100644 --- a/tests/unit/_internal/test_configuration.py +++ b/tests/unit/_internal/test_configuration.py @@ -1,19 +1,112 @@ -from tempfile import NamedTemporaryFile +from __future__ import annotations +import typing as t +import logging +from typing import TYPE_CHECKING + +import pytest + +from bentoml.exceptions import BentoMLConfigException from bentoml._internal.configuration.containers import BentoMLConfiguration +if TYPE_CHECKING: + from pathlib import Path + + from _pytest.logging import LogCaptureFixture + from simple_di.providers import ConfigDictType + -def get_bentomlconfiguration_from_str(config_str: str): - tmpfile = NamedTemporaryFile(mode="w+", delete=False) - tmpfile.write(config_str) - tmpfile.flush() - tmpfile.close() +@pytest.fixture(scope="function", name="config_cls") +def fixture_config_cls(tmp_path: Path) -> t.Callable[[str], ConfigDictType]: + def inner(config: str) -> ConfigDictType: + path = tmp_path / "configuration.yaml" + path.write_text(config) + return BentoMLConfiguration(override_config_file=path.__fspath__()).as_dict() - bentoml_cfg = BentoMLConfiguration(override_config_file=tmpfile.name).as_dict() - return bentoml_cfg + return inner + + +@pytest.mark.usefixtures("config_cls") +def test_backward_configuration( + config_cls: t.Callable[[str], ConfigDictType], caplog: LogCaptureFixture +): + OLD_CONFIG = """\ +api_server: + max_request_size: 8624612341 + port: 5000 + host: 0.0.0.0 +""" + with caplog.at_level(logging.WARNING): + bentoml_cfg = config_cls(OLD_CONFIG) + assert all( + i not in bentoml_cfg["api_server"] for i in ("max_request_size", "port", "host") + ) + assert "cors" not in bentoml_cfg["api_server"] + assert bentoml_cfg["api_server"]["http"]["host"] == "0.0.0.0" + assert bentoml_cfg["api_server"]["http"]["port"] == 5000 + + +@pytest.mark.usefixtures("config_cls") +def test_validate(config_cls: t.Callable[[str], ConfigDictType]): + INVALID_CONFIG = """\ +api_server: + host: localhost +""" + with pytest.raises( + BentoMLConfigException, match="Invalid configuration file was given:*" + ): + config_cls(INVALID_CONFIG) + + +@pytest.mark.usefixtures("config_cls") +def test_backward_warning( + config_cls: t.Callable[[str], ConfigDictType], caplog: LogCaptureFixture +): + OLD_HOST = """\ +api_server: + host: 0.0.0.0 +""" + with caplog.at_level(logging.WARNING): + config_cls(OLD_HOST) + assert "field 'api_server.host' is deprecated" in caplog.text + caplog.clear() + + OLD_PORT = """\ +api_server: + port: 4096 +""" + with caplog.at_level(logging.WARNING): + config_cls(OLD_PORT) + assert "field 'api_server.port' is deprecated" in caplog.text + caplog.clear() + + OLD_MAX_REQUEST_SIZE = """\ +api_server: + max_request_size: 8624612341 +""" + with caplog.at_level(logging.WARNING): + config_cls(OLD_MAX_REQUEST_SIZE) + assert ( + "'api_server.max_request_size' is deprecated and has become obsolete." + in caplog.text + ) + caplog.clear() + + OLD_CORS = """\ +api_server: + cors: + enabled: false +""" + with caplog.at_level(logging.WARNING): + config_cls(OLD_CORS) + assert "field 'api_server.cors' is deprecated" in caplog.text + caplog.clear() -def test_bentoml_configuration_runner_override(): +@pytest.mark.usefixtures("config_cls") +def test_bentoml_configuration_runner_override( + config_cls: t.Callable[[str], ConfigDictType] +): OVERRIDE_RUNNERS = """\ runners: batching: @@ -40,7 +133,7 @@ def test_bentoml_configuration_runner_override(): enabled: True """ - bentoml_cfg = get_bentomlconfiguration_from_str(OVERRIDE_RUNNERS) + bentoml_cfg = config_cls(OVERRIDE_RUNNERS) runner_cfg = bentoml_cfg["runners"] # test_runner_1 @@ -73,13 +166,14 @@ def test_bentoml_configuration_runner_override(): assert test_runner_batching["resources"]["cpu"] == 4 # should use global -def test_runner_gpu_configuration(): +@pytest.mark.usefixtures("config_cls") +def test_runner_gpu_configuration(config_cls: t.Callable[[str], ConfigDictType]): GPU_INDEX = """\ runners: resources: nvidia.com/gpu: [1, 2, 4] """ - bentoml_cfg = get_bentomlconfiguration_from_str(GPU_INDEX) + bentoml_cfg = config_cls(GPU_INDEX) assert bentoml_cfg["runners"]["resources"] == {"nvidia.com/gpu": [1, 2, 4]} GPU_INDEX_WITH_STRING = """\ @@ -87,12 +181,14 @@ def test_runner_gpu_configuration(): resources: nvidia.com/gpu: "[1, 2, 4]" """ - bentoml_cfg = get_bentomlconfiguration_from_str(GPU_INDEX_WITH_STRING) + bentoml_cfg = config_cls(GPU_INDEX_WITH_STRING) # this behaviour can be confusing assert bentoml_cfg["runners"]["resources"] == {"nvidia.com/gpu": "[1, 2, 4]"} -RUNNER_TIMEOUTS = """\ +@pytest.mark.usefixtures("config_cls") +def test_runner_timeouts(config_cls: t.Callable[[str], ConfigDictType]): + RUNNER_TIMEOUTS = """\ runners: timeout: 50 test_runner_1: @@ -100,10 +196,7 @@ def test_runner_gpu_configuration(): test_runner_2: resources: system """ - - -def test_runner_timeouts(): - bentoml_cfg = get_bentomlconfiguration_from_str(RUNNER_TIMEOUTS) + bentoml_cfg = config_cls(RUNNER_TIMEOUTS) runner_cfg = bentoml_cfg["runners"] assert runner_cfg["timeout"] == 50 assert runner_cfg["test_runner_1"]["timeout"] == 100 diff --git a/tests/unit/_internal/test_utils.py b/tests/unit/_internal/test_utils.py index 74f5dca88b..1f9caea535 100644 --- a/tests/unit/_internal/test_utils.py +++ b/tests/unit/_internal/test_utils.py @@ -8,9 +8,20 @@ from scipy.sparse import csr_matrix import bentoml._internal.utils as utils +from bentoml._internal.types import LazyType from bentoml._internal.types import MetadataDict +def test_typeref(): + + # assert __eq__ + assert LazyType("numpy", "ndarray") == np.ndarray + assert LazyType("numpy", "ndarray") == LazyType(type(np.array([2, 3]))) + + # evaluate + assert LazyType("numpy", "ndarray").get_class() == np.ndarray + + def test_validate_labels(): inp = {"label1": "label", "label3": "anotherlabel"} diff --git a/tests/unit/_internal/utils/test_analytics.py b/tests/unit/_internal/utils/test_analytics.py index 6e3de4e122..9e363df3a7 100644 --- a/tests/unit/_internal/utils/test_analytics.py +++ b/tests/unit/_internal/utils/test_analytics.py @@ -132,7 +132,7 @@ def test_track_serve_init( mock_usage_event_debugging: MagicMock, mock_do_not_track: MagicMock, mock_post: MagicMock, - noop_service: Service, + simple_service: Service, production: bool, caplog: LogCaptureFixture, ): @@ -145,7 +145,7 @@ def test_track_serve_init( mock_response.text = "sent" analytics.usage_stats._track_serve_init( # type: ignore (private warning) - noop_service, + simple_service, production=production, serve_info=analytics.usage_stats.get_serve_info(), serve_kind="http", @@ -157,7 +157,7 @@ def test_track_serve_init( mock_usage_event_debugging.return_value = True with caplog.at_level(logging.INFO): analytics.usage_stats._track_serve_init( # type: ignore (private warning) - noop_service, + simple_service, production=production, serve_info=analytics.usage_stats.get_serve_info(), serve_kind="http", @@ -218,10 +218,12 @@ def test_filter_metrics_report( @patch("bentoml._internal.utils.analytics.usage_stats.do_not_track") -def test_track_serve_do_not_track(mock_do_not_track: MagicMock, noop_service: Service): +def test_track_serve_do_not_track( + mock_do_not_track: MagicMock, simple_service: Service +): mock_do_not_track.return_value = True with analytics.track_serve( - noop_service, + simple_service, production=False, serve_info=analytics.usage_stats.get_serve_info(), ) as output: @@ -236,18 +238,18 @@ def test_track_serve_do_not_track(mock_do_not_track: MagicMock, noop_service: Se def test_legacy_get_metrics_report( mock_prometheus_client: MagicMock, mock_do_not_track: MagicMock, - noop_service: Service, + simple_service: Service, ): mock_do_not_track.return_value = True mock_prometheus_client.multiproc.return_value = False mock_prometheus_client.text_string_to_metric_families.return_value = text_string_to_metric_families( b"""\ -# HELP BENTOML_noop_service_request_in_progress Multiprocess metric -# TYPE BENTOML_noop_service_request_in_progress gauge -BENTOML_noop_service_request_in_progress{endpoint="/predict",service_version="not available"} 0.0 -# HELP BENTOML_noop_service_request_total Multiprocess metric -# TYPE BENTOML_noop_service_request_total counter -BENTOML_noop_service_request_total{endpoint="/predict",http_response_code="200",service_version="not available"} 8.0 +# HELP BENTOML_simple_service_request_in_progress Multiprocess metric +# TYPE BENTOML_simple_service_request_in_progress gauge +BENTOML_simple_service_request_in_progress{endpoint="/predict",service_version="not available"} 0.0 +# HELP BENTOML_simple_service_request_total Multiprocess metric +# TYPE BENTOML_simple_service_request_total counter +BENTOML_simple_service_request_total{endpoint="/predict",http_response_code="200",service_version="not available"} 8.0 """.decode( "utf-8" ) @@ -276,7 +278,7 @@ def test_legacy_get_metrics_report( { "api_name": "pred_json", "http_response_code": "200", - "service_name": "noop_service", + "service_name": "simple_service", "service_version": "not available", "value": 15.0, }, @@ -291,10 +293,10 @@ def test_legacy_get_metrics_report( b"""\ # HELP bentoml_api_server_request_total Multiprocess metric # TYPE bentoml_api_server_request_total counter - bentoml_api_server_request_total{api_name="pred_json",http_response_code="200",service_name="noop_service",service_version="not available"} 15.0 + bentoml_api_server_request_total{api_name="pred_json",http_response_code="200",service_name="simple_service",service_version="not available"} 15.0 # HELP bentoml_api_server_request_in_progress Multiprocess metric # TYPE bentoml_api_server_request_in_progress gauge - bentoml_api_server_request_in_progress{api_name="pred_json",service_name="noop_service",service_version="not available"} 0.0 + bentoml_api_server_request_in_progress{api_name="pred_json",service_name="simple_service",service_version="not available"} 0.0 """.decode( "utf-8" ) @@ -303,7 +305,7 @@ def test_legacy_get_metrics_report( ) def test_get_metrics_report( mock_prometheus_client: MagicMock, - noop_service: Service, + simple_service: Service, serve_kind: str, expected: dict[str, str | float] | None, generated_metrics: t.Generator[Metric, None, None], @@ -331,7 +333,7 @@ def test_track_serve( mock_track_serve_init: MagicMock, mock_post: MagicMock, mock_do_not_track: MagicMock, - noop_service: Service, + simple_service: Service, monkeypatch: MonkeyPatch, caplog: LogCaptureFixture, ): @@ -344,7 +346,7 @@ def test_track_serve( with caplog.at_level(logging.INFO): with analytics.track_serve( - noop_service, + simple_service, production=False, metrics_client=mock_prometheus_client, serve_info=analytics.usage_stats.get_serve_info(), diff --git a/tests/conftest.py b/tests/unit/conftest.py similarity index 57% rename from tests/conftest.py rename to tests/unit/conftest.py index ccaf6e966e..99a0972c13 100644 --- a/tests/conftest.py +++ b/tests/unit/conftest.py @@ -1,99 +1,27 @@ +# pylint: disable=unused-argument from __future__ import annotations -import os import typing as t import logging -import pathlib from typing import TYPE_CHECKING import yaml import pytest +import cloudpickle import bentoml -from bentoml._internal.utils import bentoml_cattr -from bentoml._internal.models import ModelStore -from bentoml._internal.models import ModelContext -from bentoml._internal.bento.build_config import BentoBuildConfig +from bentoml.testing.pytest import TEST_MODEL_CONTEXT if TYPE_CHECKING: - from _pytest.python import Metafunc + from pathlib import Path -TEST_MODEL_CONTEXT = ModelContext( - framework_name="testing", framework_versions={"testing": "v1"} -) - - -def pytest_generate_tests(metafunc: Metafunc) -> None: - from bentoml._internal.utils import analytics - - analytics.usage_stats.do_not_track.cache_clear() - analytics.usage_stats._usage_event_debugging.cache_clear() # type: ignore (private warning) - - # used for local testing, on CI we already set DO_NOT_TRACK - os.environ["__BENTOML_DEBUG_USAGE"] = "False" - os.environ["BENTOML_DO_NOT_TRACK"] = "True" - - -@pytest.fixture(scope="function") -def noop_service(dummy_model_store: ModelStore) -> bentoml.Service: - import cloudpickle - - from bentoml.io import Text - - class NoopModel: - def predict(self, data: t.Any) -> t.Any: - return data - - with bentoml.models.create( - "noop_model", - context=TEST_MODEL_CONTEXT, - module=__name__, - signatures={"predict": {"batchable": True}}, - _model_store=dummy_model_store, - ) as model: - with open(model.path_of("test.pkl"), "wb") as f: - cloudpickle.dump(NoopModel(), f) - - ref = bentoml.models.get("noop_model", _model_store=dummy_model_store) - - class NoopRunnable(bentoml.Runnable): - SUPPORTED_RESOURCES = ("cpu",) - SUPPORTS_CPU_MULTI_THREADING = True - - def __init__(self): - self._model: NoopModel = bentoml.picklable_model.load_model(ref) - - @bentoml.Runnable.method(batchable=True) - def predict(self, data: t.Any) -> t.Any: - return self._model.predict(data) - - svc = bentoml.Service( - name="noop_service", - runners=[bentoml.Runner(NoopRunnable, models=[ref])], - ) - - @svc.api(input=Text(), output=Text()) - def noop_sync(data: str) -> str: # type: ignore - return data - - return svc - - -@pytest.fixture(scope="function", autouse=True, name="propagate_logs") -def fixture_propagate_logs() -> t.Generator[None, None, None]: - logger = logging.getLogger("bentoml") - # bentoml sets propagate to False by default, so we need to set it to True - # for pytest caplog to recognize logs - logger.propagate = True - yield - # restore propagate to False after tests - logger.propagate = False + from _pytest.fixtures import FixtureRequest @pytest.fixture(scope="function") def reload_directory( - request: pytest.FixtureRequest, tmp_path_factory: pytest.TempPathFactory -) -> t.Generator[pathlib.Path, None, None]: + request: FixtureRequest, tmp_path_factory: pytest.TempPathFactory +) -> t.Generator[Path, None, None]: """ This fixture will create an example bentoml working file directory and yield the results directory @@ -114,6 +42,9 @@ def reload_directory( ├── service.py └── train.py """ + from bentoml._internal.utils import bentoml_cattr + from bentoml._internal.bento.build_config import BentoBuildConfig + root = tmp_path_factory.mktemp("reload_directory") # create a models directory root.joinpath("models").mkdir() @@ -129,10 +60,10 @@ def reload_directory( "train.py", "fname.ipynb", ] + for f in root_file: p = root.joinpath(f) p.touch() - build_config = BentoBuildConfig( service="service.py:svc", description="A mock service", @@ -167,41 +98,63 @@ def reload_directory( yield root -@pytest.fixture(scope="function", name="change_test_dir") -def fixture_change_test_dir( - request: pytest.FixtureRequest, -) -> t.Generator[None, None, None]: - os.chdir(request.fspath.dirname) # type: ignore (bad pytest stubs) - yield - os.chdir(request.config.invocation_dir) # type: ignore (bad pytest stubs) +@pytest.fixture(scope="session") +def simple_service() -> bentoml.Service: + """ + This fixture create a simple service implementation that implements a noop runnable with two APIs: + - noop_sync: sync API that returns the input. + - invalid: an invalid API that can be used to test error handling. + """ + from bentoml.io import Text + + class NoopModel: + def predict(self, data: t.Any) -> t.Any: + return data -@pytest.fixture(scope="session", name="dummy_model_store") -def fixture_dummy_model_store(tmpdir_factory: "pytest.TempPathFactory") -> ModelStore: - store = ModelStore(tmpdir_factory.mktemp("models")) with bentoml.models.create( - "testmodel", - module=__name__, - signatures={}, + "python_function", context=TEST_MODEL_CONTEXT, - _model_store=store, - ): - pass - with bentoml.models.create( - "testmodel", module=__name__, - signatures={}, - context=TEST_MODEL_CONTEXT, - _model_store=store, - ): - pass - with bentoml.models.create( - "anothermodel", - module=__name__, - signatures={}, - context=TEST_MODEL_CONTEXT, - _model_store=store, - ): - pass + signatures={"predict": {"batchable": True}}, + ) as model: + with open(model.path_of("test.pkl"), "wb") as f: + cloudpickle.dump(NoopModel(), f) + + model_ref = bentoml.models.get("python_function") + + class NoopRunnable(bentoml.Runnable): + SUPPORTED_RESOURCES = ("cpu",) + SUPPORTS_CPU_MULTI_THREADING = True - return store + def __init__(self): + self._model: NoopModel = bentoml.picklable_model.load_model(model_ref) + + @bentoml.Runnable.method(batchable=True) + def predict(self, data: t.Any) -> t.Any: + return self._model.predict(data) + + svc = bentoml.Service( + name="simple_service", + runners=[bentoml.Runner(NoopRunnable, models=[model_ref])], + ) + + @svc.api(input=Text(), output=Text()) + def noop_sync(data: str) -> str: # type: ignore + return data + + @svc.api(input=Text(), output=Text()) + def invalid(data: str) -> str: # type: ignore + raise RuntimeError("invalid implementation.") + + return svc + + +@pytest.fixture(scope="function", name="propagate_logs") +def fixture_propagate_logs() -> t.Generator[None, None, None]: + """BentoML sets propagate to False by default, hence this fixture enable log propagation.""" + logger = logging.getLogger("bentoml") + logger.propagate = True + yield + # restore propagate to False after tests + logger.propagate = False diff --git a/tests/unit/grpc/__init__.py b/tests/unit/grpc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/grpc/conftest.py b/tests/unit/grpc/conftest.py new file mode 100644 index 0000000000..373f5747f5 --- /dev/null +++ b/tests/unit/grpc/conftest.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock +from unittest.mock import PropertyMock + +import pytest + +if TYPE_CHECKING: + import grpc +else: + from bentoml.grpc.utils import import_grpc + + grpc, _ = import_grpc() + + +@pytest.fixture(scope="module", name="mock_unary_unary_handler") +def fixture_mock_handler() -> MagicMock: + handler = MagicMock(spec=grpc.RpcMethodHandler) + handler.request_streaming = PropertyMock(return_value=False) + handler.response_streaming = PropertyMock(return_value=False) + return handler diff --git a/tests/unit/grpc/interceptors/__init__.py b/tests/unit/grpc/interceptors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/grpc/interceptors/test_access.py b/tests/unit/grpc/interceptors/test_access.py new file mode 100644 index 0000000000..1284dc8c85 --- /dev/null +++ b/tests/unit/grpc/interceptors/test_access.py @@ -0,0 +1,157 @@ +# pylint: disabl=unused-argument +from __future__ import annotations + +import typing as t +import logging +import functools +from typing import TYPE_CHECKING + +import pytest + +from bentoml.grpc.utils import wrap_rpc_handler +from bentoml.testing.grpc import create_channel +from bentoml.testing.grpc import create_bento_servicer +from bentoml.testing.grpc import make_standalone_server +from bentoml.grpc.interceptors.access import AccessLogServerInterceptor +from bentoml.grpc.interceptors.opentelemetry import AsyncOpenTelemetryServerInterceptor + +if TYPE_CHECKING: + from grpc import aio + from _pytest.logging import LogCaptureFixture + from google.protobuf import wrappers_pb2 + + from bentoml import Service + from bentoml.grpc.types import Request + from bentoml.grpc.types import Response + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import AsyncHandlerMethod + from bentoml.grpc.types import HandlerCallDetails + from bentoml.grpc.types import BentoServicerContext + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from bentoml.grpc.v1alpha1 import service_pb2_grpc as services + from bentoml.grpc.v1alpha1 import service_test_pb2 as pb_test + from bentoml.grpc.v1alpha1 import service_test_pb2_grpc as services_test +else: + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + pb, services = import_generated_stubs() + pb_test, services_test = import_generated_stubs(file="service_test.proto") + aio = LazyLoader("aio", globals(), "grpc.aio") + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") + + +class AppendMetadataInterceptor(aio.ServerInterceptor): + def __init__(self, *metadata: tuple[str, t.Any]): + self._metadata = tuple(metadata) + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + handler = await continuation(handler_call_details) + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + context.set_trailing_metadata(aio.Metadata.from_tuple(self._metadata)) + return await behaviour(request, context) + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("propagate_logs") +async def test_success_logs(caplog: LogCaptureFixture): + with make_standalone_server( + # we need to also setup opentelemetry interceptor + # to make sure the access log is correctly setup. + interceptors=[ + AsyncOpenTelemetryServerInterceptor(), + AccessLogServerInterceptor(), + ] + ) as (server, host_url): + try: + await server.start() + with caplog.at_level(logging.INFO, "bentoml.access"): + async with create_channel(host_url) as channel: + stub = services_test.TestServiceStub(channel) + await stub.Execute(pb_test.ExecuteRequest(input="BentoML")) + assert ( + "(scheme=http,path=/bentoml.testing.v1alpha1.TestService/Execute,type=application/grpc,size=9) (http_status=200,grpc_status=0,type=application/grpc,size=17)" + in caplog.text + ) + + finally: + await server.stop(None) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("propagate_logs") +async def test_trailing_metadata(caplog: LogCaptureFixture): + with make_standalone_server( + # we need to also setup opentelemetry interceptor + # to make sure the access log is correctly setup. + interceptors=[ + AsyncOpenTelemetryServerInterceptor(), + AppendMetadataInterceptor(("content-type", "application/grpc+python")), + AccessLogServerInterceptor(), + ] + ) as (server, host_url): + try: + await server.start() + with caplog.at_level(logging.INFO, "bentoml.access"): + async with create_channel(host_url) as channel: + stub = services_test.TestServiceStub(channel) + await stub.Execute(pb_test.ExecuteRequest(input="BentoML")) + assert "type=application/grpc+python" in caplog.text + finally: + await server.stop(None) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("propagate_logs") +async def test_access_log_exception(caplog: LogCaptureFixture, simple_service: Service): + with make_standalone_server( + # we need to also setup opentelemetry interceptor + # to make sure the access log is correctly setup. + interceptors=[ + AsyncOpenTelemetryServerInterceptor(), + AccessLogServerInterceptor(), + ] + ) as (server, host_url): + services.add_BentoServiceServicer_to_server( + create_bento_servicer(simple_service), server + ) + try: + await server.start() + with caplog.at_level(logging.INFO, "bentoml.access"): + async with create_channel(host_url) as channel: + Call = channel.unary_unary( + "/bentoml.grpc.v1alpha1.BentoService/Call", + request_serializer=pb.Request.SerializeToString, + response_deserializer=pb.Response.FromString, + ) + with pytest.raises(aio.AioRpcError): + await t.cast( + t.Awaitable[pb.Response], + Call( + pb.Request( + api_name="invalid", + text=wrappers_pb2.StringValue(value="asdf"), + ) + ), + ) + assert ( + "(scheme=http,path=/bentoml.grpc.v1alpha1.BentoService/Call,type=application/grpc,size=17) (http_status=500,grpc_status=13,type=application/grpc,size=0)" + in caplog.text + ) + finally: + await server.stop(None) diff --git a/tests/unit/grpc/interceptors/test_prometheus.py b/tests/unit/grpc/interceptors/test_prometheus.py new file mode 100644 index 0000000000..d294fed076 --- /dev/null +++ b/tests/unit/grpc/interceptors/test_prometheus.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import sys +import typing as t +import tempfile +from typing import TYPE_CHECKING +from asyncio import Future +from unittest.mock import MagicMock + +import pytest + +from bentoml.testing.grpc import create_channel +from bentoml.testing.grpc import async_client_call +from bentoml.testing.grpc import create_bento_servicer +from bentoml.testing.grpc import make_standalone_server +from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor +from bentoml._internal.configuration.containers import BentoMLContainer + +if TYPE_CHECKING: + import grpc + from google.protobuf import wrappers_pb2 + + from bentoml import Service + from bentoml.grpc.v1alpha1 import service_pb2_grpc as services + from bentoml.grpc.v1alpha1 import service_test_pb2 as pb_test +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + _, services = import_generated_stubs() + pb_test, _ = import_generated_stubs(file="service_test.proto") + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") + grpc, aio = import_grpc() + +prom_dir = tempfile.mkdtemp("prometheus-multiproc") +BentoMLContainer.prometheus_multiproc_dir.set(prom_dir) +interceptor = PrometheusServerInterceptor() + +if "prometheus_client" in sys.modules: + mods = [m for m in sys.modules if "prometheus_client" in m] + list(map(lambda s: sys.modules.pop(s), mods)) + if not interceptor._is_setup: + interceptor._setup() + + +@pytest.mark.asyncio +async def test_metrics_invocation(mock_unary_unary_handler: MagicMock): + mhandler_call_details = MagicMock(spec=grpc.HandlerCallDetails) + mcontinuation = MagicMock(return_value=Future()) + mcontinuation.return_value.set_result(mock_unary_unary_handler) + await interceptor.intercept_service(mcontinuation, mhandler_call_details) + assert mcontinuation.call_count == 1 + assert interceptor._is_setup # type: ignore # pylint: disable=protected-access + assert ( + interceptor.metrics_request_duration + and interceptor.metrics_request_total + and interceptor.metrics_request_in_progress + ) + + +@pytest.mark.asyncio +async def test_empty_metrics(): + metrics_client = BentoMLContainer.metrics_client.get() + # This test a branch where we change inside the handler whether or not the incoming + # handler contains pb.Request + # if it isn't a pb.Request, then we just pass the handler, hence metrics should be empty + with make_standalone_server(interceptors=[interceptor]) as ( + server, + host_url, + ): + try: + await server.start() + async with create_channel(host_url) as channel: + Execute = channel.unary_unary( + "/bentoml.testing.v1alpha1.TestService/Execute", + request_serializer=pb_test.ExecuteRequest.SerializeToString, + response_deserializer=pb_test.ExecuteResponse.FromString, + ) + resp = t.cast( + t.Awaitable[pb_test.ExecuteResponse], + Execute(pb_test.ExecuteRequest(input="BentoML")), + ) + await resp + assert metrics_client.generate_latest() == b"" + finally: + await server.stop(None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "metric_type,parent_set", + [ + ( + "counter", + ["api_name", "service_version", "http_response_code", "service_name"], + ), + ( + "histogram", + ["api_name", "service_version", "http_response_code", "service_name", "le"], + ), + ("gauge", ["api_name", "service_version", "service_name"]), + ], +) +async def test_metrics_interceptors( + simple_service: Service, + metric_type: str, + parent_set: list[str], +): + metrics_client = BentoMLContainer.metrics_client.get() + + with make_standalone_server(interceptors=[interceptor]) as ( + server, + host_url, + ): + services.add_BentoServiceServicer_to_server( + create_bento_servicer(simple_service), server + ) + try: + await server.start() + async with create_channel(host_url) as channel: + await async_client_call( + "noop_sync", + channel=channel, + data={"text": wrappers_pb2.StringValue(value="BentoML")}, + ) + for m in metrics_client.text_string_to_metric_families(): + for sample in m.samples: + if m.type == metric_type: + assert set(sample.labels).issubset(set(parent_set)) + assert ( + "api_name" in sample.labels + and sample.labels["api_name"] == "noop_sync" + ) + if m.type in ["counter", "histogram"]: + # response code is 500 because we didn't actually startup + # the service runner as well as running on_startup hooks. + # This is expected since we are testing prometheus behaviour. + assert sample.labels["http_response_code"] == "500" + + finally: + await server.stop(None) diff --git a/tests/unit/grpc/server/__init__.py b/tests/unit/grpc/server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/grpc/server/test_config.py b/tests/unit/grpc/server/test_config.py new file mode 100644 index 0000000000..c88069138b --- /dev/null +++ b/tests/unit/grpc/server/test_config.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import typing as t +from typing import TYPE_CHECKING + +import psutil +import pytest + +from bentoml._internal.server.grpc import Config +from bentoml._internal.server.grpc import Servicer + +if TYPE_CHECKING: + from bentoml import Service + + +@pytest.fixture() +def servicer(simple_service: Service) -> Servicer: + return Servicer(simple_service) + + +@pytest.mark.skipif(not psutil.WINDOWS, reason="Windows test.") +def test_windows_config_options(servicer: Servicer) -> None: + config = Config( + servicer, + bind_address="0.0.0.0", + max_message_length=None, + max_concurrent_streams=None, + maximum_concurrent_rpcs=None, + ) + assert not config.options + + +@pytest.mark.skipif(psutil.WINDOWS, reason="Unix test.") +@pytest.mark.parametrize( + "options,expected", + [ + ( + {"max_concurrent_streams": 128}, + ( + ("grpc.so_reuseport", 1), + ("grpc.max_concurrent_streams", 128), + ("grpc.max_message_length", -1), + ("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1), + ), + ), + ( + {"max_message_length": 2048}, + ( + ("grpc.so_reuseport", 1), + ("grpc.max_message_length", 2048), + ("grpc.max_receive_message_length", 2048), + ("grpc.max_send_message_length", 2048), + ), + ), + ], +) +def test_unix_options( + servicer: Servicer, + options: dict[str, t.Any], + expected: tuple[tuple[str, t.Any], ...], +) -> None: + config = Config(servicer, bind_address="0.0.0.0", **options) + assert config.options + assert config.options == expected diff --git a/tests/unit/grpc/test_grpc_utils.py b/tests/unit/grpc/test_grpc_utils.py new file mode 100644 index 0000000000..6ea89191d2 --- /dev/null +++ b/tests/unit/grpc/test_grpc_utils.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import typing as t +from http import HTTPStatus +from unittest.mock import Mock + +import grpc +import pytest + +from bentoml.exceptions import BadInput +from bentoml.exceptions import InvalidArgument +from bentoml.exceptions import BentoMLException +from bentoml.grpc.utils import MethodName +from bentoml.grpc.utils import to_http_status +from bentoml.grpc.utils import grpc_status_code +from bentoml.grpc.utils import wrap_rpc_handler +from bentoml.grpc.utils import parse_method_name + + +@pytest.mark.parametrize( + "exception,expected", + [ + (BentoMLException, grpc.StatusCode.INTERNAL), + (InvalidArgument, grpc.StatusCode.INVALID_ARGUMENT), + (BadInput, grpc.StatusCode.INVALID_ARGUMENT), + ( + type( + "UnknownException", + (BentoMLException,), + {"error_code": HTTPStatus.ALREADY_REPORTED}, + ), + grpc.StatusCode.UNKNOWN, + ), + ], +) +def test_exception_to_grpc_status( + exception: t.Type[BentoMLException], expected: grpc.StatusCode +): + assert grpc_status_code(exception("something")) == expected + + +@pytest.mark.parametrize( + "status_code,expected", + [ + (grpc.StatusCode.OK, HTTPStatus.OK), + (grpc.StatusCode.CANCELLED, HTTPStatus.INTERNAL_SERVER_ERROR), + (grpc.StatusCode.INVALID_ARGUMENT, HTTPStatus.BAD_REQUEST), + ], +) +def test_grpc_to_http_status_code(status_code: grpc.StatusCode, expected: HTTPStatus): + assert to_http_status(status_code) == expected + + +def test_method_name(): + # Fields are correct and fully_qualified_service work. + mn = MethodName("foo.bar", "SearchService", "Search") + assert mn.package == "foo.bar" + assert mn.service == "SearchService" + assert mn.method == "Search" + assert mn.fully_qualified_service == "foo.bar.SearchService" + + +def test_empty_package_method_name(): + # fully_qualified_service works when there's no package + mn = MethodName("", "SearchService", "Search") + assert mn.fully_qualified_service == "SearchService" + + +def test_parse_method_name(): + mn, ok = parse_method_name("/foo.bar.SearchService/Search") + assert mn.package == "foo.bar" + assert mn.service == "SearchService" + assert mn.method == "Search" + assert ok + + +def test_parse_empty_package(): + # parse_method_name works with no package. + mn, _ = parse_method_name("/SearchService/Search") + assert mn.package == "" + assert mn.service == "SearchService" + assert mn.method == "Search" + + +@pytest.mark.parametrize( + "request_streaming,response_streaming,handler_fn", + [ + (True, True, "stream_stream"), + (True, False, "stream_unary"), + (False, True, "unary_stream"), + (False, False, "unary_unary"), + ], +) +def test_wrap_rpc_handler( + request_streaming: bool, + response_streaming: bool, + handler_fn: str, +): + mock_handler = Mock( + request_streaming=request_streaming, + response_streaming=response_streaming, + ) + fn = Mock() + assert wrap_rpc_handler(fn, None) is None + # wrap_rpc_handler works with None handler. + wrapped = wrap_rpc_handler(fn, mock_handler) + assert fn.call_count == 1 + assert getattr(wrapped, handler_fn) is not None