Skip to content

Commit

Permalink
test(grpc): e2e + unit tests (#2984)
Browse files Browse the repository at this point in the history
Refactor some test components, and expose test functions users can use to write their tests.

Add a bentoml pytest plugin to ensure the correct environment and folders will be set during pytest invocation.
  • Loading branch information
aarnphm committed Sep 29, 2022
1 parent b3bd5a7 commit 8ccf0a2
Show file tree
Hide file tree
Showing 72 changed files with 3,474 additions and 985 deletions.
81 changes: 46 additions & 35 deletions .github/workflows/ci.yml
Expand Up @@ -13,6 +13,11 @@ env:
LINES: 120
COLUMNS: 120

# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#defaultsrun
defaults:
run:
shell: bash --noprofile --norc -exo pipefail {0}

jobs:
diff:
runs-on: ubuntu-latest
Expand All @@ -34,7 +39,10 @@ jobs:
- scripts/ci/config.yml
- scripts/ci/run_tests.sh
- requirements/tests-requirements.txt
protos: &protos
- "bentoml/grpc/**/*.proto"
bentoml:
- *protos
- *related
- "bentoml/**"
- "tests/**"
Expand All @@ -46,9 +54,6 @@ jobs:
codestyle_check:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
needs:
- diff

Expand All @@ -72,9 +77,13 @@ jobs:
uses: actions/setup-node@v3
with:
node-version: "17"
- name: install pyright
- name: Install pyright
run: |
npm install -g npm@^7 pyright
- name: Setup bufbuild/buf
uses: bufbuild/buf-setup-action@v1.8.0
with:
github_token: ${{ github.token }}

- name: Cache pip dependencies
uses: actions/cache@v3
Expand All @@ -94,12 +103,11 @@ jobs:
run: make ci-lint
- name: Type check
run: make ci-pyright
- name: Proto check
if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.protos == 'true') || github.event_name == 'push' }}
run: buf lint --config "bentoml/grpc/buf.yaml" --error-format msvs --path "bentoml/grpc"

documentation_spelling_check:
defaults:
run:
shell: bash

runs-on: ubuntu-latest
needs:
- diff
Expand Down Expand Up @@ -138,7 +146,6 @@ jobs:
- name: Run spellcheck script
run: make spellcheck-docs
shell: bash

unit_tests:
needs:
Expand All @@ -149,9 +156,6 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.7", "3.8", "3.9", "3.10"]
defaults:
run:
shell: bash

if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.bentoml == 'true') || github.event_name == 'push' }}
name: python${{ matrix.python-version }}_unit_tests (${{ matrix.os }})
Expand Down Expand Up @@ -182,17 +186,18 @@ jobs:

- name: Install dependencies
run: |
pip install .
pip install ".[grpc]"
pip install -r requirements/tests-requirements.txt
- name: Run unit tests
if: ${{ matrix.os != 'windows-latest' }}
run: make tests-unit

- name: Run unit tests (Windows)
if: ${{ matrix.os == 'windows-latest' }}
run: make tests-unit
shell: bash
run: |
OPTS=(--cov-config pyproject.toml --cov-report=xml:unit.xml -vvv)
if [ "${{ matrix.os }}" != 'windows-latest' ]; then
# we will use pytest-xdist to improve tests run-time.
OPTS=(${OPTS[@]} --dist loadfile -n auto --run-grpc-tests)
fi
# Now run the unit tests
python -m pytest tests/unit "${OPTS[@]}"
- name: Upload test coverage to Codecov
uses: codecov/codecov-action@v3
Expand All @@ -213,12 +218,13 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.7", "3.8", "3.9", "3.10"]
defaults:
run:
shell: bash
server_type: ["http", "grpc"]
exclude:
- os: windows-latest
server_type: "grpc"

if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.bentoml == 'true') || github.event_name == 'push' }}
name: python${{ matrix.python-version }}_e2e_tests (${{ matrix.os }})
name: python${{ matrix.python-version }}_${{ matrix.server_type }}_e2e_tests (${{ matrix.os }})
runs-on: ${{ matrix.os }}
timeout-minutes: 20

Expand Down Expand Up @@ -256,24 +262,29 @@ jobs:
path: ${{ steps.cache-dir.outputs.dir }}
key: ${{ runner.os }}-tests-${{ hashFiles('requirements/tests-requirements.txt') }}

- name: Install dependencies
- name: Install dependencies for ${{ matrix.server_type }}-based tests.
run: |
pip install -e ".[grpc]"
pip install -r requirements/tests-requirements.txt
pip install -r tests/e2e/bento_server_general_features/requirements.txt
- name: Export Action Envvar
run: export GITHUB_ACTION=true

- name: Run tests and generate coverage report
run: ./scripts/ci/run_tests.sh general_features
if [ "${{ matrix.server_type }}" == 'grpc' ]; then
pip install -e ".[grpc]"
else
pip install -e .
fi
if [ -f "tests/e2e/bento_server_${{ matrix.server_type }}/requirements.txt" ]; then
pip install -r tests/e2e/bento_server_${{ matrix.server_type }}/requirements.txt
fi
- name: Run ${{ matrix.server_type }} tests and generate coverage report
run: ./scripts/ci/run_tests.sh ${{ matrix.server_type }}_server --verbose

- name: Upload test coverage to Codecov
uses: codecov/codecov-action@v3
with:
flags: e2e-tests
flags: e2e-tests-${{ matrix.server_type }}
name: codecov-${{ matrix.os }}-python${{ matrix.python-version }}-e2e
fail_ci_if_error: true
directory: ./
files: ./tests/e2e/bento_server_general_features/general_features.xml
files: ./tests/e2e/bento_server_${{ matrix.server_type }}/${{ matrix.server_type }}_server.xml
verbose: true

concurrency:
Expand Down
2 changes: 1 addition & 1 deletion DEVELOPMENT.md
Expand Up @@ -353,7 +353,7 @@ Flags:
If `pytest_additional_arguments` is given, the additional arguments will be passed to all of the tests run by the tests script.

Example:
$ ./scripts/ci/run_tests.sh pytorch --gpus --capture=tee-sys
$ ./scripts/ci/run_tests.sh pytorch --run-gpus-tests --capture=tee-sys
```

All tests are then defined under [config.yml](./scripts/ci/config.yml) where each field follows the following format:
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Expand Up @@ -4,6 +4,7 @@ SHELL := /bin/bash
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
USE_VERBOSE ?=false
USE_GPU ?= false
USE_GRPC ?= false

help: ## Show all Makefile targets
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
Expand Down Expand Up @@ -49,7 +50,9 @@ tests-%:
ifeq ($(USE_VERBOSE),true)
./scripts/ci/run_tests.sh -v $(type) $(__positional)
else ifeq ($(USE_GPU),true)
./scripts/ci/run_tests.sh -v $(type) --gpus $(__positional)
./scripts/ci/run_tests.sh -v $(type) --run-gpu-tests $(__positional)
else ifeq ($(USE_GPRC),true)
./scripts/ci/run_tests.sh -v $(type) --run-gprc-tests $(__positional)
else
./scripts/ci/run_tests.sh $(type) $(__positional)
endif
Expand Down
2 changes: 1 addition & 1 deletion bentoml/_internal/io_descriptors/file.py
Expand Up @@ -227,7 +227,7 @@ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
mime_type = mapping[field.kind]
if mime_type != self._mime_type:
raise BadInput(
f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'",
f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'",
)
except KeyError:
raise BadInput(
Expand Down
2 changes: 1 addition & 1 deletion bentoml/_internal/io_descriptors/image.py
Expand Up @@ -358,7 +358,7 @@ async def from_proto(self, field: pb.File | bytes) -> ImageType:
mime_type = mapping[field.kind]
if mime_type != self._mime_type:
raise BadInput(
f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'",
f"Inferred mime_type from 'kind' is '{mime_type}', while '{self!r}' is expecting '{self._mime_type}'",
)
except KeyError:
raise BadInput(
Expand Down
50 changes: 28 additions & 22 deletions bentoml/_internal/io_descriptors/multipart.py
Expand Up @@ -143,12 +143,12 @@ async def predict(
| +--------------------------------------------------------+ |
| | | |
| | Multipart(arr=NumpyNdarray(), annotations=JSON()) | |
| | | |
| +----------------+-----------------------+---------------+ |
| | | |
| | | |
| | | |
| +----+ +---------+ |
| | | | | |
| +---------------+-----------------------+----------------+ |
| | | |
| | | |
| | | |
| +-----+ +--------+ |
| | | |
| +---------------v--------v---------+ |
| | def predict(arr, annotations): | |
Expand Down Expand Up @@ -236,28 +236,33 @@ async def to_http_response(
def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None:
if len(set(field) - set(self._inputs)) != 0:
raise InvalidArgument(
f"'{repr(self)}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}",
f"'{self!r}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}",
) from None

async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
from bentoml.grpc.utils import validate_proto_fields

if isinstance(field, bytes):
raise InvalidArgument(
f"cannot use 'serialized_bytes' with {self.__class__.__name__}"
) from None
message = field.fields
self.validate_input_mapping(message)
to_populate = {self._inputs[k]: message[k] for k in self._inputs}
reqs = await asyncio.gather(
*tuple(
io_.from_proto(getattr(input_pb, io_._proto_fields[0]))
for io_, input_pb in self.io_fields_mapping(message).items()
descriptor.from_proto(
getattr(
part,
validate_proto_fields(
part.WhichOneof("representation"), descriptor
),
)
)
for descriptor, part in to_populate.items()
)
)
return dict(zip(message, reqs))

def io_fields_mapping(
self, message: t.MutableMapping[str, pb.Part]
) -> dict[IODescriptor[t.Any], pb.Part]:
return {io_: part for io_, part in zip(self._inputs.values(), message.values())}
return dict(zip(self._inputs.keys(), reqs))

async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
self.validate_input_mapping(obj)
Expand All @@ -268,13 +273,14 @@ async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
)
)
return pb.Multipart(
fields={
key: pb.Part(
**{
io_._proto_fields[0]: resp
fields=dict(
zip(
obj,
[
# TODO: support multiple proto_fields
pb.Part(**{io_._proto_fields[0]: resp})
for io_, resp in zip(self._inputs.values(), resps)
}
],
)
for key in obj
}
)
)
28 changes: 12 additions & 16 deletions bentoml/_internal/server/grpc/servicer.py
Expand Up @@ -9,6 +9,7 @@
import anyio

from bentoml.grpc.utils import grpc_status_code
from bentoml.grpc.utils import validate_proto_fields

from ....exceptions import InvalidArgument
from ....exceptions import BentoMLException
Expand All @@ -27,7 +28,6 @@
from bentoml.grpc.types import AddServicerFn
from bentoml.grpc.types import ServicerClass
from bentoml.grpc.types import BentoServicerContext
from bentoml.grpc.types import GeneratedProtocolMessageType
from bentoml.grpc.v1alpha1 import service_pb2 as pb
from bentoml.grpc.v1alpha1 import service_pb2_grpc as services

Expand Down Expand Up @@ -148,28 +148,24 @@ async def Call( # type: ignore (no async types) # pylint: disable=invalid-overr
# We will use fields descriptor to determine how to process that request.
try:
# we will check if the given fields list contains a pb.Multipart.
field = request.WhichOneof("content")
if field is None:
raise InvalidArgument("Request cannot be empty.")
accepted_fields = api.input._proto_fields + ("serialized_bytes",)
if field not in accepted_fields:
raise InvalidArgument(
f"'{api.input.__class__.__name__}' accepts one of the following fields: '{', '.join(accepted_fields)}', and none of them are found in the request message.",
) from None
input_ = await api.input.from_proto(getattr(request, field))
input_proto = getattr(
request,
validate_proto_fields(request.WhichOneof("content"), api.input),
)
input_data = await api.input.from_proto(input_proto)
if asyncio.iscoroutinefunction(api.func):
if isinstance(api.input, Multipart):
output = await api.func(**input_)
output = await api.func(**input_data)
else:
output = await api.func(input_)
output = await api.func(input_data)
else:
if isinstance(api.input, Multipart):
output = await anyio.to_thread.run_sync(api.func, **input_)
output = await anyio.to_thread.run_sync(api.func, **input_data)
else:
output = await anyio.to_thread.run_sync(api.func, input_)
protos = await api.output.to_proto(output)
output = await anyio.to_thread.run_sync(api.func, input_data)
res = await api.output.to_proto(output)
# TODO(aarnphm): support multiple proto fields
response = pb.Response(**{api.output._proto_fields[0]: protos})
response = pb.Response(**{api.output._proto_fields[0]: res})
except BentoMLException as e:
log_exception(request, sys.exc_info())
await context.abort(code=grpc_status_code(e), details=e.message)
Expand Down
2 changes: 1 addition & 1 deletion bentoml/bentos.py
Expand Up @@ -424,7 +424,7 @@ def construct_dockerfile(
with open(bento.path_of(dockerfile_path), "r") as f:
FINAL_DOCKERFILE = f"""\
{f.read()}
FROM base-{bento.info.docker.distro}
FROM base-{bento.info.docker.distro} as final
# Additional instructions for final image.
{final_instruction}
"""
Expand Down

0 comments on commit 8ccf0a2

Please sign in to comment.