diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 7e11288aa3..7ae205d5de 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,6 +1,6 @@ # syntax=docker/dockerfile:1.4-labs -FROM python:3-bullseye +FROM --platform=linux/amd64 python:3-bullseye # [Option] Install zsh ARG INSTALL_ZSH="true" @@ -12,10 +12,8 @@ ARG ENABLE_NONROOT_DOCKER="true" ARG USE_MOBY="true" # [Option] Select CLI version ARG CLI_VERSION="latest" - # Enable new "BUILDKIT" mode for Docker CLI ENV DOCKER_BUILDKIT=1 - ENV DEBIAN_FRONTEND=noninteractive # Install needed packages and setup non-root user. Use a separate RUN statement to add your @@ -23,16 +21,28 @@ ENV DEBIAN_FRONTEND=noninteractive ARG USERNAME=automatic ARG USER_UID=1000 ARG USER_GID=$USER_UID -COPY library-scripts/*.sh /tmp/library-scripts/ -RUN --mount=type=cache,target=/var/cache/apt \ - --mount=type=cache,target=/var/lib/apt \ - apt-get update \ - && apt-get install -y build-essential software-properties-common vim \ +COPY .devcontainer/library-scripts/*.sh /tmp/library-scripts/ + +RUN --mount=type=cache,target=/var/lib/apt \ + --mount=type=cache,target=/var/cache/apt \ + apt-get update -y \ + # Remove imagemagick due to https://security-tracker.debian.org/tracker/CVE-2019-10131 + && apt-get purge -y imagemagick imagemagick-6-common + +# install common packages +RUN --mount=type=cache,target=/var/lib/apt \ + --mount=type=cache,target=/var/cache/apt \ + apt-get install -y build-essential software-properties-common vim \ && /bin/bash /tmp/library-scripts/common-debian.sh "${INSTALL_ZSH}" "${USERNAME}" "${USER_UID}" "${USER_GID}" "${UPGRADE_PACKAGES}" "true" "true" \ # Use Docker script from script library to set things up && /bin/bash /tmp/library-scripts/docker-debian.sh "${ENABLE_NONROOT_DOCKER}" "/var/run/docker-host.sock" "/var/run/docker.sock" "${USERNAME}" "${USE_MOBY}" "${CLI_VERSION}" \ # Clean up - && apt-get autoremove -y && apt-get clean -y && rm -rf /var/lib/apt/lists/* /tmp/library-scripts/ + && rm -rf /var/lib/apt/lists/* /tmp/library-scripts/ + +COPY requirements/*.txt /tmp/pip-tmp/ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --no-warn-script-location -r /tmp/pip-tmp/dev-requirements.txt -r /tmp/pip-tmp/docs-requirements.txt \ + && rm -rf /tmp/pip-tmp # Setting the ENTRYPOINT to docker-init.sh will configure non-root access to # the Docker socket if "overrideCommand": false is set in devcontainer.json. diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 8c3364e937..3cc51c5f0b 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,6 +3,7 @@ { "name": "BentoML", "dockerFile": "Dockerfile", + "context": "..", "containerEnv": { "BENTOML_DEBUG": "True", "BENTOML_BUNDLE_LOCAL_BUILD": "True", @@ -22,15 +23,16 @@ // Configure properties specific to VS Code. "vscode": { "extensions": [ - "ms-azuretools.vscode-docker", - "ms-python.vscode-pylance", "ms-python.python", + "ms-python.vscode-pylance", + "ms-azuretools.vscode-docker", "ms-vsliveshare.vsliveshare", "ms-python.black-formatter", "ms-python.pylint", "samuelcolvin.jinjahtml", "GitHub.copilot", - "esbenp.prettier-vscode" + "esbenp.prettier-vscode", + "VisualStudioExptTeam.intellicode-api-usage-examples" ], "settings": { "files.watcherExclude": { @@ -52,18 +54,14 @@ "[jsonc]": { "editor.defaultFormatter": "esbenp.prettier-vscode" }, - "editor.minimap.enabled": true, + "editor.minimap.enabled": false, "editor.formatOnSave": true, - "editor.wordWrapColumn": 88, + "editor.wordWrapColumn": 88 } } }, // Use 'forwardPorts' to make a list of ports inside the container available locally. - "forwardPorts": [3000, 8080, 9000, 9090], - // Link some default configs to codespace container. - "postCreateCommand": "bash ./.devcontainer/lifecycle/post-create", + "forwardPorts": [3000, 3001], // install BentoML and tools - "postStartCommand": "bash ./.devcontainer/lifecycle/post-start", - // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. - "remoteUser": "vscode" + "postStartCommand": "bash ./.devcontainer/lifecycle/post-start" } diff --git a/.devcontainer/lifecycle/post-create b/.devcontainer/lifecycle/post-create deleted file mode 100755 index 5e72c9360a..0000000000 --- a/.devcontainer/lifecycle/post-create +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash - -pip install --user -r requirements/dev-requirements.txt diff --git a/.devcontainer/lifecycle/post-start b/.devcontainer/lifecycle/post-start index c36ed88eb7..6edd411cc3 100755 --- a/.devcontainer/lifecycle/post-start +++ b/.devcontainer/lifecycle/post-start @@ -7,4 +7,16 @@ git config --global pull.ff only git fetch upstream --tags && git pull # install editable wheels & tools for bentoml -pip install --user -e ".[tracing]" --isolated +pip install -e ".[tracing,grpc]" --verbose +pip install -r requirements/dev-requirements.txt +pip install -U "grpcio-tools>=1.41.0" "mypy-protobuf>=3.3.0" +# generate stubs +OPTS=(-I. --grpc_python_out=. --python_out=. --mypy_out=. --mypy_grpc_out=.) +python -m grpc_tools.protoc "${OPTS[@]}" bentoml/grpc/v1alpha1/service.proto +python -m grpc_tools.protoc "${OPTS[@]}" bentoml/grpc/v1alpha1/service_test.proto +# uninstall broken protobuf typestubs +pip uninstall -y types-protobuf + +# setup docker buildx +docker buildx install +docker buildx ls | grep bentoml-builder &>/dev/null || docker buildx create --use --name bentoml-builder --platform linux/amd64,linux/arm64 &>/dev/null diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 76f6387247..3a56683e15 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -257,9 +257,15 @@ jobs: path: ${{ steps.cache-dir.outputs.dir }} key: ${{ runner.os }}-tests-${{ hashFiles('requirements/tests-requirements.txt') }} + # Simulate ./scripts/generate_grpc_stubs.sh + - name: Generate gRPC stubs + run: | + pip install protobuf==3.19.4 "grpcio-tools==1.41" + find bentoml/grpc/v1alpha1 -type f -name "*.proto" -exec python -m grpc_tools.protoc -I. --grpc_python_out=. --python_out=. "{}" \; + - name: Install dependencies run: | - pip install -e . + pip install -e ".[grpc]" pip install -r requirements/tests-requirements.txt pip install -r tests/e2e/bento_server_general_features/requirements.txt diff --git a/.gitignore b/.gitignore index a2baefb2a7..69360490f2 100644 --- a/.gitignore +++ b/.gitignore @@ -114,3 +114,10 @@ typings # test files catboost_info + +# ignore pyvenv +pyvenv.cfg + +# generated stub that is included in distribution +*_pb2*.py +*_pb2*.pyi diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index a1370404be..15a034c82d 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -18,22 +18,26 @@ If you are interested in proposing a new feature, make sure to create a new feat 2. Fork the BentoML project on [GitHub](https://github.com/bentoml/BentoML). 3. Clone the source code from your fork of BentoML's GitHub repository: + ```bash git clone git@github.com:username/BentoML.git && cd BentoML ``` 4. Add the BentoML upstream remote to your local BentoML clone: + ```bash git remote add upstream git@github.com:bentoml/BentoML.git ``` 5. Configure git to pull from the upstream remote: + ```bash git switch main # ensure you're on the main branch git branch --set-upstream-to=upstream/main ``` 6. Install BentoML with pip in editable mode: + ```bash pip install -e . ``` @@ -41,11 +45,13 @@ If you are interested in proposing a new feature, make sure to create a new feat This installs BentoML in an editable state. The changes you make will automatically be reflected without reinstalling BentoML. 7. Install the BentoML development requirements: + ```bash pip install -r ./requirements/dev-requirements.txt ``` 8. Test the BentoML installation either with `bash`: + ```bash bentoml --version ``` @@ -62,65 +68,72 @@ If you are interested in proposing a new feature, make sure to create a new feat

with VS Code

1. Confirm that you have the following installed: - - [Python3.7+](https://www.python.org/downloads/) - - VS Code with the [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) and [Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance) extensions + + - [Python3.7+](https://www.python.org/downloads/) + - VS Code with the [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) and [Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance) extensions 2. Fork the BentoML project on [GitHub](https://github.com/bentoml/BentoML). 3. Clone the GitHub repository: - 1. Open the command palette with Ctrl+Shift+P and type in 'clone'. - 2. Select 'Git: Clone(Recursive)'. - 3. Clone BentoML. + + 1. Open the command palette with Ctrl+Shift+P and type in 'clone'. + 2. Select 'Git: Clone(Recursive)'. + 3. Clone BentoML. 4. Add an BentoML upstream remote: - 1. Open the command palette and enter 'add remote'. - 2. Select 'Git: Add Remote'. - 3. Press enter to select 'Add remote' from GitHub. - 4. Enter https://github.com/bentoml/BentoML.git to select the BentoML repository. - 5. Name your remote 'upstream'. + + 1. Open the command palette and enter 'add remote'. + 2. Select 'Git: Add Remote'. + 3. Press enter to select 'Add remote' from GitHub. + 4. Enter https://github.com/bentoml/BentoML.git to select the BentoML repository. + 5. Name your remote 'upstream'. 5. Pull from the BentoML upstream remote to your main branch: - 1. Open the command palette and enter 'checkout'. - 2. Select 'Git: Checkout to...' - 3. Choose 'main' to switch to the main branch. - 4. Open the command palette again and enter 'pull from'. - 5. Click on 'Git: Pull from...' - 6. Select 'upstream'. + + 1. Open the command palette and enter 'checkout'. + 2. Select 'Git: Checkout to...' + 3. Choose 'main' to switch to the main branch. + 4. Open the command palette again and enter 'pull from'. + 5. Click on 'Git: Pull from...' + 6. Select 'upstream'. 6. Open a new terminal by clicking the Terminal dropdown at the top of the window, followed by the 'New Terminal' option. Next, add a virtual environment with this command: ```bash python -m venv .venv ``` 7. Click yes if a popup suggests switching to the virtual environment. Otherwise, go through these steps: - 1. Open any python file in the directory. - 2. Select the interpreter selector on the blue status bar at the bottom of the editor. - ![vscode-status-bar](https://user-images.githubusercontent.com/489344/166984038-75f1f4bd-c896-43ee-a7ee-1b57fda359a3.png) - - 3. Switch to the path that includes .venv from the dropdown at the top. - ![vscode-select-venv](https://user-images.githubusercontent.com/489344/166984060-170d25f5-a91f-41d3-96f4-4db3c21df7c8.png) + 1. Open any python file in the directory. + 2. Select the interpreter selector on the blue status bar at the bottom of the editor. + ![vscode-status-bar](https://user-images.githubusercontent.com/489344/166984038-75f1f4bd-c896-43ee-a7ee-1b57fda359a3.png) + + 3. Switch to the path that includes .venv from the dropdown at the top. + ![vscode-select-venv](https://user-images.githubusercontent.com/489344/166984060-170d25f5-a91f-41d3-96f4-4db3c21df7c8.png) 8. Update your PowerShell execution policies. Win+x followed by the 'a' key opens the admin Windows PowerShell. Enter the following command to allow the virtual environment activation script to run: ``` Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser ``` -
+ ## Making Changes

using the Command Line

1. Make sure you're on the main branch. + ```bash git switch main ``` 2. Use the git pull command to retrieve content from the BentoML Github repository. + ```bash git pull ``` 3. Create a new branch and switch to it. + ```bash git switch -c my-new-branch-name ``` @@ -128,11 +141,13 @@ If you are interested in proposing a new feature, make sure to create a new feat 4. Make your changes! 5. Use the git add command to save the state of files you have changed. + ```bash git add ``` 6. Commit your changes. + ```bash git commit ``` @@ -141,59 +156,71 @@ If you are interested in proposing a new feature, make sure to create a new feat ```bash git push ``` -
+

using VS Code

1. Switch to the main branch: - 1. Open the command palette with Ctrl+Shift+P. - 2. Search for 'Git: Checkout to...' - 3. Select 'main'. + + 1. Open the command palette with Ctrl+Shift+P. + 2. Search for 'Git: Checkout to...' + 3. Select 'main'. 2. Pull from the upstream remote: - 1. Open the command palette. - 2. Enter and select 'Git: Pull...' - 3. Select 'upstream'. + + 1. Open the command palette. + 2. Enter and select 'Git: Pull...' + 3. Select 'upstream'. 3. Create and change to a new branch: - 1. Type in 'Git: Create Branch...' in the command palette. - 2. Enter a branch name. + + 1. Type in 'Git: Create Branch...' in the command palette. + 2. Enter a branch name. 4. Make your changes! 5. Stage all your changes: - 1. Enter and select 'Git: Stage All Changes...' in the command palette. + + 1. Enter and select 'Git: Stage All Changes...' in the command palette. 6. Commit your changes: - 1. Open the command palette and enter 'Git: Commit'. + + 1. Open the command palette and enter 'Git: Commit'. 7. Push your changes: - 1. Enter and select 'Git: Push...' in the command palette. + 1. Enter and select 'Git: Push...' in the command palette.
- ## Run BentoML with verbose/debug logging To view internal debug loggings for development, set the `BENTOML_DEBUG` environment variable to `TRUE`: + ```bash export BENTOML_DEBUG=TRUE ``` And/or use the `--verbose` option when running `bentoml` CLI command, e.g.: + ```bash bentoml get IrisClassifier --verbose ``` ## Style check, auto-formatting, type-checking -formatter: [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort) +formatter: [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort), [buf](https://github.com/bufbuild/buf) -linter: [pylint](https://pylint.org/) +linter: [pylint](https://pylint.org/), [buf](https://github.com/bufbuild/buf) type checker: [pyright](https://github.com/microsoft/pyright) +We are using [buf](https://github.com/bufbuild/buf) for formatting and linting +of our proto files. Configuration can be found [here](./bentoml/grpc/buf.yaml). +Currently, we are running `buf` with docker, hence we kindly ask our developers +to have docker available. Docker installation can be found [here](https://docs.docker.com/get-docker/). + Run linter/format script: + ```bash make format @@ -201,23 +228,30 @@ make lint ``` Run type checker: + ```bash make type ``` +## Editing proto files + +The proto files for the BentoML gRPC service are located under [`bentoml/grpc`](./bentoml/grpc/). +The generated python files are not checked in the git repository, and are instead generated via this [`script`](./scripts/generate_grpc_stubs.sh). +If you edit the proto files, make sure to run `./scripts/generate_grpc_stubs.sh` to +regenerate the proto stubs. + ## Deploy with your changes Test test out your changes in an actual BentoML model deployment, you can create a new Bento with your custom BentoML source repo: 1. Install custom BentoML in editable mode. e.g.: - * git clone your bentoml fork - * `pip install -e PATH_TO_THE_FORK` + - git clone your bentoml fork + - `pip install -e PATH_TO_THE_FORK` 2. Set env var `export BENTOML_BUNDLE_LOCAL_BUILD=True` and `export SETUPTOOLS_USE_DISTUTILS=stdlib` - * make sure you have the latest setuptools installed: `pip install -U setuptools` + - make sure you have the latest setuptools installed: `pip install -U setuptools` 3. Build a new Bento with `bentoml build` in your project directory -4. The new Bento will include a wheel file built from the BentoML source, and -`bentoml containrize` will install it to override the default BentoML installation in base image - +4. The new Bento will include a wheel file built from the BentoML source, and + `bentoml containerize` will install it to override the default BentoML installation in base image ### Distribute a custom BentoML release for your team @@ -236,18 +270,18 @@ description: "file: ./README.md" include: - "*.py" python: - packages: - - pandas - - git+https://github.com/{YOUR_GITHUB_USERNAME}/bentoml@{YOUR_REVISION} + packages: + - pandas + - git+https://github.com/{YOUR_GITHUB_USERNAME}/bentoml@{YOUR_REVISION} docker: - system_packages: - - git + system_packages: + - git ``` - ## Testing Make sure to install all test dependencies: + ```bash pip install -r requirements/tests-requirements.txt ``` @@ -257,12 +291,14 @@ pip install -r requirements/tests-requirements.txt You can run unit tests in two ways: Run all unit tests directly with pytest: + ```bash # GIT_ROOT=$(git rev-parse --show-toplevel) -pytest tests/unit --cov=bentoml --cov-config="$GIT_ROOT"/setup.cfg +pytest tests/unit --cov=bentoml --cov-config="$GIT_ROOT"/pyproject.toml ``` Run all unit tests via `./scripts/ci/run_tests.sh`: + ```bash ./scripts/ci/run_tests.sh unit @@ -273,6 +309,7 @@ make tests-unit ### Integration tests Run given tests after defining a target under `scripts/ci/config.yml` with `run_tests.sh`: + ```bash # example: run Keras TF1 integration tests ./scripts/ci/run_tests.sh keras_tf1 @@ -281,18 +318,26 @@ Run given tests after defining a target under `scripts/ci/config.yml` with `run_ ### E2E tests ```bash -# example: run e2e tests to check for general features -./scripts/ci/run_tests.sh general_features +# example: run e2e tests to check for http general features +./scripts/ci/run_tests.sh http_server +``` + +### Running the whole suite + +To run the whole test suite, minus frameworks integration, you can use: + +```bash +make tests-suite ``` ### Adding new test suite - + If you are adding new ML framework support, it is recommended that you also add a separate test suite in our CI. Currently we are using GitHub Actions to manage our CI/CD workflow. We recommend using [`nektos/act`](https://github.com/nektos/act) to run and test Actions locally. - The following tests script [run_tests.sh](./scripts/ci/run_tests.sh) can be used to run tests locally. + ```bash ./scripts/ci/run_tests.sh -h Running unit/integration tests with pytest and generate coverage reports. Make sure that given testcases is defined under ./scripts/ci/config.yml. @@ -312,6 +357,7 @@ Example: ``` All tests are then defined under [config.yml](./scripts/ci/config.yml) where each field follows the following format: + ```yaml : &tmpl root_test_dir: "tests/integration/frameworks" @@ -324,24 +370,25 @@ All tests are then defined under [config.yml](./scripts/ci/config.yml) where eac By default, each of our frameworks tests files with the format: `test__impl.py`. If `is_dir` set to `true` we will try to match the given `` under `root_test_dir` to run tests from. -| Keys | Type | Defintions | -|------|------|------------| -|`root_test_dir`| ``| root directory to run a given tests | -|`is_dir`| ``| whether `target` is a directory instead of a file | -|`override_name_or_path`| ``| optional way to override a tests file name if doesn't match our convention | -|`dependencies`| ``| define additional dependencies required to run the tests, accepts `requirements.txt` format | -|`external_scripts`| ``| optional shell scripts that can be run on top of `./scripts/ci/run_tests.sh` for given testsuite | -|`type_tests`| ``| define type of tests for given `target` | +| Keys | Type | Defintions | +| ----------------------- | --------------------------------------- | ------------------------------------------------------------------------------------------------ | +| `root_test_dir` | `` | root directory to run a given tests | +| `is_dir` | `` | whether `target` is a directory instead of a file | +| `override_name_or_path` | `` | optional way to override a tests file name if doesn't match our convention | +| `dependencies` | `` | define additional dependencies required to run the tests, accepts `requirements.txt` format | +| `external_scripts` | `` | optional shell scripts that can be run on top of `./scripts/ci/run_tests.sh` for given testsuite | +| `type_tests` | `` | define type of tests for given `target` | When `type_tests` is set to `e2e`, `./scripts/ci/run_tests.sh` will change current directory into the given `root_test_dir`, and will run the testsuite from there. The reason why we encourage developers to use the scripts in CI is that under the hood when we use pytest, we will create a custom report for the given tests. This report can then be used as carryforward flags on codecov for consistent reporting. Example: + ```yaml # e2e tests -general_features: - root_test_dir: "tests/e2e/bento_server_general_features" +http: + root_test_dir: "tests/e2e/bento_server_http" is_dir: true type_tests: "e2e" dependencies: @@ -359,20 +406,18 @@ pytorch_lightning: Refer to [config.yml](./scripts/ci/config.yml) for more examples. - ## Python tools ecosystem -Currently, BentoML is [PEP518](https://www.python.org/dev/peps/pep-0518/) compatible via `setup.cfg` and `pyproject.toml`. - We also define most of our config for Python tools where: - - `pyproject.toml` contains config for `setuptools`, `black`, `pytest`, `pylint`, `isort`, `pyright` - - `setup.cfg` contains metadata for `bentoml` library and `coverage` +Currently, BentoML is [PEP518](https://www.python.org/dev/peps/pep-0518/) compatible. We define package configuration via [`pyproject.toml`][https://github.com/bentoml/bentoml/blob/main/pyproject.toml]. ## Benchmark + BentoML has moved its benchmark to [`bentoml/benchmark`](https://github.com/bentoml/benchmark). ## Optional: git hooks BentoML also provides git hooks that developers can install with: + ```bash make hooks ``` @@ -385,6 +430,7 @@ on how to create a pull request on github. Name your pull request with one of the following prefixes, e.g. "feat: add support for PyTorch". This is based on the [Conventional Commits specification](https://www.conventionalcommits.org/en/v1.0.0/#summary) + - feat: (new feature for the user, not a new feature for build script) - fix: (bug fix for the user, not a fix to a build script) - docs: (changes to the documentation) @@ -406,4 +452,3 @@ your pull request. Refers to [BentoML Documentation Guide](./docs/README.md) for how to build and write docs. - diff --git a/MANIFEST.in b/MANIFEST.in index 1bcf3f99be..4117e26c8c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,12 +1,17 @@ # All files tracked by Git are included in PyPI distribution +include bentoml/grpc/**/*_pb2*.py +include bentoml/grpc/**/*.pyi # Files to exclude in PyPI distribution exclude CONTRIBUTING.md GOVERNANCE.md CODE_OF_CONDUCT.md DEVELOPMENT.md SECURITY.md exclude Makefile MANIFEST.in -exclude *.yml +exclude *.yml *.yaml exclude .git* +exclude bentoml/grpc/buf.yaml +exclude bentoml/_internal/frameworks/FRAMEWORK_TEMPLATE_PY # Directories to exclude in PyPI package +prune .devcontainer prune requirements prune tests prune typings @@ -18,6 +23,7 @@ prune .git* prune */__pycache__ prune */.DS_Store prune */.ipynb_checkpoints +prune */README* # Patterns to exclude from any directory global-exclude *.py[cod] diff --git a/Makefile b/Makefile index d6ce033305..274e0139ca 100644 --- a/Makefile +++ b/Makefile @@ -8,19 +8,25 @@ USE_GPU ?= false help: ## Show all Makefile targets @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' +.PHONY: format format-proto lint lint-proto type style clean format: ## Running code formatter: black and isort @./scripts/tools/formatter.sh +format-proto: ## Running proto formatter: buf + @echo "Formatting proto files..." + docker run --init --rm --volume $(GIT_ROOT):/workspace --workdir /workspace bufbuild/buf format --config "/workspace/bentoml/grpc/buf.yaml" -w --path "bentoml/grpc" lint: ## Running lint checker: pylint @./scripts/tools/linter.sh +lint-proto: ## Running proto lint checker: buf + @echo "Linting proto files..." + docker run --init --rm --volume $(GIT_ROOT):/workspace --workdir /workspace bufbuild/buf lint --config "/workspace/bentoml/grpc/buf.yaml" --error-format msvs --path "bentoml/grpc" type: ## Running type checker: pyright @./scripts/tools/type_checker.sh +style: format lint format-proto lint-proto ## Running formatter and linter clean: ## Clean all generated files @echo "Cleaning all generated files..." @cd $(GIT_ROOT)/docs && make clean @cd $(GIT_ROOT) || exit 1 @find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete -hooks: __check_defined_FORCE ## Install pre-defined hooks - @./scripts/install_hooks.sh ci-%: @@ -33,8 +39,10 @@ ci-format: ci-black ci-isort ## Running format check in CI: black, isort .PHONY: ci-lint ci-lint: ci-pylint ## Running lint check in CI: pylint +.PHONY: tests-suite +tests-suite: tests-unit tests-http_server tests-grpc_server ## Running BentoML tests suite (unit, e2e, integration) -tests-%: check-defined-USE_GPU check-defined-USE_VERBOSE +tests-%: $(eval type :=$(subst tests-, , $@)) $(eval RUN_ARGS:=$(wordlist 2,$(words $(MAKECMDGOALS)),$(MAKECMDGOALS))) $(eval __positional:=$(foreach t, $(RUN_ARGS), -$(t))) @@ -78,15 +86,3 @@ install-spellchecker-deps: ## Inform users to install enchant depending on their @echo Make sure to install enchant from your distros package manager @exit 1 endif - -check_defined = $(strip $(foreach 1,$1, $(call __check_defined,$1,$(strip $(value 2))))) -__check_defined = \ - $(if $(value $1),, \ - $(error Undefined $1$(if $2, ($2))$(if $(value @), \ - required by target `$@`))) -check-defined-% : __check_defined_FORCE - $(eval $@_target := $(subst check-defined-, ,$@)) - @:$(call check_defined, $*, $@_target) - -.PHONY : __check_defined_FORCE -__check_defined_FORCE: diff --git a/bentoml/_internal/bento/build_config.py b/bentoml/_internal/bento/build_config.py index bd58f6cbce..906fba6a72 100644 --- a/bentoml/_internal/bento/build_config.py +++ b/bentoml/_internal/bento/build_config.py @@ -64,7 +64,7 @@ def _convert_python_version(py_version: str | None) -> str | None: target_python_version = f"{major}.{minor}" if target_python_version != py_version: logger.warning( - "BentoML will install the latest python%s instead of the specified version %s. To use the exact python version, use a custom docker base image. See https://docs.bentoml.org/en/latest/concepts/bento.html#custom-base-image-advanced", + "BentoML will install the latest 'python%s' instead of the specified 'python%s'. To use the exact python version, use a custom docker base image. See https://docs.bentoml.org/en/latest/concepts/bento.html#custom-base-image-advanced", target_python_version, py_version, ) @@ -165,20 +165,27 @@ def __attrs_post_init__(self): if self.base_image is not None: if self.distro is not None: logger.warning( - f"docker base_image {self.base_image} is used, 'distro={self.distro}' option is ignored.", + "docker base_image %s is used, 'distro=%s' option is ignored.", + self.base_image, + self.distro, ) if self.python_version is not None: logger.warning( - f"docker base_image {self.base_image} is used, 'python={self.python_version}' option is ignored.", + "docker base_image %s is used, 'python=%s' option is ignored.", + self.base_image, + self.python_version, ) if self.cuda_version is not None: logger.warning( - f"docker base_image {self.base_image} is used, 'cuda_version={self.cuda_version}' option is ignored.", + "docker base_image %s is used, 'cuda_version=%s' option is ignored.", + self.base_image, + self.cuda_version, ) if self.system_packages: logger.warning( - f"docker base_image {self.base_image} is used, " - f"'system_packages={self.system_packages}' option is ignored.", + "docker base_image %s is used, 'system_packages=%s' option is ignored.", + self.base_image, + self.system_packages, ) if self.distro is not None and self.cuda_version is not None: @@ -225,14 +232,14 @@ def write_to_bento( try: setup_script = resolve_user_filepath(self.setup_script, build_ctx) except FileNotFoundError as e: - raise InvalidArgument(f"Invalid setup_script file: {e}") + raise InvalidArgument(f"Invalid setup_script file: {e}") from None if not os.access(setup_script, os.X_OK): message = f"{setup_script} is not executable." if not psutil.WINDOWS: raise InvalidArgument( f"{message} Ensure the script has a shebang line, then run 'chmod +x {setup_script}'." - ) - raise InvalidArgument(message) + ) from None + raise InvalidArgument(message) from None copy_file_to_fs_folder( setup_script, bento_fs, docker_folder, "setup_script" ) @@ -428,11 +435,14 @@ class PythonOptions: def __attrs_post_init__(self): if self.requirements_txt and self.packages: logger.warning( - f'Build option python: `requirements_txt="{self.requirements_txt}"` found, will ignore the option: `packages="{self.packages}"`.' + "Build option python: 'requirements_txt={self.requirements_txt}' found, will ignore the option: 'packages=%s'.", + self.requirements_txt, + self.packages, ) if self.no_index and (self.index_url or self.extra_index_url): logger.warning( - f'Build option python: `no_index="{self.no_index}"` found, will ignore `index_url` and `extra_index_url` option when installing PyPI packages.' + "Build option python: 'no_index=%s' found, will ignore 'index_url' and 'extra_index_url' option when installing PyPI packages.", + self.no_index, ) def is_empty(self) -> bool: @@ -479,8 +489,10 @@ def write_to_bento(self, bento_fs: FS, build_ctx: str) -> None: pip_args.extend(self.pip_args.split()) with bento_fs.open(fs.path.combine(py_folder, "install.sh"), "w") as f: - args = " ".join(map(quote, pip_args)) if pip_args else "" - install_script_content = ( + args = ["--no-warn-script-location"] + if pip_args: + args.extend(pip_args) + install_sh = ( """\ #!/usr/bin/env bash set -exuo pipefail @@ -488,8 +500,8 @@ def write_to_bento(self, bento_fs: FS, build_ctx: str) -> None: # Parent directory https://stackoverflow.com/a/246128/8643197 BASEDIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]:-$0}"; )" &> /dev/null && pwd 2> /dev/null; )" -PIP_ARGS=(--no-warn-script-location """ - + args +PIP_ARGS=(""" + + " ".join(map(quote, args)) + """) # BentoML by default generates two requirement files: @@ -525,11 +537,11 @@ def write_to_bento(self, bento_fs: FS, build_ctx: str) -> None: echo "WARNING: using BentoML version ${existing_bentoml_version}" fi else -pip install bentoml=="$BENTOML_VERSION" + pip install bentoml=="$BENTOML_VERSION" fi """ ) - f.write(install_script_content) + f.write(install_sh) if self.requirements_txt is not None: requirements_txt_file = resolve_user_filepath( diff --git a/bentoml/_internal/bento/build_dev_bentoml_whl.py b/bentoml/_internal/bento/build_dev_bentoml_whl.py index ae7c28aaae..2b57662314 100644 --- a/bentoml/_internal/bento/build_dev_bentoml_whl.py +++ b/bentoml/_internal/bento/build_dev_bentoml_whl.py @@ -1,18 +1,17 @@ from __future__ import annotations import os -import shutil import logging -import tempfile - -from bentoml.exceptions import BentoMLException from ..utils.pkg import source_locations +from ...exceptions import BentoMLException +from ...exceptions import MissingDependencyException from ..configuration import is_pypi_installed_bentoml logger = logging.getLogger(__name__) BENTOML_DEV_BUILD = "BENTOML_BUNDLE_LOCAL_BUILD" +_exc_message = f"'{BENTOML_DEV_BUILD}=True', which requires the 'pypa/build' package. Install development dependencies with 'pip install -r requirements/dev-requirements.txt' and try again." def build_bentoml_editable_wheel(target_path: str) -> None: @@ -28,11 +27,27 @@ def build_bentoml_editable_wheel(target_path: str) -> None: # skip this entirely if BentoML is installed from PyPI return + try: + from build.env import IsolatedEnvBuilder + + from build import ProjectBuilder + except ModuleNotFoundError as e: + raise MissingDependencyException(_exc_message) from e + # Find bentoml module path module_location = source_locations("bentoml") if not module_location: raise BentoMLException("Could not find bentoml module location.") + try: + from importlib import import_module + + _ = import_module("bentoml.grpc.v1alpha1.service_pb2") + except ModuleNotFoundError: + raise ModuleNotFoundError( + f"Generated stubs are not found. Make sure to run '{module_location}/scripts/generate_grpc_stubs.sh' beforehand to generate gRPC stubs." + ) from None + pyproject = os.path.abspath(os.path.join(module_location, "..", "pyproject.toml")) # this is for BentoML developer to create Service containing custom development @@ -42,17 +57,12 @@ def build_bentoml_editable_wheel(target_path: str) -> None: logger.info( "BentoML is installed in `editable` mode; building BentoML distribution with the local BentoML code base. The built wheel file will be included in the target bento." ) - try: - from build import ProjectBuilder - except ModuleNotFoundError: - raise BentoMLException( - f"Environment variable {BENTOML_DEV_BUILD}=True detected, which requires the `pypa/build` package. Make sure to install all dev dependencies via `pip install -r requirements/dev-requirements.txt` and try again." - ) - - with tempfile.TemporaryDirectory() as dist_dir: + with IsolatedEnvBuilder() as env: builder = ProjectBuilder(os.path.dirname(pyproject)) - builder.build("wheel", dist_dir) - shutil.copytree(dist_dir, target_path) + builder.python_executable = env.executable + builder.scripts_dir = env.scripts_dir + env.install(builder.build_system_requires) + builder.build("wheel", target_path) else: logger.info( "Custom BentoML build is detected. For a Bento to use the same build at serving time, add your custom BentoML build to the pip packages list, e.g. `packages=['git+https://github.com/bentoml/bentoml.git@13dfb36']`" diff --git a/bentoml/_internal/bento/docker.py b/bentoml/_internal/bento/docker/__init__.py similarity index 96% rename from bentoml/_internal/bento/docker.py rename to bentoml/_internal/bento/docker/__init__.py index 433b25423e..0a7a954e3b 100644 --- a/bentoml/_internal/bento/docker.py +++ b/bentoml/_internal/bento/docker/__init__.py @@ -6,8 +6,8 @@ import attr -from ...exceptions import InvalidArgument -from ...exceptions import BentoMLException +from bentoml.exceptions import InvalidArgument +from bentoml.exceptions import BentoMLException if TYPE_CHECKING: P = t.ParamSpec("P") @@ -109,7 +109,7 @@ class DistroSpec: ), ) - supported_cuda_versions: t.List[str] = attr.field( + supported_cuda_versions: t.Optional[t.List[str]] = attr.field( default=None, validator=attr.validators.optional( attr.validators.deep_iterable( diff --git a/bentoml/_internal/bento/docker/entrypoint.sh b/bentoml/_internal/bento/docker/entrypoint.sh index fd29ad0276..0f33ea19bf 100755 --- a/bentoml/_internal/bento/docker/entrypoint.sh +++ b/bentoml/_internal/bento/docker/entrypoint.sh @@ -10,19 +10,23 @@ _is_sourced() { } _main() { - # for backwards compatibility with the yatai<1.0.0, adapting the old "yatai" command to the new "start" command - if [ "${#}" -gt 0 ] && [ "${1}" = 'python' ] && [ "${2}" = '-m' ] && ([ "${3}" = 'bentoml._internal.server.cli.runner' ] || [ "${3}" = "bentoml._internal.server.cli.api_server" ]); then + # For backwards compatibility with the yatai<1.0.0, adapting the old "yatai" command to the new "start" command. + if [ "${#}" -gt 0 ] && [ "${1}" = 'python' ] && [ "${2}" = '-m' ] && { [ "${3}" = 'bentoml._internal.server.cli.runner' ] || [ "${3}" = "bentoml._internal.server.cli.api_server" ]; }; then # SC2235, use { } to avoid subshell overhead if [ "${3}" = 'bentoml._internal.server.cli.runner' ]; then set -- bentoml start-runner-server "${@:4}" elif [ "${3}" = 'bentoml._internal.server.cli.api_server' ]; then set -- bentoml start-http-server "${@:4}" fi - # if no arg or first arg looks like a flag - elif [ -z "$@" ] || [ "${1:0:1}" = '-' ]; then + # If no arg or first arg looks like a flag. + elif [[ "$#" -eq 0 ]] || [[ "${1:0:1}" =~ '-' ]]; then + # This is provided for backwards compatibility with places where user may have + # discover this easter egg and use it in their scripts to run the container. if [[ -v BENTOML_SERVE_COMPONENT ]]; then echo "\$BENTOML_SERVE_COMPONENT is set! Calling 'bentoml start-*' instead" if [ "${BENTOML_SERVE_COMPONENT}" = 'http_server' ]; then - set -- bentoml start-rest-server "$@" "$BENTO_PATH" + set -- bentoml start-http-server "$@" "$BENTO_PATH" + elif [ "${BENTOML_SERVE_COMPONENT}" = 'grpc_server' ]; then + set -- bentoml start-grpc-server "$@" "$BENTO_PATH" elif [ "${BENTOML_SERVE_COMPONENT}" = 'runner' ]; then set -- bentoml start-runner-server "$@" "$BENTO_PATH" fi @@ -30,13 +34,21 @@ _main() { set -- bentoml serve --production "$@" "$BENTO_PATH" fi fi - - # Overide the BENTOML_PORT if PORT env var is present. Used for Heroku **and Yatai** + # Overide the BENTOML_PORT if PORT env var is present. Used for Heroku and Yatai. if [[ -v PORT ]]; then echo "\$PORT is set! Overiding \$BENTOML_PORT with \$PORT ($PORT)" export BENTOML_PORT=$PORT fi - exec "$@" + # Handle serve and start commands that is passed to the container. + # Assuming that serve and start commands are the first arguments + # Note that this is the recommended way going forward to run all bentoml containers. + if [ "${#}" -gt 0 ] && { [ "${1}" = 'serve' ] || [ "${1}" = 'serve-http' ] || [ "${1}" = 'serve-grpc' ] || [ "${1}" = 'start-http-server' ] || [ "${1}" = 'start-grpc-server' ] || [ "${1}" = 'start-runner-server' ]; }; then + exec bentoml "$@" "$BENTO_PATH" + else + # otherwise default to run whatever the command is + # This should allow running bash, sh, python, etc + exec "$@" + fi } if ! _is_sourced; then diff --git a/bentoml/_internal/bento/docker/templates/base.j2 b/bentoml/_internal/bento/docker/templates/base.j2 index c6bf21e4ea..5c72e34024 100644 --- a/bentoml/_internal/bento/docker/templates/base.j2 +++ b/bentoml/_internal/bento/docker/templates/base.j2 @@ -2,7 +2,7 @@ {# users can use these values #} {% import '_macros.j2' as common %} {% set bento__entrypoint = bento__entrypoint | default(expands_bento_path("env", "docker", "entrypoint.sh", bento_path=bento__path)) %} -# syntax = docker/dockerfile:1.4-labs +# syntax = docker/dockerfile:1.4.3 # # =========================================== # @@ -12,7 +12,7 @@ # Block SETUP_BENTO_BASE_IMAGE {% block SETUP_BENTO_BASE_IMAGE %} -FROM {{ __base_image__ }} +FROM {{ __base_image__ }} as base-{{ __options__distro }} ENV LANG=C.UTF-8 @@ -21,7 +21,6 @@ ENV LC_ALL=C.UTF-8 ENV PYTHONIOENCODING=UTF-8 ENV PYTHONUNBUFFERED=1 - {% endblock %} # Block SETUP_BENTO_USER @@ -34,7 +33,6 @@ RUN groupadd -g $BENTO_USER_GID -o $BENTO_USER && useradd -m -u $BENTO_USER_UID {% block SETUP_BENTO_ENVARS %} {% if __options__env is not none %} {% for key, value in __options__env.items() -%} - ENV {{ key }}={{ value }} {% endfor -%} {% endif -%} @@ -46,25 +44,20 @@ ENV BENTOML_HOME={{ bento__home }} RUN mkdir $BENTO_PATH && chown {{ bento__user }}:{{ bento__user }} $BENTO_PATH -R WORKDIR $BENTO_PATH -# init related components COPY --chown={{ bento__user }}:{{ bento__user }} . ./ - {% endblock %} # Block SETUP_BENTO_COMPONENTS {% block SETUP_BENTO_COMPONENTS %} - {% set __install_python_scripts__ = expands_bento_path("env", "python", "install.sh", bento_path=bento__path) %} {% set __pip_cache__ = common.mount_cache("/root/.cache/pip") %} # install python packages with install.sh RUN {{ __pip_cache__ }} bash -euxo pipefail {{ __install_python_scripts__ }} - {% if __options__setup_script is not none %} {% set __setup_script__ = expands_bento_path("env", "docker", "setup_script", bento_path=bento__path) %} RUN chmod +x {{ __setup_script__ }} RUN {{ __setup_script__ }} {% endif %} - {% endblock %} # Block SETUP_BENTO_ENTRYPOINT @@ -72,6 +65,9 @@ RUN {{ __setup_script__ }} # Default port for BentoServer EXPOSE 3000 +# Expose Prometheus port +EXPOSE {{ __prometheus_port__ }} + RUN chmod +x {{ bento__entrypoint }} USER bentoml diff --git a/bentoml/_internal/bento/gen.py b/bentoml/_internal/bento/gen.py index 65db6a1a44..0c4935a02f 100644 --- a/bentoml/_internal/bento/gen.py +++ b/bentoml/_internal/bento/gen.py @@ -5,14 +5,16 @@ import logging from sys import version_info from typing import TYPE_CHECKING +from dataclasses import asdict +from dataclasses import dataclass -import attr from jinja2 import Environment from jinja2.loaders import FileSystemLoader +from ..utils import bentoml_cattr from ..utils import resolve_user_filepath from .docker import DistroSpec -from ..configuration import CLEAN_BENTOML_VERSION +from ..configuration.containers import BentoMLContainer logger = logging.getLogger(__name__) @@ -22,7 +24,7 @@ from .build_config import DockerOptions TemplateFunc = t.Callable[[DockerOptions], t.Dict[str, t.Any]] - GenericFunc = t.Callable[P, t.Any] + F = t.Callable[P, t.Any] BENTO_UID_GID = 1034 BENTO_USER = "bentoml" @@ -45,33 +47,34 @@ def expands_bento_path(*path: str, bento_path: str = BENTO_PATH) -> str: return "/".join([bento_path, *path]) -J2_FUNCTION: dict[str, GenericFunc[t.Any]] = { - "expands_bento_path": expands_bento_path, -} +J2_FUNCTION: dict[str, F[t.Any]] = {"expands_bento_path": expands_bento_path} + +to_preserved_field: t.Callable[[str], str] = lambda s: f"__{s}__" +to_bento_field: t.Callable[[str], str] = lambda s: f"bento__{s}" +to_options_field: t.Callable[[str], str] = lambda s: f"__options__{s}" -@attr.frozen(on_setattr=None, eq=False, repr=False) +@dataclass class ReservedEnv: base_image: str - supported_architectures: list[str] - bentoml_version: str = attr.field(default=CLEAN_BENTOML_VERSION) - python_version: str = attr.field( - default=f"{version_info.major}.{version_info.minor}" - ) + python_version: str = f"{version_info.major}.{version_info.minor}" - def todict(self): - return {f"__{k}__": v for k, v in attr.asdict(self).items()} + def asdict(self) -> dict[str, t.Any]: + return { + **{to_preserved_field(k): v for k, v in asdict(self).items()}, + "__prometheus_port__": BentoMLContainer.grpc.metrics.port.get(), + } -@attr.frozen(on_setattr=None, eq=False, repr=False) +@dataclass class CustomizableEnv: - uid_gid: int = attr.field(default=BENTO_UID_GID) - user: str = attr.field(default=BENTO_USER) - home: str = attr.field(default=BENTO_HOME) - path: str = attr.field(default=BENTO_PATH) + uid_gid: int = BENTO_UID_GID + user: str = BENTO_USER + home: str = BENTO_HOME + path: str = BENTO_PATH - def todict(self) -> dict[str, str]: - return {f"bento__{k}": v for k, v in attr.asdict(self).items()} + def asdict(self) -> dict[str, t.Any]: + return {to_bento_field(k): v for k, v in asdict(self).items()} def get_templates_variables( @@ -85,6 +88,10 @@ def get_templates_variables( distro = options.distro cuda_version = options.cuda_version python_version = options.python_version + + # these values will be set at with_defaults() if not provided + # so distro and python_version won't be None here. + assert distro and python_version spec = DistroSpec.from_distro( distro, cuda=cuda_version is not None, conda=use_conda ) @@ -97,28 +104,32 @@ def get_templates_variables( else: python_version = python_version base_image = spec.image.format(spec_version=python_version) - supported_architecture = spec.supported_architectures else: base_image = options.base_image - # TODO: allow user to specify supported architectures of the base image - supported_architecture = ["amd64"] logger.info( - f"BentoML will not install Python to custom base images; ensure the base image '{base_image}' has Python installed." + "BentoML will not install Python to custom base images; ensure the base image '%s' has Python installed.", + base_image, ) # environment returns are - # __base_image__, __supported_architectures__, __bentoml_version__, __python_version_full__ + # __base_image__, __python_version__, __prometheus_port__ # bento__uid_gid, bento__user, bento__home, bento__path # __options__distros, __options__base_image, __options_env, __options_system_packages, __options_setup_script return { - **{f"__options__{k}": v for k, v in attr.asdict(options).items()}, - **CustomizableEnv().todict(), - **ReservedEnv(base_image, supported_architecture).todict(), + **{ + to_options_field(k): v + for k, v in bentoml_cattr.unstructure(options).items() + }, + **CustomizableEnv().asdict(), + **ReservedEnv(base_image=base_image).asdict(), } def generate_dockerfile( - options: DockerOptions, build_ctx: str, *, use_conda: bool + options: DockerOptions, + build_ctx: str, + *, + use_conda: bool, ) -> str: """ Generate a Dockerfile that containerize a Bento. @@ -179,6 +190,7 @@ def generate_dockerfile( if options.base_image is not None: base = "base.j2" else: + assert distro # distro will be set via 'with_defaults()' spec = DistroSpec.from_distro(distro, cuda=use_cuda, conda=use_conda) base = f"{spec.release_type}_{distro}.j2" diff --git a/bentoml/_internal/configuration/__init__.py b/bentoml/_internal/configuration/__init__.py index 8061ec6dcc..bbc1144108 100644 --- a/bentoml/_internal/configuration/__init__.py +++ b/bentoml/_internal/configuration/__init__.py @@ -5,6 +5,7 @@ from functools import lru_cache from bentoml.exceptions import BentoMLException +from bentoml.exceptions import BentoMLConfigException try: import importlib.metadata as importlib_metadata @@ -29,6 +30,8 @@ class version_mod: DEBUG_ENV_VAR = "BENTOML_DEBUG" QUIET_ENV_VAR = "BENTOML_QUIET" CONFIG_ENV_VAR = "BENTOML_CONFIG" +# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md +GRPC_DEBUG_ENV_VAR = "GRPC_VERBOSITY" def expand_env_var(env_var: str) -> str: @@ -96,6 +99,7 @@ def get_bentoml_config_file_from_env() -> t.Optional[str]: def set_debug_mode(enabled: bool) -> None: os.environ[DEBUG_ENV_VAR] = str(enabled) + os.environ[GRPC_DEBUG_ENV_VAR] = "DEBUG" logger.info( f"{'Enabling' if enabled else 'Disabling'} debug mode for current BentoML session" @@ -109,9 +113,9 @@ def get_debug_mode() -> bool: def set_quiet_mode(enabled: bool) -> None: - os.environ[DEBUG_ENV_VAR] = str(enabled) - # do not log setting quiet mode + os.environ[QUIET_ENV_VAR] = str(enabled) + os.environ[GRPC_DEBUG_ENV_VAR] = "NONE" def get_quiet_mode() -> bool: @@ -131,15 +135,15 @@ def load_global_config(bentoml_config_file: t.Optional[str] = None): if bentoml_config_file: if not bentoml_config_file.endswith((".yml", ".yaml")): - raise Exception( + raise BentoMLConfigException( "BentoML config file specified in ENV VAR does not end with `.yaml`: " f"`BENTOML_CONFIG={bentoml_config_file}`" - ) + ) from None if not os.path.isfile(bentoml_config_file): raise FileNotFoundError( "BentoML config file specified in ENV VAR not found: " f"`BENTOML_CONFIG={bentoml_config_file}`" - ) + ) from None bentoml_configuration = BentoMLConfiguration( override_config_file=bentoml_config_file, diff --git a/bentoml/_internal/configuration/containers.py b/bentoml/_internal/configuration/containers.py index 35da7fdfe8..b83c728d89 100644 --- a/bentoml/_internal/configuration/containers.py +++ b/bentoml/_internal/configuration/containers.py @@ -54,20 +54,19 @@ "grpc", "http", ) -_check_sample_rate: t.Callable[[float], None] = ( - lambda sample_rate: logger.warning( - "Tracing enabled, but sample_rate is unset or zero. No traces will be collected. " - "Please refer to https://docs.bentoml.org/en/latest/guides/tracing.html for more details." - ) - if sample_rate == 0.0 - else None -) _larger_than: t.Callable[[int | float], t.Callable[[int | float], bool]] = ( lambda target: lambda val: val > target ) _larger_than_zero: t.Callable[[int | float], bool] = _larger_than(0) +def _check_sample_rate(sample_rate: float) -> None: + if sample_rate == 0.0: + logger.warning( + "Tracing enabled, but sample_rate is unset or zero. No traces will be collected. Please refer to https://docs.bentoml.org/en/latest/guides/tracing.html for more details." + ) + + def _is_ip_address(addr: str) -> bool: import socket @@ -109,12 +108,9 @@ def _is_ip_address(addr: str) -> bool: SCHEMA = Schema( { "api_server": { - "port": And(int, _larger_than_zero), - "host": And(str, _is_ip_address), - "backlog": And(int, _larger_than(64)), "workers": Or(And(int, _larger_than_zero), None), "timeout": And(int, _larger_than_zero), - "max_request_size": And(int, _larger_than_zero), + "backlog": And(int, _larger_than(64)), Optional("ssl"): { Optional("certfile"): Or(str, None), Optional("keyfile"): Or(str, None), @@ -143,14 +139,30 @@ def _is_ip_address(addr: str) -> bool: "response_content_type": Or(bool, None), }, }, - "cors": { - "enabled": bool, - "access_control_allow_origin": Or(str, None), - "access_control_allow_credentials": Or(bool, None), - "access_control_allow_headers": Or([str], str, None), - "access_control_allow_methods": Or([str], str, None), - "access_control_max_age": Or(int, None), - "access_control_expose_headers": Or([str], str, None), + "http": { + "host": And(str, _is_ip_address), + "port": And(int, _larger_than_zero), + "cors": { + "enabled": bool, + "access_control_allow_origin": Or(str, None), + "access_control_allow_credentials": Or(bool, None), + "access_control_allow_headers": Or([str], str, None), + "access_control_allow_methods": Or([str], str, None), + "access_control_max_age": Or(int, None), + "access_control_expose_headers": Or([str], str, None), + }, + }, + "grpc": { + "host": And(str, _is_ip_address), + "port": And(int, _larger_than_zero), + "metrics": { + "port": And(int, _larger_than_zero), + "host": And(str, _is_ip_address), + }, + "reflection": {"enabled": bool}, + "max_concurrent_streams": Or(int, None), + "max_message_length": Or(int, None), + "maximum_concurrent_rpcs": Or(int, None), }, }, "runners": { @@ -193,6 +205,10 @@ def _is_ip_address(addr: str) -> bool: } ) +_WARNING_MESSAGE = ( + "field 'api_server.%s' is deprecated and has been renamed to 'api_server.http.%s'" +) + class BentoMLConfiguration: def __init__( @@ -224,7 +240,35 @@ def __init__( f"Config file {override_config_file} not found" ) with open(override_config_file, "rb") as f: - override_config = yaml.safe_load(f) + override_config: dict[str, t.Any] = yaml.safe_load(f) + + # compatibility layer with old configuration pre gRPC features + # api_server.[cors|port|host] -> api_server.http.$^ + if "api_server" in override_config: + user_api_config = override_config["api_server"] + # max_request_size is deprecated + if "max_request_size" in user_api_config: + logger.warning( + "'api_server.max_request_size' is deprecated and has become obsolete." + ) + user_api_config.pop("max_request_size") + # check if user are using older configuration + if "http" not in user_api_config: + user_api_config["http"] = {} + # then migrate these fields to newer configuration fields. + for field in ["port", "host", "cors"]: + if field in user_api_config: + old_field = user_api_config.pop(field) + user_api_config["http"][field] = old_field + logger.warning(_WARNING_MESSAGE, field, field) + + config_merger.merge(override_config["api_server"], user_api_config) + + assert all( + key not in override_config["api_server"] + for key in ["cors", "max_request_size", "host", "port"] + ) + config_merger.merge(self.config, override_config) global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS} @@ -250,9 +294,13 @@ def __init__( def override(self, keys: t.List[str], value: t.Any): if keys is None: - raise BentoMLConfigException("Configuration override key is None.") + raise BentoMLConfigException( + "Configuration override key is None." + ) from None if len(keys) == 0: - raise BentoMLConfigException("Configuration override key is empty.") + raise BentoMLConfigException( + "Configuration override key is empty." + ) from None if value is None: return @@ -261,7 +309,7 @@ def override(self, keys: t.List[str], value: t.Any): if key not in c: raise BentoMLConfigException( "Configuration override key is invalid, %s" % keys - ) + ) from None c = c[key] c[keys[-1]] = value @@ -330,6 +378,9 @@ def session_id() -> str: api_server_config = config.api_server runners_config = config.runners + grpc = api_server_config.grpc + http = api_server_config.http + development_mode = providers.Static(True) @providers.SingletonFactory @@ -342,23 +393,20 @@ def serve_info() -> ServeInfo: @providers.SingletonFactory @staticmethod def access_control_options( - allow_origins: t.List[str] = Provide[ - config.api_server.cors.access_control_allow_origin - ], - allow_credentials: t.List[str] = Provide[ - config.api_server.cors.access_control_allow_credentials - ], - expose_headers: t.List[str] = Provide[ - config.api_server.cors.access_control_expose_headers - ], - allow_methods: t.List[str] = Provide[ - config.api_server.cors.access_control_allow_methods - ], - allow_headers: t.List[str] = Provide[ - config.api_server.cors.access_control_allow_headers - ], - max_age: int = Provide[config.api_server.cors.access_control_max_age], - ) -> t.Dict[str, t.Union[t.List[str], int]]: + allow_origins: str | None = Provide[http.cors.access_control_allow_origin], + allow_credentials: bool + | None = Provide[http.cors.access_control_allow_credentials], + expose_headers: list[str] + | str + | None = Provide[http.cors.access_control_expose_headers], + allow_methods: list[str] + | str + | None = Provide[http.cors.access_control_allow_methods], + allow_headers: list[str] + | str + | None = Provide[http.cors.access_control_allow_headers], + max_age: int | None = Provide[http.cors.access_control_max_age], + ) -> dict[str, list[str] | str | int]: kwargs = dict( allow_origins=allow_origins, allow_credentials=allow_credentials, @@ -368,15 +416,15 @@ def access_control_options( max_age=max_age, ) - filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + filtered_kwargs: dict[str, list[str] | str | int] = { + k: v for k, v in kwargs.items() if v is not None + } return filtered_kwargs api_server_workers = providers.Factory[int]( lambda workers: workers or (multiprocessing.cpu_count() // 2) + 1, - config.api_server.workers, + api_server_config.workers, ) - service_port = config.api_server.port - service_host = config.api_server.host prometheus_multiproc_dir = providers.Factory[str]( os.path.join, @@ -388,7 +436,7 @@ def access_control_options( @staticmethod def metrics_client( multiproc_dir: str = Provide[prometheus_multiproc_dir], - ) -> "PrometheusClient": + ) -> PrometheusClient: from ..server.metrics.prometheus import PrometheusClient return PrometheusClient(multiproc_dir=multiproc_dir) @@ -414,6 +462,7 @@ def tracer_provider( from opentelemetry.sdk.environment_variables import OTEL_SERVICE_NAME from opentelemetry.sdk.environment_variables import OTEL_RESOURCE_ATTRIBUTES + from ...exceptions import InvalidArgument from ..utils.telemetry import ParentBasedTraceIdRatio if sample_rate is None: @@ -442,15 +491,10 @@ def tracer_provider( ) if tracer_type == "zipkin" and zipkin_server_url is not None: - # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290 - from opentelemetry.exporter.zipkin.json import ( - ZipkinExporter, # type: ignore (no opentelemetry types) - ) + from opentelemetry.exporter.zipkin.json import ZipkinExporter - exporter = ZipkinExporter( # type: ignore (no opentelemetry types) - endpoint=zipkin_server_url, - ) - provider.add_span_processor(BatchSpanProcessor(exporter)) # type: ignore (no opentelemetry types) + exporter = ZipkinExporter(endpoint=zipkin_server_url) + provider.add_span_processor(BatchSpanProcessor(exporter)) _check_sample_rate(sample_rate) return provider elif ( @@ -458,16 +502,12 @@ def tracer_provider( and jaeger_server_address is not None and jaeger_server_port is not None ): - # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290 - from opentelemetry.exporter.jaeger.thrift import ( - JaegerExporter, # type: ignore (no opentelemetry types) - ) + from opentelemetry.exporter.jaeger.thrift import JaegerExporter - exporter = JaegerExporter( # type: ignore (no opentelemetry types) - agent_host_name=jaeger_server_address, - agent_port=jaeger_server_port, + exporter = JaegerExporter( + agent_host_name=jaeger_server_address, agent_port=jaeger_server_port ) - provider.add_span_processor(BatchSpanProcessor(exporter)) # type: ignore (no opentelemetry types) + provider.add_span_processor(BatchSpanProcessor(exporter)) _check_sample_rate(sample_rate) return provider elif ( @@ -476,21 +516,16 @@ def tracer_provider( and otlp_server_url is not None ): if otlp_server_protocol == "grpc": - # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290 - from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( - OTLPSpanExporter, # type: ignore (no opentelemetry types) - ) + from opentelemetry.exporter.otlp.proto.grpc import trace_exporter elif otlp_server_protocol == "http": - # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290 - from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter, # type: ignore (no opentelemetry types) - ) - - exporter = OTLPSpanExporter( # type: ignore (no opentelemetry types) - endpoint=otlp_server_url, - ) - provider.add_span_processor(BatchSpanProcessor(exporter)) # type: ignore (no opentelemetry types) + from opentelemetry.exporter.otlp.proto.http import trace_exporter + else: + raise InvalidArgument( + f"Invalid otlp protocol: {otlp_server_protocol}" + ) from None + exporter = trace_exporter.OTLPSpanExporter(endpoint=otlp_server_url) + provider.add_span_processor(BatchSpanProcessor(exporter)) _check_sample_rate(sample_rate) return provider else: @@ -499,9 +534,7 @@ def tracer_provider( @providers.SingletonFactory @staticmethod def tracing_excluded_urls( - excluded_urls: t.Optional[t.Union[str, t.List[str]]] = Provide[ - config.tracing.excluded_urls - ], + excluded_urls: str | list[str] | None = Provide[config.tracing.excluded_urls], ): from opentelemetry.util.http import ExcludeList from opentelemetry.util.http import parse_excluded_urls @@ -520,7 +553,7 @@ def tracing_excluded_urls( @providers.SingletonFactory @staticmethod def duration_buckets( - metrics: dict[str, t.Any] = Provide[config.api_server.metrics] + metrics: dict[str, t.Any] = Provide[api_server_config.metrics] ) -> tuple[float, ...]: """ Returns a tuple of duration buckets in seconds. If not explicitly configured, diff --git a/bentoml/_internal/configuration/default_configuration.yaml b/bentoml/_internal/configuration/default_configuration.yaml index 4d7cba78f4..8fefc0a7f5 100644 --- a/bentoml/_internal/configuration/default_configuration.yaml +++ b/bentoml/_internal/configuration/default_configuration.yaml @@ -1,42 +1,53 @@ api_server: - port: 3000 - host: 0.0.0.0 - backlog: 2048 workers: 1 timeout: 60 - max_request_size: 20971520 + backlog: 2048 metrics: - enabled: True + enabled: true namespace: bentoml_api_server logging: access: - enabled: True - request_content_length: True - request_content_type: True - response_content_length: True - response_content_type: True - cors: - enabled: False - access_control_allow_origin: Null - access_control_allow_credentials: Null - access_control_allow_methods: Null - access_control_allow_headers: Null - access_control_max_age: Null - access_control_expose_headers: Null + enabled: true + request_content_length: true + request_content_type: true + response_content_length: true + response_content_type: true + http: + host: 0.0.0.0 + port: 3000 + cors: + enabled: false + access_control_allow_origin: ~ + access_control_allow_credentials: ~ + access_control_allow_methods: ~ + access_control_allow_headers: ~ + access_control_max_age: ~ + access_control_expose_headers: ~ + grpc: + host: 0.0.0.0 + port: 3000 + max_concurrent_streams: ~ + maximum_concurrent_rpcs: ~ + max_message_length: -1 + reflection: + enabled: false + metrics: + host: 0.0.0.0 + port: 3001 runners: batching: - enabled: True + enabled: true max_batch_size: 100 max_latency_ms: 10000 resources: ~ logging: access: - enabled: True - request_content_length: True - request_content_type: True - response_content_length: True - response_content_type: True + enabled: true + request_content_length: true + request_content_type: true + response_content_length: true + response_content_type: true metrics: enabled: True namespace: bentoml_runner @@ -44,16 +55,16 @@ runners: tracing: type: zipkin - sample_rate: Null - excluded_urls: Null + sample_rate: ~ + excluded_urls: ~ zipkin: - url: Null + url: ~ jaeger: - address: Null - port: Null + address: ~ + port: ~ otlp: - protocol: Null - url: Null + protocol: ~ + url: ~ logging: formatting: diff --git a/bentoml/_internal/external_typing/__init__.py b/bentoml/_internal/external_typing/__init__.py index de2e56dcf1..9087543eb1 100644 --- a/bentoml/_internal/external_typing/__init__.py +++ b/bentoml/_internal/external_typing/__init__.py @@ -4,19 +4,22 @@ if TYPE_CHECKING: from typing import Literal - from pandas import Series as _PdSeries + F = t.Callable[..., t.Any] + + from pandas import Series as PdSeries from pandas import DataFrame as PdDataFrame + from pandas._typing import Dtype as PdDType + from pandas._typing import DtypeArg as PdDTypeArg from pyarrow.plasma import ObjectID from pyarrow.plasma import PlasmaClient - PdSeries = _PdSeries[t.Any] DataFrameOrient = Literal["split", "records", "index", "columns", "values", "table"] SeriesOrient = Literal["split", "records", "index", "table"] # numpy is always required by bentoml from numpy import generic as NpGeneric from numpy.typing import NDArray as _NDArray - from numpy.typing import DTypeLike as NpDTypeLike # type: ignore (incomplete numpy types) + from numpy.typing import DTypeLike as NpDTypeLike NpNDArray = _NDArray[t.Any] @@ -30,9 +33,13 @@ from .starlette import ASGIReceive from .starlette import AsgiMiddleware + WSGIApp = t.Callable[[F, t.Mapping[str, t.Any]], t.Iterable[bytes]] + __all__ = [ "PdSeries", "PdDataFrame", + "PdDType", + "PdDTypeArg", "DataFrameOrient", "SeriesOrient", "ObjectID", @@ -51,4 +58,6 @@ "ASGISend", "ASGIReceive", "ASGIMessage", + # misc + "WSGIApp", ] diff --git a/bentoml/_internal/io_descriptors/base.py b/bentoml/_internal/io_descriptors/base.py index 09c669b698..4c090ba1aa 100644 --- a/bentoml/_internal/io_descriptors/base.py +++ b/bentoml/_internal/io_descriptors/base.py @@ -8,10 +8,11 @@ if TYPE_CHECKING: from types import UnionType - from typing_extensions import Self from starlette.requests import Request from starlette.responses import Response + from bentoml.grpc.types import ProtoField + from ..types import LazyType from ..context import InferenceApiContext as Context from ..service.openapi.specification import Schema @@ -39,20 +40,12 @@ class IODescriptor(ABC, t.Generic[IOType]): HTTP_METHODS = ["POST"] - _init_str: str = "" - _mime_type: str - - def __new__(cls: t.Type[Self], *args: t.Any, **kwargs: t.Any) -> Self: - self = super().__new__(cls) - # default mime type is application/json - self._mime_type = "application/json" - self._init_str = cls.__qualname__ - - return self + _rpc_content_type: str = "application/grpc" + _proto_fields: tuple[ProtoField] def __repr__(self) -> str: - return self._init_str + return self.__class__.__qualname__ @abstractmethod def input_type(self) -> InputType: @@ -83,3 +76,11 @@ async def to_http_response( self, obj: IOType, ctx: Context | None = None ) -> Response: ... + + @abstractmethod + async def from_proto(self, field: t.Any) -> IOType: + ... + + @abstractmethod + async def to_proto(self, obj: IOType) -> t.Any: + ... diff --git a/bentoml/_internal/io_descriptors/file.py b/bentoml/_internal/io_descriptors/file.py index 5a21d0ab90..452088578c 100644 --- a/bentoml/_internal/io_descriptors/file.py +++ b/bentoml/_internal/io_descriptors/file.py @@ -13,6 +13,7 @@ from .base import IODescriptor from ..types import FileLike from ..utils.http import set_cookies +from ...exceptions import BadInput from ...exceptions import BentoMLException from ..service.openapi import SUCCESS_DESCRIPTION from ..service.openapi.specification import Schema @@ -23,10 +24,17 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from ..context import InferenceApiContext as Context FileKind: t.TypeAlias = t.Literal["binaryio", "textio"] -FileType: t.TypeAlias = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]] +else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + +FileType = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]] class File(IODescriptor[FileType]): @@ -100,16 +108,16 @@ async def predict(input_pdf: io.BytesIO[Any]) -> NDArray[Any]: """ + _proto_fields = ("file",) + def __new__( # pylint: disable=arguments-differ # returning subclass from new cls, kind: FileKind = "binaryio", mime_type: str | None = None ) -> File: mime_type = mime_type if mime_type is not None else "application/octet-stream" - if kind == "binaryio": res = object.__new__(BytesIOFile) else: raise ValueError(f"invalid File kind '{kind}'") - res._mime_type = mime_type return res @@ -134,11 +142,7 @@ def openapi_responses(self) -> OpenAPIResponse: content={self._mime_type: MediaType(schema=self.openapi_schema())}, ) - async def to_http_response( - self, - obj: FileType, - ctx: Context | None = None, - ): + async def to_http_response(self, obj: FileType, ctx: Context | None = None): if isinstance(obj, bytes): body = obj else: @@ -155,6 +159,31 @@ async def to_http_response( res = Response(body) return res + async def to_proto(self, obj: FileType) -> pb.File: + from bentoml.grpc.utils import mimetype_to_filetype_pb_map + + if isinstance(obj, bytes): + body = obj + else: + body = obj.read() + + try: + kind = mimetype_to_filetype_pb_map()[self._mime_type] + except KeyError: + raise BadInput( + f"{self._mime_type} doesn't have a corresponding File 'kind'" + ) from None + + return pb.File(kind=kind, content=body) + + if TYPE_CHECKING: + + async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]: + ... + + async def from_http_request(self, request: Request) -> t.IO[bytes]: + ... + class BytesIOFile(File): async def from_http_request(self, request: Request) -> t.IO[bytes]: @@ -183,3 +212,29 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]: raise BentoMLException( f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead" ) + + async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]: + from bentoml.grpc.utils import filetype_pb_to_mimetype_map + + mapping = filetype_pb_to_mimetype_map() + # check if the request message has the correct field + if isinstance(field, bytes): + content = field + else: + assert isinstance(field, pb.File) + if field.kind: + try: + mime_type = mapping[field.kind] + if mime_type != self._mime_type: + raise BadInput( + f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'", + ) + except KeyError: + raise BadInput( + f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb.File.FileType.items()]}", + ) from None + content = field.content + if not content: + raise BadInput("Content is empty!") from None + + return FileLike[bytes](io.BytesIO(content), "") diff --git a/bentoml/_internal/io_descriptors/image.py b/bentoml/_internal/io_descriptors/image.py index af49128e0c..88502cea07 100644 --- a/bentoml/_internal/io_descriptors/image.py +++ b/bentoml/_internal/io_descriptors/image.py @@ -15,7 +15,6 @@ from ..utils.http import set_cookies from ...exceptions import BadInput from ...exceptions import InvalidArgument -from ...exceptions import InternalServerError from ..service.openapi import SUCCESS_DESCRIPTION from ..service.openapi.specification import Schema from ..service.openapi.specification import Response as OpenAPIResponse @@ -25,8 +24,11 @@ if TYPE_CHECKING: from types import UnionType + import PIL import PIL.Image + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from .. import external_typing as ext from ..context import InferenceApiContext as Context @@ -34,17 +36,21 @@ "1", "CMYK", "F", "HSV", "I", "L", "LAB", "P", "RGB", "RGBA", "RGBX", "YCbCr" ] else: + from bentoml.grpc.utils import import_generated_stubs # NOTE: pillow-simd only benefits users who want to do preprocessing # TODO: add options for users to choose between simd and native mode - _exc = f"'Pillow' is required to use {__name__}. Install with: 'pip install -U Pillow'." + _exc = "'Pillow' is required to use the Image IO descriptor. Install it with: 'pip install -U Pillow'." PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=_exc) PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=_exc) + pb, _ = import_generated_stubs() + + # NOTES: we will keep type in quotation to avoid backward compatibility # with numpy < 1.20, since we will use the latest stubs from the main branch of numpy. # that enable a new way to type hint an ndarray. -ImageType: t.TypeAlias = t.Union["PIL.Image.Image", "ext.NpNDArray"] +ImageType = t.Union["PIL.Image.Image", "ext.NpNDArray"] DEFAULT_PIL_MODE = "RGB" @@ -135,32 +141,26 @@ async def predict_image(f: Image) -> NDArray[Any]: :obj:`Image`: IO Descriptor that either a :code:`PIL.Image.Image` or a :code:`np.ndarray` representing an image. """ - MIME_EXT_MAPPING: t.Dict[str, str] = {} + MIME_EXT_MAPPING: dict[str, str] = {} + + _proto_fields = ("file",) def __init__( self, pilmode: _Mode | None = DEFAULT_PIL_MODE, mime_type: str = "image/jpeg", ): - try: - import PIL.Image - except ImportError: - raise InternalServerError( - "`Pillow` is required to use {__name__}\n Instructions: `pip install -U Pillow`" - ) PIL.Image.init() self.MIME_EXT_MAPPING.update({v: k for k, v in PIL.Image.MIME.items()}) if mime_type.lower() not in self.MIME_EXT_MAPPING: # pragma: no cover raise InvalidArgument( - f"Invalid Image mime_type '{mime_type}', " - f"Supported mime types are {', '.join(PIL.Image.MIME.values())} " - ) + f"Invalid Image mime_type '{mime_type}'. Supported mime types are {', '.join(PIL.Image.MIME.values())}." + ) from None if pilmode is not None and pilmode not in PIL.Image.MODES: # pragma: no cover raise InvalidArgument( - f"Invalid Image pilmode '{pilmode}', " - f"Supported PIL modes are {', '.join(PIL.Image.MODES)} " - ) + f"Invalid Image pilmode '{pilmode}'. Supported PIL modes are {', '.join(PIL.Image.MODES)}." + ) from None self._mime_type = mime_type.lower() self._pilmode: _Mode | None = pilmode @@ -197,13 +197,12 @@ async def from_http_request(self, request: Request) -> ImageType: bytes_ = await request.body() else: raise BadInput( - f"{self.__class__.__name__} should get `multipart/form-data`, " - f"`{self._mime_type}` or `image/*`, got {content_type} instead" + f"{self.__class__.__name__} should get 'multipart/form-data', '{self._mime_type}' or 'image/*', got '{content_type}' instead." ) try: return PIL.Image.open(io.BytesIO(bytes_)) - except PIL.UnidentifiedImageError: - raise BadInput("Failed reading image file uploaded") from None + except PIL.UnidentifiedImageError as e: + raise BadInput(f"Failed reading image file uploaded: {e}") from None async def to_http_response( self, obj: ImageType, ctx: Context | None = None @@ -213,10 +212,9 @@ async def to_http_response( elif LazyType[PIL.Image.Image]("PIL.Image.Image").isinstance(obj): image = obj else: - raise InternalServerError( - f"Unsupported Image type received: {type(obj)}, `{self.__class__.__name__}`" - " only supports `np.ndarray` and `PIL.Image`" - ) + raise BadInput( + f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'." + ) from None filename = f"output.{self._format.lower()}" ret = io.BytesIO() @@ -248,3 +246,52 @@ async def to_http_response( media_type=self._mime_type, headers={"content-disposition": content_disposition}, ) + + async def from_proto(self, field: pb.File | bytes) -> ImageType: + from bentoml.grpc.utils import filetype_pb_to_mimetype_map + + mapping = filetype_pb_to_mimetype_map() + # check if the request message has the correct field + if isinstance(field, bytes): + content = field + else: + assert isinstance(field, pb.File) + if field.kind: + try: + mime_type = mapping[field.kind] + if mime_type != self._mime_type: + raise BadInput( + f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'", + ) + except KeyError: + raise BadInput( + f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb.File.FileType.items()]}", + ) from None + content = field.content + if not content: + raise BadInput("Content is empty!") from None + + return PIL.Image.open(io.BytesIO(content)) + + async def to_proto(self, obj: ImageType) -> pb.File: + from bentoml.grpc.utils import mimetype_to_filetype_pb_map + + if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj): + image = PIL.Image.fromarray(obj, mode=self._pilmode) + elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj): + image = obj + else: + raise BadInput( + f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'.", + ) from None + ret = io.BytesIO() + image.save(ret, format=self._format) + + try: + kind = mimetype_to_filetype_pb_map()[self._mime_type] + except KeyError: + raise BadInput( + f"{self._mime_type} doesn't have a corresponding File 'kind'", + ) from None + + return pb.File(kind=kind, content=ret.getvalue()) diff --git a/bentoml/_internal/io_descriptors/json.py b/bentoml/_internal/io_descriptors/json.py index 648cf7de7e..c17bffc7ab 100644 --- a/bentoml/_internal/io_descriptors/json.py +++ b/bentoml/_internal/io_descriptors/json.py @@ -10,12 +10,14 @@ from starlette.requests import Request from starlette.responses import Response +from bentoml.exceptions import BadInput + from .base import IODescriptor from ..types import LazyType from ..utils import LazyLoader from ..utils import bentoml_cattr +from ..utils.pkg import pkg_version_info from ..utils.http import set_cookies -from ...exceptions import BadInput from ..service.openapi import REF_PREFIX from ..service.openapi import SUCCESS_DESCRIPTION from ..service.openapi.specification import Schema @@ -28,26 +30,28 @@ import pydantic import pydantic.schema as schema + from google.protobuf import struct_pb2 from .. import external_typing as ext from ..context import InferenceApiContext as Context - _Serializable = ext.NpNDArray | ext.PdDataFrame | t.Type[pydantic.BaseModel] | type else: _exc_msg = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install pydantic'." pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=_exc_msg) schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=_exc_msg) + # lazy load our proto generated. + struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2") + # lazy load numpy for processing ndarray. + np = LazyLoader("np", globals(), "numpy") JSONType = t.Union[str, t.Dict[str, t.Any], "pydantic.BaseModel", None] -MIME_TYPE_JSON = "application/json" - logger = logging.getLogger(__name__) class DefaultJsonEncoder(json.JSONEncoder): - def default(self, o: _Serializable) -> t.Any: + def default(self, o: type) -> t.Any: if dataclasses.is_dataclass(o): return dataclasses.asdict(o) if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(o): @@ -63,9 +67,8 @@ def default(self, o: _Serializable) -> t.Any: if "__root__" in obj_dict: obj_dict = obj_dict.get("__root__") return obj_dict - if attr.has(o): # type: ignore (trivial case) + if attr.has(o): return bentoml_cattr.unstructure(o) - return super().default(o) @@ -168,7 +171,9 @@ def classify(input_data: IrisFeatures) -> NDArray[Any]: :obj:`JSON`: IO Descriptor that represents JSON format. """ - _mime_type: str = MIME_TYPE_JSON + _proto_fields = ("json",) + # default mime type is application/json + _mime_type = "application/json" def __init__( self, @@ -177,7 +182,11 @@ def __init__( validate_json: bool | None = None, json_encoder: t.Type[json.JSONEncoder] = DefaultJsonEncoder, ): - if pydantic_model: + if pydantic_model is not None: + if pkg_version_info("pydantic")[0] >= 2: + raise BadInput( + "pydantic 2.x is not yet supported. Add upper bound to 'pydantic': 'pip install \"pydantic<2\"'" + ) from None assert issubclass( pydantic_model, pydantic.BaseModel ), "'pydantic_model' must be a subclass of 'pydantic.BaseModel'." @@ -236,12 +245,12 @@ async def from_http_request(self, request: Request) -> JSONType: except json.JSONDecodeError as e: raise BadInput(f"Invalid JSON input received: {e}") from None - if self._pydantic_model is not None: + if self._pydantic_model: try: pydantic_model = self._pydantic_model.parse_obj(json_obj) return pydantic_model except pydantic.ValidationError as e: - raise BadInput(f"Invalid JSON input received: {e}") from e + raise BadInput(f"Invalid JSON input received: {e}") from None else: return json_obj @@ -269,11 +278,57 @@ async def to_http_response( if ctx is not None: res = Response( json_str, - media_type=MIME_TYPE_JSON, + media_type=self._mime_type, headers=ctx.response.metadata, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(json_str, media_type=MIME_TYPE_JSON) + return Response(json_str, media_type=self._mime_type) + + async def from_proto(self, field: struct_pb2.Value | bytes) -> JSONType: + from google.protobuf.json_format import MessageToDict + + if isinstance(field, bytes): + content = field + if self._pydantic_model: + try: + return self._pydantic_model.parse_raw(content) + except pydantic.ValidationError as e: + raise BadInput(f"Invalid JSON input received: {e}") from None + try: + parsed = json.loads(content) + except json.JSONDecodeError as e: + raise BadInput(f"Invalid JSON input received: {e}") from None + else: + assert isinstance(field, struct_pb2.Value) + parsed = MessageToDict(field, preserving_proto_field_name=True) + + if self._pydantic_model: + try: + return self._pydantic_model.parse_obj(parsed) + except pydantic.ValidationError as e: + raise BadInput(f"Invalid JSON input received: {e}") from None + return parsed + + async def to_proto(self, obj: JSONType) -> struct_pb2.Value: + if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj): + obj = obj.dict() + msg = struct_pb2.Value() + # To handle None cases. + if obj is not None: + from google.protobuf.json_format import ParseDict + + if isinstance(obj, (dict, str, list, float, int, bool)): + # ParseDict handles google.protobuf.Struct type + # directly if given object has a supported type + ParseDict(obj, msg) + else: + # If given object doesn't have a supported type, we will + # use given JSON encoder to convert it to dictionary + # and then parse it to google.protobuf.Struct. + # Note that if a custom JSON encoder is used, it mustn't + # take any arguments. + ParseDict(self._json_encoder().default(obj), msg) + return msg diff --git a/bentoml/_internal/io_descriptors/multipart.py b/bentoml/_internal/io_descriptors/multipart.py index c6f9190dd8..afc62190b4 100644 --- a/bentoml/_internal/io_descriptors/multipart.py +++ b/bentoml/_internal/io_descriptors/multipart.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +import asyncio from typing import TYPE_CHECKING from starlette.requests import Request @@ -21,11 +22,17 @@ if TYPE_CHECKING: from types import UnionType + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from ..types import LazyType from ..context import InferenceApiContext as Context +else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() -class Multipart(IODescriptor[t.Any]): +class Multipart(IODescriptor[t.Dict[str, t.Any]]): """ :obj:`Multipart` defines API specification for the inputs/outputs of a Service, where inputs/outputs of a Service can receive/send a **multipart** request/responses as specified in your API function signature. @@ -153,14 +160,18 @@ async def predict( :obj:`Multipart`: IO Descriptor that represents a Multipart request/response. """ + _proto_fields = ("multipart",) + _mime_type = "multipart/form-data" + def __init__(self, **inputs: IODescriptor[t.Any]): - for descriptor in inputs.values(): - if isinstance(descriptor, Multipart): # pragma: no cover - raise InvalidArgument( - "Multipart IO can not contain nested Multipart IO descriptor" - ) - self._inputs: dict[str, t.Any] = inputs - self._mime_type = "multipart/form-data" + if any(isinstance(descriptor, Multipart) for descriptor in inputs.values()): + raise InvalidArgument( + "Multipart IO can not contain nested Multipart IO descriptor" + ) from None + self._inputs = inputs + + def __repr__(self) -> str: + return f"Multipart({','.join([f'{k}={v}' for k,v in zip(self._inputs, map(repr, self._inputs.values()))])})" def input_type( self, @@ -171,7 +182,7 @@ def input_type( if isinstance(inp_type, dict): raise TypeError( "A multipart descriptor cannot take a multi-valued I/O descriptor as input" - ) + ) from None res[k] = inp_type return res @@ -202,22 +213,68 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]: if ctype != b"multipart/form-data": raise BentoMLException( f"{self.__class__.__name__} only accepts `multipart/form-data` as Content-Type header, got {ctype} instead." - ) - - res: dict[str, t.Any] = dict() - reqs = await populate_multipart_requests(request) + ) from None - for k, i in self._inputs.items(): - req = reqs[k] - v = await i.from_http_request(req) - res[k] = v - return res + to_populate = zip( + self._inputs.values(), (await populate_multipart_requests(request)).values() + ) + reqs = await asyncio.gather( + *tuple(io_.from_http_request(req) for io_, req in to_populate) + ) + return dict(zip(self._inputs, reqs)) async def to_http_response( self, obj: dict[str, t.Any], ctx: Context | None = None ) -> Response: - res_mapping: dict[str, Response] = {} - for k, io_ in self._inputs.items(): - data = obj[k] - res_mapping[k] = await io_.to_http_response(data, ctx) - return await concat_to_multipart_response(res_mapping, ctx) + resps = await asyncio.gather( + *tuple( + io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items() + ) + ) + return await concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx) + + def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None: + if len(set(field) - set(self._inputs)) != 0: + raise InvalidArgument( + f"'{repr(self)}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}", + ) from None + + async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]: + if isinstance(field, bytes): + raise InvalidArgument( + f"cannot use 'serialized_bytes' with {self.__class__.__name__}" + ) from None + message = field.fields + self.validate_input_mapping(message) + reqs = await asyncio.gather( + *tuple( + io_.from_proto(getattr(input_pb, io_._proto_fields[0])) + for io_, input_pb in self.io_fields_mapping(message).items() + ) + ) + return dict(zip(message, reqs)) + + def io_fields_mapping( + self, message: t.MutableMapping[str, pb.Part] + ) -> dict[IODescriptor[t.Any], pb.Part]: + return {io_: part for io_, part in zip(self._inputs.values(), message.values())} + + async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart: + self.validate_input_mapping(obj) + resps = await asyncio.gather( + *tuple( + io_.to_proto(data) + for io_, data in zip(self._inputs.values(), obj.values()) + ) + ) + return pb.Multipart( + fields={ + key: pb.Part( + **{ + io_._proto_fields[0]: resp + for io_, resp in zip(self._inputs.values(), resps) + } + ) + for key in obj + } + ) diff --git a/bentoml/_internal/io_descriptors/numpy.py b/bentoml/_internal/io_descriptors/numpy.py index 9ae561d47a..aa730a7dea 100644 --- a/bentoml/_internal/io_descriptors/numpy.py +++ b/bentoml/_internal/io_descriptors/numpy.py @@ -4,19 +4,20 @@ import typing as t import logging from typing import TYPE_CHECKING +from functools import lru_cache from starlette.requests import Request from starlette.responses import Response from .base import IODescriptor -from .json import MIME_TYPE_JSON from ..types import LazyType +from ..utils import LazyLoader from ..utils.http import set_cookies from ...exceptions import BadInput +from ...exceptions import InvalidArgument from ...exceptions import BentoMLException -from ...exceptions import InternalServerError +from ...exceptions import UnprocessableEntity from ..service.openapi import SUCCESS_DESCRIPTION -from ..utils.lazy_loader import LazyLoader from ..service.openapi.specification import Schema from ..service.openapi.specification import Response as OpenAPIResponse from ..service.openapi.specification import MediaType @@ -25,18 +26,79 @@ if TYPE_CHECKING: import numpy as np + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from .. import external_typing as ext from ..context import InferenceApiContext as Context else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() np = LazyLoader("np", globals(), "numpy") logger = logging.getLogger(__name__) -def _is_matched_shape( - left: t.Optional[t.Tuple[int, ...]], - right: t.Optional[t.Tuple[int, ...]], -) -> bool: # pragma: no cover +# TODO: support the following types for for protobuf message: +# - support complex64, complex128, object and struct types +# - BFLOAT16, QINT32, QINT16, QUINT16, QINT8, QUINT8 +# +# For int16, uint16, int8, uint8 -> specify types in NumpyNdarray + using int_values. +# +# For bfloat16, half (float16) -> specify types in NumpyNdarray + using float_values. +# +# for string_values, use dict[pb.NDArray.DType.ValueType, ext.NpDTypeLike]: + # pb.NDArray.Dtype -> np.dtype + return { + pb.NDArray.DTYPE_FLOAT: np.dtype("float32"), + pb.NDArray.DTYPE_DOUBLE: np.dtype("double"), + pb.NDArray.DTYPE_INT32: np.dtype("int32"), + pb.NDArray.DTYPE_INT64: np.dtype("int64"), + pb.NDArray.DTYPE_UINT32: np.dtype("uint32"), + pb.NDArray.DTYPE_UINT64: np.dtype("uint64"), + pb.NDArray.DTYPE_BOOL: np.dtype("bool"), + pb.NDArray.DTYPE_STRING: np.dtype(" dict[pb.NDArray.DType.ValueType, str]: + return {k: npdtype_to_fieldpb_map()[v] for k, v in dtypepb_to_npdtype_map().items()} + + +@lru_cache(maxsize=1) +def fieldpb_to_npdtype_map() -> dict[str, ext.NpDTypeLike]: + # str -> np.dtype + return {k: np.dtype(v) for k, v in FIELDPB_TO_NPDTYPE_NAME_MAP.items()} + + +@lru_cache(maxsize=1) +def npdtype_to_dtypepb_map() -> dict[ext.NpDTypeLike, pb.NDArray.DType.ValueType]: + # np.dtype -> pb.NDArray.Dtype + return {v: k for k, v in dtypepb_to_npdtype_map().items()} + + +@lru_cache(maxsize=1) +def npdtype_to_fieldpb_map() -> dict[ext.NpDTypeLike, str]: + # np.dtype -> str + return {v: k for k, v in fieldpb_to_npdtype_map().items()} + + +def _is_matched_shape(left: tuple[int, ...], right: tuple[int, ...]) -> bool: if (left is None) or (right is None): return False @@ -52,6 +114,7 @@ def _is_matched_shape( return True +# TODO: when updating docs, add examples with gRPCurl class NumpyNdarray(IODescriptor["ext.NpNDArray"]): """ :obj:`NumpyNdarray` defines API specification for the inputs/outputs of a Service, where @@ -135,6 +198,9 @@ async def predict(input_array: np.ndarray) -> np.ndarray: :obj:`~bentoml._internal.io_descriptors.IODescriptor`: IO Descriptor that represents a :code:`np.ndarray`. """ + _proto_fields = ("ndarray",) + _mime_type = "application/json" + def __init__( self, dtype: str | ext.NpDTypeLike | None = None, @@ -143,15 +209,11 @@ def __init__( enforce_shape: bool = False, ): if dtype and not isinstance(dtype, np.dtype): - # Convert from primitive type or type string, e.g.: - # np.dtype(float) - # np.dtype("float64") + # Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64") try: dtype = np.dtype(dtype) except TypeError as e: - raise BentoMLException( - f'NumpyNdarray: Invalid dtype "{dtype}": {e}' - ) from e + raise UnprocessableEntity(f'Invalid dtype "{dtype}": {e}') from None self._dtype = dtype self._shape = shape @@ -160,6 +222,15 @@ def __init__( self._sample_input = None + if self._enforce_dtype and not self._dtype: + raise InvalidArgument( + "'dtype' must be specified when 'enforce_dtype=True'" + ) from None + if self._enforce_shape and not self._shape: + raise InvalidArgument( + "'shape' must be specified when 'enforce_shape=True'" + ) from None + def _openapi_types(self) -> str: # convert numpy dtypes to openapi compatible types. var_type = "integer" @@ -195,7 +266,9 @@ def openapi_components(self) -> dict[str, t.Any] | None: def openapi_example(self) -> t.Any: if self.sample_input is not None: if isinstance(self.sample_input, np.generic): - raise BadInput("NumpyNdarray: sample_input must be a numpy array.") + raise BadInput( + "NumpyNdarray: sample_input must be a numpy array." + ) from None return self.sample_input.tolist() return @@ -219,33 +292,33 @@ def openapi_responses(self) -> OpenAPIResponse: }, ) - def _verify_ndarray( - self, obj: ext.NpNDArray, exception_cls: t.Type[Exception] = BadInput + def validate_array( + self, arr: ext.NpNDArray, exception_cls: t.Type[Exception] = BadInput ) -> ext.NpNDArray: - if self._dtype is not None and self._dtype != obj.dtype: + if self._dtype is not None and self._dtype != arr.dtype: # ‘same_kind’ means only safe casts or casts within a kind, like float64 # to float32, are allowed. - if np.can_cast(obj.dtype, self._dtype, casting="same_kind"): - obj = obj.astype(self._dtype, casting="same_kind") # type: ignore + if np.can_cast(arr.dtype, self._dtype, casting="same_kind"): + arr = arr.astype(self._dtype, casting="same_kind") # type: ignore else: - msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{obj.dtype}" was received.' + msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{arr.dtype}" was received.' if self._enforce_dtype: - raise exception_cls(msg) + raise exception_cls(msg) from None else: logger.debug(msg) - if self._shape is not None and not _is_matched_shape(self._shape, obj.shape): - msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{obj.shape}" was received.' + if self._shape is not None and not _is_matched_shape(self._shape, arr.shape): + msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{arr.shape}" was received.' if self._enforce_shape: - raise exception_cls(msg) + raise exception_cls(msg) from None try: - obj = obj.reshape(self._shape) + arr = arr.reshape(self._shape) except ValueError as e: logger.debug(f"{msg} Failed to reshape: {e}.") - return obj + return arr - async def from_http_request(self, request: Request) -> "ext.NpNDArray": + async def from_http_request(self, request: Request) -> ext.NpNDArray: """ Process incoming requests and convert incoming objects to ``numpy.ndarray``. @@ -262,7 +335,7 @@ async def from_http_request(self, request: Request) -> "ext.NpNDArray": except ValueError: res = np.array(obj) - return self._verify_ndarray(res) + return self.validate_array(res) async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None): """ @@ -276,18 +349,19 @@ async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None) HTTP Response of type ``starlette.responses.Response``. This can be accessed via cURL or any external web traffic. """ - obj = self._verify_ndarray(obj, InternalServerError) + obj = self.validate_array(obj) + if ctx is not None: res = Response( json.dumps(obj.tolist()), - media_type=MIME_TYPE_JSON, + media_type=self._mime_type, headers=ctx.response.metadata, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(json.dumps(obj.tolist()), media_type=MIME_TYPE_JSON) + return Response(json.dumps(obj.tolist()), media_type=self._mime_type) @classmethod def from_sample( @@ -336,8 +410,8 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]: """ if isinstance(sample_input, np.generic): raise BentoMLException( - "NumpyNdarray.from_sample() expects a numpy.array, not numpy.generic." - ) + "'NumpyNdarray.from_sample()' expects a 'numpy.array', not 'numpy.generic'." + ) from None inst = cls( dtype=sample_input.dtype, @@ -348,3 +422,88 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]: inst.sample_input = sample_input return inst + + async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray: + """ + Process incoming protobuf request and convert it to ``numpy.ndarray`` + + Args: + request: Incoming RPC request message. + context: grpc.ServicerContext + + Returns: + a ``numpy.ndarray`` object. This can then be used + inside users defined logics. + """ + if isinstance(field, bytes): + if not self._dtype: + raise BadInput( + "'serialized_bytes' requires specifying 'dtype'." + ) from None + + dtype: ext.NpDTypeLike = self._dtype + array = np.frombuffer(field, dtype=self._dtype) + else: + assert isinstance(field, pb.NDArray) + if field.dtype == pb.NDArray.DTYPE_UNSPECIFIED: + dtype = None + else: + try: + dtype = dtypepb_to_npdtype_map()[field.dtype] + except KeyError: + raise BadInput(f"{field.dtype} is invalid.") from None + if dtype is not None: + values_array = getattr(field, dtypepb_to_fieldpb_map()[field.dtype]) + else: + fieldpb = [ + f.name for f, _ in field.ListFields() if f.name.endswith("_values") + ] + if len(fieldpb) == 0: + # input message doesn't have any fields. + return np.empty(shape=field.shape or 0) + elif len(fieldpb) > 1: + # when there are more than two values provided in the proto. + raise BadInput( + f"Array contents can only be one of given values key. Use one of '{fieldpb}' instead.", + ) from None + + dtype: ext.NpDTypeLike = fieldpb_to_npdtype_map()[fieldpb[0]] + values_array = getattr(field, fieldpb[0]) + try: + array = np.array(values_array, dtype=dtype) + except ValueError: + array = np.array(values_array) + + if field.shape: + array = np.reshape(array, field.shape) + + return self.validate_array(array) + + async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray: + """ + Process given objects and convert it to grpc protobuf response. + + Args: + obj: `np.ndarray` that will be serialized to protobuf + context: grpc.aio.ServicerContext from grpc.aio.Server + Returns: + `io_descriptor_pb2.Array`: + Protobuf representation of given `np.ndarray` + """ + try: + obj = self.validate_array(obj) + except BadInput as e: + raise e from None + + try: + fieldpb = npdtype_to_fieldpb_map()[obj.dtype] + dtypepb = npdtype_to_dtypepb_map()[obj.dtype] + return pb.NDArray( + dtype=dtypepb, + shape=tuple(obj.shape), + **{fieldpb: obj.ravel().tolist()}, + ) + except KeyError: + raise BadInput( + f"Unsupported dtype '{obj.dtype}' for response message.", + ) from None diff --git a/bentoml/_internal/io_descriptors/pandas.py b/bentoml/_internal/io_descriptors/pandas.py index b01256bfc5..135a44b0ce 100644 --- a/bentoml/_internal/io_descriptors/pandas.py +++ b/bentoml/_internal/io_descriptors/pandas.py @@ -4,19 +4,20 @@ import typing as t import logging import functools -import importlib.util from enum import Enum from typing import TYPE_CHECKING +from concurrent.futures import ThreadPoolExecutor from starlette.requests import Request from starlette.responses import Response from .base import IODescriptor -from .json import MIME_TYPE_JSON from ..types import LazyType +from ..utils.pkg import find_spec from ..utils.http import set_cookies from ...exceptions import BadInput from ...exceptions import InvalidArgument +from ...exceptions import UnprocessableEntity from ...exceptions import MissingDependencyException from ..service.openapi import SUCCESS_DESCRIPTION from ..utils.lazy_loader import LazyLoader @@ -28,15 +29,20 @@ if TYPE_CHECKING: import pandas as pd + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from .. import external_typing as ext from ..context import InferenceApiContext as Context else: + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() pd = LazyLoader( "pd", globals(), "pandas", - exc_msg="`pandas` is required to use PandasDataFrame or PandasSeries. Install with `pip install -U pandas`", + exc_msg='pandas" is required to use PandasDataFrame or PandasSeries. Install with "pip install -U pandas"', ) logger = logging.getLogger(__name__) @@ -45,9 +51,9 @@ # Check for parquet support @functools.lru_cache(maxsize=1) def get_parquet_engine() -> str: - if importlib.util.find_spec("pyarrow") is not None: + if find_spec("pyarrow") is not None: return "pyarrow" - elif importlib.util.find_spec("fastparquet") is not None: + elif find_spec("fastparquet") is not None: return "fastparquet" else: logger.warning( @@ -72,9 +78,7 @@ def _openapi_types(item: str) -> str: # pragma: no cover return "object" -def _openapi_schema( - dtype: bool | dict[str, t.Any] | None -) -> Schema: # pragma: no cover +def _openapi_schema(dtype: bool | ext.PdDTypeArg | None) -> Schema: # pragma: no cover if isinstance(dtype, dict): return Schema( type="object", @@ -111,15 +115,12 @@ def _infer_serialization_format_from_request( return SerializationFormat.CSV elif content_type: logger.debug( - "Unknown content-type (%s), falling back to %s serialization format.", - content_type, - default_format, + f"Unknown content-type ('{content_type}'), falling back to '{default_format}' serialization format.", ) return default_format else: logger.debug( - "Content-type not specified, falling back to %s serialization format.", - default_format, + f"Content-type not specified, falling back to '{default_format}' serialization format.", ) return default_format @@ -203,7 +204,7 @@ def predict(input_arr): - :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``} - :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}] - :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}} - - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}} + - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}} - :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays columns: List of columns name that users wish to update. apply_column_names: Whether to update incoming DataFrame columns. If :code:`apply_column_names=True`, @@ -248,12 +249,14 @@ def predict(input_df: pd.DataFrame) -> pd.DataFrame: :obj:`PandasDataFrame`: IO Descriptor that represents a :code:`pd.DataFrame`. """ + _proto_fields = ("dataframe",) + def __init__( self, orient: ext.DataFrameOrient = "records", - apply_column_names: bool = False, columns: list[str] | None = None, - dtype: bool | dict[str, t.Any] | None = None, + apply_column_names: bool = False, + dtype: bool | ext.PdDTypeArg | None = None, enforce_dtype: bool = False, shape: tuple[int, ...] | None = None, enforce_shape: bool = False, @@ -324,49 +327,21 @@ async def from_http_request(self, request: Request) -> ext.PdDataFrame: _validate_serialization_format(serialization_format) obj = await request.body() - if self._enforce_dtype: - if self._dtype is None: - logger.warning( - "`dtype` is None or undefined, while `enforce_dtype`=True" - ) - # TODO(jiang): check dtype - if serialization_format is SerializationFormat.JSON: + assert not isinstance(self._dtype, bool) res = pd.read_json(io.BytesIO(obj), dtype=self._dtype, orient=self._orient) elif serialization_format is SerializationFormat.PARQUET: res = pd.read_parquet(io.BytesIO(obj), engine=get_parquet_engine()) elif serialization_format is SerializationFormat.CSV: + assert not isinstance(self._dtype, bool) res: ext.PdDataFrame = pd.read_csv(io.BytesIO(obj), dtype=self._dtype) else: raise InvalidArgument( f"Unknown serialization format ({serialization_format})." - ) + ) from None assert isinstance(res, pd.DataFrame) - - if self._apply_column_names: - if self._columns is None: - logger.warning( - "`columns` is None or undefined, while `apply_column_names`=True" - ) - elif len(self._columns) != res.shape[1]: - raise BadInput( - "length of `columns` does not match the columns of incoming data" - ) - else: - res.columns = pd.Index(self._columns) - if self._enforce_shape: - if self._shape is None: - logger.warning( - "`shape` is None or undefined, while `enforce_shape`=True" - ) - else: - assert all( - left == right - for left, right in zip(self._shape, res.shape) # type: ignore (shape type) - if left != -1 and right != -1 - ), f"incoming has shape {res.shape} where enforced shape to be {self._shape}" - return res + return self.validate_dataframe(res) async def to_http_response( self, obj: ext.PdDataFrame, ctx: Context | None = None @@ -381,6 +356,7 @@ async def to_http_response( HTTP Response of type `starlette.responses.Response`. This can be accessed via cURL or any external web traffic. """ + obj = self.validate_dataframe(obj) # For the response it doesn't make sense to enforce the same serialization format as specified # by the request's headers['content-type']. Instead we simply use the _default_format. @@ -399,7 +375,7 @@ async def to_http_response( else: raise InvalidArgument( f"Unknown serialization format ({serialization_format})." - ) + ) from None if ctx is not None: res = Response( @@ -420,7 +396,7 @@ def from_sample( orient: ext.DataFrameOrient = "records", apply_column_names: bool = True, enforce_shape: bool = True, - enforce_dtype: bool = False, + enforce_dtype: bool = True, default_format: t.Literal["json", "parquet", "csv"] = "json", ) -> PandasDataFrame: """ @@ -435,7 +411,7 @@ def from_sample( - :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``} - :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}] - :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}} - - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}} + - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}} - :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays apply_column_names: Update incoming DataFrame columns. ``columns`` must be specified at function signature. If you don't want to enforce a specific columns @@ -469,22 +445,144 @@ def from_sample( @svc.api(input=input_spec, output=PandasDataFrame()) def predict(inputs: pd.DataFrame) -> pd.DataFrame: ... """ - columns = [str(x) for x in list(sample_input.columns)] - inst = cls( orient=orient, enforce_shape=enforce_shape, shape=sample_input.shape, apply_column_names=apply_column_names, - columns=columns, + columns=[str(x) for x in list(sample_input.columns)], enforce_dtype=enforce_dtype, - dtype=None, # TODO: not breaking atm + dtype=True, # set to True to infer from given input default_format=default_format, ) inst.sample_input = sample_input return inst + def validate_dataframe( + self, dataframe: ext.PdDataFrame, exception_cls: t.Type[Exception] = BadInput + ) -> ext.PdDataFrame: + + if not LazyType["ext.PdDataFrame"]("pd.DataFrame").isinstance(dataframe): + raise InvalidArgument( + f"return object is not of type 'pd.DataFrame', got type '{type(dataframe)}' instead" + ) from None + + # TODO: dtype check + # if self._dtype is not None and self._dtype != dataframe.dtypes: + # msg = f'{self.__class__.__name__}: Expecting DataFrame of dtype "{self._dtype}", but "{dataframe.dtypes}" was received.' + # if self._enforce_dtype: + # raise exception_cls(msg) from None + + if self._columns is not None and len(self._columns) != dataframe.shape[1]: + msg = f"length of 'columns' ({len(self._columns)}) does not match the # of columns of incoming data." + if self._apply_column_names: + raise BadInput(msg) from None + else: + logger.debug(msg) + dataframe.columns = pd.Index(self._columns) + + # TODO: convert from wide to long format (melt()) + if self._shape is not None and self._shape != dataframe.shape: + msg = f'{self.__class__.__name__}: Expecting DataFrame of shape "{self._shape}", but "{dataframe.shape}" was received.' + if self._enforce_shape and not all( + left == right + for left, right in zip(self._shape, dataframe.shape) + if left != -1 and right != -1 + ): + raise exception_cls(msg) from None + + return dataframe + + async def from_proto(self, field: pb.DataFrame | bytes) -> ext.PdDataFrame: + """ + Process incoming protobuf request and convert it to ``pandas.DataFrame`` + + Args: + request: Incoming RPC request message. + context: grpc.ServicerContext + + Returns: + a ``pandas.DataFrame`` object. This can then be used + inside users defined logics. + """ + # TODO: support different serialization format + if isinstance(field, bytes): + # TODO: handle serialized_bytes for dataframe + raise NotImplementedError( + 'Currently not yet implemented. Use "dataframe" instead.' + ) + else: + # note that there is a current bug where we don't check for + # dtype of given fields per Series to match with types of a given + # columns, hence, this would result in a wrong DataFrame that is not + # expected by our users. + assert isinstance(field, pb.DataFrame) + # columns orient: { column_name : {index : columns.series._value}} + if self._orient != "columns": + raise BadInput( + f"'dataframe' field currently only supports 'columns' orient. Make sure to set 'orient=columns' in {self.__class__.__name__}." + ) from None + data: list[t.Any] = [] + + def process_columns_contents(content: pb.Series) -> dict[str, t.Any]: + # To be use inside a ThreadPoolExecutor to handle + # large tabular data + if len(content.ListFields()) != 1: + raise BadInput( + f"Array contents can only be one of given values key. Use one of '{list(map(lambda f: f[0].name,content.ListFields()))}' instead." + ) from None + return {str(i): c for i, c in enumerate(content.ListFields()[0][1])} + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = executor.map(process_columns_contents, field.columns) + data.extend([i for i in list(futures)]) + dataframe = pd.DataFrame( + dict(zip(field.column_names, data)), + columns=t.cast(t.List[str], field.column_names), + ) + return self.validate_dataframe(dataframe) + + async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame: + """ + Process given objects and convert it to grpc protobuf response. + + Args: + obj: ``pandas.DataFrame`` that will be serialized to protobuf + context: grpc.aio.ServicerContext from grpc.aio.Server + Returns: + ``service_pb2.Response``: + Protobuf representation of given ``pandas.DataFrame`` + """ + from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map + + # TODO: support different serialization format + obj = self.validate_dataframe(obj) + mapping = npdtype_to_fieldpb_map() + # note that this is not safe, since we are not checking the dtype of the series + # FIXME(aarnphm): validate and handle mix columns dtype + # currently we don't support ExtensionDtype + columns_name: list[str] = list(map(str, obj.columns)) + not_supported: list[ext.PdDType] = list( + filter( + lambda x: x not in mapping, + map(lambda x: t.cast("ext.PdSeries", obj[x]).dtype, columns_name), + ) + ) + if len(not_supported) > 0: + raise UnprocessableEntity( + f'dtype in column "{obj.columns}" is not currently supported.' + ) from None + return pb.DataFrame( + column_names=columns_name, + columns=[ + pb.Series( + **{mapping[t.cast("ext.NpDTypeLike", obj[col].dtype)]: obj[col]} + ) + for col in columns_name + ], + ) + class PandasSeries(IODescriptor["ext.PdSeries"]): """ @@ -551,7 +649,7 @@ def predict(input_arr): - :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``} - :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}] - :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}} - - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}} + - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}} - :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays columns: List of columns name that users wish to update. apply_column_names (`bool`, `optional`, default to :code:`False`): @@ -573,8 +671,8 @@ def predict(input_arr): from bentoml.io import PandasSeries @svc.api(input=PandasSeries(shape=(51,10), enforce_shape=True), output=PandasSeries()) - def infer(input_df: pd.DataFrame) -> pd.DataFrame: - # if input_df have shape (40,9), it will throw out errors + def infer(input_series: pd.Series) -> pd.Series: + # if input_series have shape (40,9), it will throw out errors ... enforce_shape: Whether to enforce a certain shape. If ``enforce_shape=True`` then ``shape`` must be specified. @@ -582,12 +680,13 @@ def infer(input_df: pd.DataFrame) -> pd.DataFrame: :obj:`PandasSeries`: IO Descriptor that represents a :code:`pd.Series`. """ - _mime_type: str = MIME_TYPE_JSON + _proto_fields = ("series",) + _mime_type = "application/json" def __init__( self, orient: ext.SeriesOrient = "records", - dtype: bool | dict[str, t.Any] | None = None, + dtype: ext.PdDTypeArg | None = None, enforce_dtype: bool = False, shape: tuple[int, ...] | None = None, enforce_shape: bool = False, @@ -630,29 +729,13 @@ async def from_http_request(self, request: Request) -> ext.PdSeries: a ``pd.Series`` object. This can then be used inside users defined logics. """ obj = await request.body() - if self._enforce_dtype: - if self._dtype is None: - logger.warning( - "`dtype` is None or undefined, while `enforce_dtype=True`" - ) - - # TODO(jiang): check dtypes when enforce_dtype is set - res = pd.read_json(obj, typ="series", orient=self._orient, dtype=self._dtype) - - assert isinstance(res, pd.Series) - - if self._enforce_shape: - if self._shape is None: - logger.warning( - "`shape` is None or undefined, while `enforce_shape`=True" - ) - else: - assert all( - left == right - for left, right in zip(self._shape, res.shape) - if left != -1 and right != -1 - ), f"incoming has shape {res.shape} where enforced shape to be {self._shape}" - return res + res: ext.PdSeries = pd.read_json( + obj, + typ="series", + orient=self._orient, + dtype=self._dtype, + ) + return self.validate_series(res) async def to_http_response( self, obj: t.Any, ctx: Context | None = None @@ -665,19 +748,43 @@ async def to_http_response( Returns: HTTP Response of type ``starlette.responses.Response``. This can be accessed via cURL or any external web traffic. """ - if not LazyType["ext.PdSeries"](pd.Series).isinstance(obj): - raise InvalidArgument( - f"return object is not of type `pd.Series`, got type {type(obj)} instead" - ) - + obj = self.validate_series(obj) if ctx is not None: res = Response( obj.to_json(orient=self._orient), - media_type=MIME_TYPE_JSON, + media_type=self._mime_type, headers=ctx.response.headers, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(obj.to_json(orient=self._orient), media_type=MIME_TYPE_JSON) + return Response( + obj.to_json(orient=self._orient), media_type=self._mime_type + ) + + def validate_series( + self, series: ext.PdSeries, exception_cls: t.Type[Exception] = BadInput + ) -> ext.PdSeries: + # TODO: dtype check + if not LazyType["ext.PdSeries"]("pd.Series").isinstance(series): + raise InvalidArgument( + f"return object is not of type 'pd.Series', got type '{type(series)}' instead" + ) from None + # TODO: convert from wide to long format (melt()) + if self._shape is not None and self._shape != series.shape: + msg = f"{self.__class__.__name__}: Expecting Series of shape '{self._shape}', but '{series.shape}' was received." + if self._enforce_shape and not all( + left == right + for left, right in zip(self._shape, series.shape) + if left != -1 and right != -1 + ): + raise exception_cls(msg) from None + + return series + + async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries: + raise NotImplementedError("Currently not yet implemented.") + + async def to_proto(self, obj: ext.PdSeries) -> pb.Series: + raise NotImplementedError("Currently not yet implemented.") diff --git a/bentoml/_internal/io_descriptors/text.py b/bentoml/_internal/io_descriptors/text.py index db372250cb..c85ab74212 100644 --- a/bentoml/_internal/io_descriptors/text.py +++ b/bentoml/_internal/io_descriptors/text.py @@ -11,14 +11,18 @@ from .base import IODescriptor from ..utils.http import set_cookies from ..service.openapi import SUCCESS_DESCRIPTION +from ..utils.lazy_loader import LazyLoader +from ..service.openapi.specification import Schema +from ..service.openapi.specification import Response as OpenAPIResponse from ..service.openapi.specification import MediaType +from ..service.openapi.specification import RequestBody if TYPE_CHECKING: - from ..context import InferenceApiContext as Context + from google.protobuf import wrappers_pb2 -from ..service.openapi.specification import Schema -from ..service.openapi.specification import Response as OpenAPIResponse -from ..service.openapi.specification import RequestBody + from ..context import InferenceApiContext as Context +else: + wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2") MIME_TYPE = "text/plain" @@ -86,13 +90,14 @@ def predict(text: str) -> str: :obj:`Text`: IO Descriptor that represents strings type. """ + _proto_fields = ("text",) + _mime_type = MIME_TYPE + def __init__(self, *args: t.Any, **kwargs: t.Any): if args or kwargs: raise BentoMLException( - "'Text' is not designed to take any args or kwargs during initialization." - ) - - self._mime_type = MIME_TYPE + f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization." + ) from None def input_type(self) -> t.Type[str]: return str @@ -123,11 +128,21 @@ async def to_http_response(self, obj: str, ctx: Context | None = None) -> Respon if ctx is not None: res = Response( obj, - media_type=MIME_TYPE, + media_type=self._mime_type, headers=ctx.response.metadata, # type: ignore (bad starlette types) status_code=ctx.response.status_code, ) set_cookies(res, ctx.response.cookies) return res else: - return Response(obj, media_type=MIME_TYPE) + return Response(obj, media_type=self._mime_type) + + async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str: + if isinstance(field, bytes): + return field.decode("utf-8") + else: + assert isinstance(field, wrappers_pb2.StringValue) + return field.value + + async def to_proto(self, obj: str) -> wrappers_pb2.StringValue: + return wrappers_pb2.StringValue(value=obj) diff --git a/bentoml/_internal/runner/runner.py b/bentoml/_internal/runner/runner.py index 488e5b804b..0b931f3e69 100644 --- a/bentoml/_internal/runner/runner.py +++ b/bentoml/_internal/runner/runner.py @@ -225,7 +225,7 @@ def init_local(self, quiet: bool = False) -> None: init local runnable instance, for testing and debugging only """ if not quiet: - logger.warning("'Runner.init_local' is for debugging and testing only") + logger.warning("'Runner.init_local' is for debugging and testing only.") self._init_local() diff --git a/bentoml/_internal/runner/runner_handle/remote.py b/bentoml/_internal/runner/runner_handle/remote.py index 2d31fcf972..d557cb5a84 100644 --- a/bentoml/_internal/runner/runner_handle/remote.py +++ b/bentoml/_internal/runner/runner_handle/remote.py @@ -18,7 +18,8 @@ from ...runner.utils import PAYLOAD_META_HEADER from ...configuration.containers import BentoMLContainer -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: + import yarl from aiohttp import BaseConnector from aiohttp.client import ClientSession @@ -84,7 +85,7 @@ def _get_conn(self) -> BaseConnector: ) self._addr = f"http://{parsed.netloc}" else: - raise ValueError(f"Unsupported bind scheme: {parsed.scheme}") + raise ValueError(f"Unsupported bind scheme: {parsed.scheme}") from None return self._conn @property @@ -99,10 +100,7 @@ def _client( or self._client_cache.closed or self._loop.is_closed() ): - import yarl - from opentelemetry.instrumentation.aiohttp_client import ( - create_trace_config, # type: ignore (missing type stubs) - ) + from opentelemetry.instrumentation.aiohttp_client import create_trace_config def strip_query_params(url: yarl.URL) -> str: return str(url.with_query(None)) @@ -145,7 +143,7 @@ async def async_run_method( if not payload_params.map(lambda i: i.batch_size).all_equal(): raise ValueError( "All batchable arguments must have the same batch size." - ) + ) from None path = "" if __bentoml_method.name == "__call__" else __bentoml_method.name async with self._client.post( @@ -164,30 +162,26 @@ async def async_run_method( if resp.status != 200: raise RemoteException( f"An exception occurred in remote runner {self._runner.name}: [{resp.status}] {body.decode()}" - ) + ) from None try: meta_header = resp.headers[PAYLOAD_META_HEADER] except KeyError: raise RemoteException( - f"Bento payload decode error: {PAYLOAD_META_HEADER} header not set. " - "An exception might have occurred in the remote server." - f"[{resp.status}] {body.decode()}" + f"Bento payload decode error: {PAYLOAD_META_HEADER} header not set. An exception might have occurred in the remote server. [{resp.status}] {body.decode()}" ) from None try: content_type = resp.headers["Content-Type"] except KeyError: raise RemoteException( - f"Bento payload decode error: Content-Type header not set. " - "An exception might have occurred in the remote server." - f"[{resp.status}] {body.decode()}" + f"Bento payload decode error: Content-Type header not set. An exception might have occurred in the remote server. [{resp.status}] {body.decode()}" ) from None if not content_type.lower().startswith("application/vnd.bentoml."): raise RemoteException( f"Bento payload decode error: invalid Content-Type '{content_type}'." - ) + ) from None if content_type == "application/vnd.bentoml.multiple_outputs": payloads = pickle.loads(body) @@ -200,7 +194,7 @@ async def async_run_method( data=body, meta=json.loads(meta_header), container=container ) except JSONDecodeError: - raise ValueError(f"Bento payload decode error: {meta_header}") + raise ValueError(f"Bento payload decode error: {meta_header}") from None return AutoContainer.from_payload(payload) @@ -212,11 +206,14 @@ def run_method( ) -> R: import anyio - return anyio.from_thread.run( # type: ignore (pyright cannot infer the return type) - self.async_run_method, - __bentoml_method, - *args, - **kwargs, + return t.cast( + "R", + anyio.from_thread.run( + self.async_run_method, + __bentoml_method, + *args, + **kwargs, + ), ) def __del__(self) -> None: diff --git a/bentoml/_internal/server/__init__.py b/bentoml/_internal/server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bentoml/_internal/server/grpc/__init__.py b/bentoml/_internal/server/grpc/__init__.py new file mode 100644 index 0000000000..1a92ea6fbb --- /dev/null +++ b/bentoml/_internal/server/grpc/__init__.py @@ -0,0 +1,5 @@ +from .config import Config +from .server import Server +from .servicer import Servicer + +__all__ = ["Server", "Config", "Servicer"] diff --git a/bentoml/_internal/server/grpc/config.py b/bentoml/_internal/server/grpc/config.py new file mode 100644 index 0000000000..80735fedfb --- /dev/null +++ b/bentoml/_internal/server/grpc/config.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import sys +import typing as t +from typing import TYPE_CHECKING + +from simple_di import inject +from simple_di import Provide + +from ...configuration.containers import BentoMLContainer + +if TYPE_CHECKING: + import grpc + + from .servicer import Servicer + + +class Config: + @inject + def __init__( + self, + servicer: Servicer, + bind_address: str, + enable_reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled], + max_message_length: int + | None = Provide[BentoMLContainer.grpc.max_message_length], + max_concurrent_streams: int + | None = Provide[BentoMLContainer.grpc.max_concurrent_streams], + maximum_concurrent_rpcs: int + | None = Provide[BentoMLContainer.grpc.maximum_concurrent_rpcs], + migration_thread_pool_workers: int = 1, + graceful_shutdown_timeout: float | None = None, + ) -> None: + self.servicer = servicer + + # Note that the max_workers are used inside ThreadPoolExecutor. + # This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers. + # Setting it to 1 makes it thread-safe for sync APIs. + self.migration_thread_pool_workers = migration_thread_pool_workers + + # maximum_concurrent_rpcs defines the maximum number of concurrent RPCs this server + # will service before returning RESOURCE_EXHAUSTED status. + # Set to None will indicate no limit. + self.maximum_concurrent_rpcs = maximum_concurrent_rpcs + + self.max_message_length = max_message_length + + self.max_concurrent_streams = max_concurrent_streams + + self.bind_address = bind_address + self.enable_reflection = enable_reflection + self.graceful_shutdown_timeout = graceful_shutdown_timeout + + @property + def options(self) -> grpc.aio.ChannelArgumentType: + options: grpc.aio.ChannelArgumentType = [] + + if sys.platform != "win32": + # https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h#L294 + # Eventhough GRPC_ARG_ALLOW_REUSEPORT is set to 1 by default, we want still + # want to explicitly set it to 1 so that we can spawn multiple gRPC servers in + # production settings. + options.append(("grpc.so_reuseport", 1)) + + if self.max_concurrent_streams: + options.append(("grpc.max_concurrent_streams", self.max_concurrent_streams)) + + if self.max_message_length: + options.extend( + ( + ("grpc.max_message_length", self.max_message_length), + ("grpc.max_receive_message_length", self.max_message_length), + ("grpc.max_send_message_length", self.max_message_length), + ) + ) + + return tuple(options) + + @property + def handlers(self) -> t.Sequence[grpc.GenericRpcHandler] | None: + # Note that currently BentoML doesn't provide any specific + # handlers for gRPC. If users have any specific handlers, + # BentoML will pass it through to grpc.aio.Server + return self.servicer.bento_service.grpc_handlers diff --git a/bentoml/_internal/server/grpc/server.py b/bentoml/_internal/server/grpc/server.py new file mode 100644 index 0000000000..b95416f7ff --- /dev/null +++ b/bentoml/_internal/server/grpc/server.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import typing as t +import asyncio +import logging +from typing import TYPE_CHECKING + +from ...utils import LazyLoader +from ...utils import cached_property + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import grpc + from grpc import aio + from grpc_health.v1 import health_pb2 as pb_health + from grpc_health.v1 import health_pb2_grpc as services_health + from typing_extensions import Self + + from bentoml.grpc.v1alpha1 import service_pb2_grpc as services + + from .config import Config +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + from bentoml._internal.utils import LazyLoader + + grpc, aio = import_grpc() + _, services = import_generated_stubs() + health_exception_msg = "'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'." + pb_health = LazyLoader( + "pb_health", + globals(), + "grpc_health.v1.health_pb2", + exc_msg=health_exception_msg, + ) + services_health = LazyLoader( + "services_health", + globals(), + "grpc_health.v1.health_pb2_grpc", + exc_msg=health_exception_msg, + ) + + +class Server: + """An async implementation of a gRPC server.""" + + def __init__(self, config: Config): + self.config = config + self.servicer = config.servicer + self.loaded = False + + @cached_property + def loop(self) -> asyncio.AbstractEventLoop: + return asyncio.get_event_loop() + + def load(self) -> Self: + from concurrent.futures import ThreadPoolExecutor + + assert not self.loaded + if not bool(self.servicer): + self.servicer.load() + assert self.servicer.loaded + + self.server = aio.server( + migration_thread_pool=ThreadPoolExecutor( + max_workers=self.config.migration_thread_pool_workers + ), + options=self.config.options, + maximum_concurrent_rpcs=self.config.maximum_concurrent_rpcs, + handlers=self.config.handlers, + interceptors=self.servicer.interceptors_stack, + ) + self.loaded = True + + return self + + def run(self) -> None: + if not self.loaded: + self.load() + assert self.loaded + + try: + self.loop.run_until_complete(self.serve()) + finally: + try: + self.loop.call_soon_threadsafe( + lambda: asyncio.ensure_future(self.shutdown()) + ) + except Exception as e: # pylint: disable=broad-except + raise RuntimeError(f"Server failed unexpectedly: {e}") from None + + async def serve(self) -> None: + self.add_insecure_port(self.config.bind_address) + await self.startup() + await self.wait_for_termination() + + async def startup(self) -> None: + from bentoml.exceptions import MissingDependencyException + + # Running on_startup callback. + await self.servicer.startup() + # register bento servicer + services.add_BentoServiceServicer_to_server( + self.servicer.bento_servicer, self.server + ) + services_health.add_HealthServicer_to_server( + self.servicer.health_servicer, self.server + ) + + service_names = self.servicer.service_names + # register custom servicer + for ( + user_servicer, + add_servicer_fn, + user_service_names, + ) in self.servicer.mount_servicers: + add_servicer_fn(user_servicer(), self.server) + service_names += tuple(user_service_names) + + if self.config.enable_reflection: + try: + # reflection is required for health checking to work. + from grpc_reflection.v1alpha import reflection + except ImportError: + raise MissingDependencyException( + "reflection is enabled, which requires 'grpcio-reflection' to be installed. Install with 'pip install grpcio-reflection'." + ) + service_names += (reflection.SERVICE_NAME,) + reflection.enable_server_reflection(service_names, self.server) + + # mark all services as healthy + for service in service_names: + await self.servicer.health_servicer.set( + service, pb_health.HealthCheckResponse.SERVING # type: ignore (no types available) + ) + await self.server.start() + + async def shutdown(self): + # Running on_startup callback. + await self.servicer.shutdown() + await self.server.stop(grace=self.config.graceful_shutdown_timeout) + await self.servicer.health_servicer.enter_graceful_shutdown() + self.loop.stop() + + async def wait_for_termination(self, timeout: int | None = None) -> bool: + return await self.server.wait_for_termination(timeout=timeout) + + def add_insecure_port(self, address: str) -> int: + return self.server.add_insecure_port(address) + + def add_secure_port(self, address: str, credentials: grpc.ServerCredentials) -> int: + return self.server.add_secure_port(address, credentials) + + def add_generic_rpc_handlers( + self, generic_rpc_handlers: t.Sequence[grpc.GenericRpcHandler] + ) -> None: + self.server.add_generic_rpc_handlers(generic_rpc_handlers) diff --git a/bentoml/_internal/server/grpc/servicer.py b/bentoml/_internal/server/grpc/servicer.py new file mode 100644 index 0000000000..cbce08b3a6 --- /dev/null +++ b/bentoml/_internal/server/grpc/servicer.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import sys +import typing as t +import asyncio +import logging +from typing import TYPE_CHECKING + +import anyio + +from bentoml.grpc.utils import grpc_status_code + +from ....exceptions import InvalidArgument +from ....exceptions import BentoMLException + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning) + + import grpc + from grpc import aio + from grpc_health.v1 import health + from typing_extensions import Self + + from bentoml.grpc.types import Interceptors + from bentoml.grpc.types import AddServicerFn + from bentoml.grpc.types import ServicerClass + from bentoml.grpc.types import BentoServicerContext + from bentoml.grpc.types import GeneratedProtocolMessageType + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from bentoml.grpc.v1alpha1 import service_pb2_grpc as services + + from ...service.service import Service + +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + + from ...utils import LazyLoader + + pb, services = import_generated_stubs() + grpc, aio = import_grpc() + health = LazyLoader( + "health", + globals(), + "grpc_health.v1.health", + exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.", + ) + containers = LazyLoader( + "containers", globals(), "google.protobuf.internal.containers" + ) + + +def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None: + # gRPC will always send a POST request. + logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info) + + +class Servicer: + """Create an instance of gRPC Servicer.""" + + def __init__( + self: Self, + service: Service, + on_startup: t.Sequence[t.Callable[[], t.Any]] | None = None, + on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None, + mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]] + | None = None, + interceptors: Interceptors | None = None, + ) -> None: + self.bento_service = service + + self.on_startup = [] if not on_startup else list(on_startup) + self.on_shutdown = [] if not on_shutdown else list(on_shutdown) + self.mount_servicers = [] if not mount_servicers else list(mount_servicers) + self.interceptors = [] if not interceptors else list(interceptors) + self.loaded = False + + def load(self): + assert not self.loaded + + self.interceptors_stack = self.build_interceptors_stack() + + self.bento_servicer = create_bento_servicer(self.bento_service) + + # Create a health check servicer. We use the non-blocking implementation + # to avoid thread starvation. + self.health_servicer = health.aio.HealthServicer() + + self.service_names = tuple( + service.full_name for service in pb.DESCRIPTOR.services_by_name.values() + ) + (health.SERVICE_NAME,) + self.loaded = True + + def build_interceptors_stack(self) -> list[aio.ServerInterceptor]: + return list(map(lambda x: x(), self.interceptors)) + + async def startup(self): + for handler in self.on_startup: + if is_async_iterable(handler): + await handler() + else: + handler() + + async def shutdown(self): + for handler in self.on_shutdown: + if is_async_iterable(handler): + await handler() + else: + handler() + + def __bool__(self): + return self.loaded + + +def is_async_iterable(obj: t.Any) -> bool: # pragma: no cover + return asyncio.iscoroutinefunction(obj) or ( + callable(obj) and asyncio.iscoroutinefunction(obj.__call__) + ) + + +def create_bento_servicer(service: Service) -> services.BentoServiceServicer: + """ + This is the actual implementation of BentoServicer. + Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call + """ + from ...io_descriptors.multipart import Multipart + + class BentoServiceImpl(services.BentoServiceServicer): + """An asyncio implementation of BentoService servicer.""" + + async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method + self, + request: pb.Request, + context: BentoServicerContext, + ) -> pb.Response | None: + if request.api_name not in service.apis: + raise InvalidArgument( + f"given 'api_name' is not defined in {service.name}", + ) from None + + api = service.apis[request.api_name] + response = pb.Response() + + # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved. + # This is important so that we know the order of fields to process. + # We will use fields descriptor to determine how to process that request. + try: + # we will check if the given fields list contains a pb.Multipart. + field = request.WhichOneof("content") + if field is None: + raise InvalidArgument("Request cannot be empty.") + accepted_fields = api.input._proto_fields + ("serialized_bytes",) + if field not in accepted_fields: + raise InvalidArgument( + f"'{api.input.__class__.__name__}' accepts one of the following fields: '{', '.join(accepted_fields)}', and none of them are found in the request message.", + ) from None + input_ = await api.input.from_proto(getattr(request, field)) + if asyncio.iscoroutinefunction(api.func): + if isinstance(api.input, Multipart): + output = await api.func(**input_) + else: + output = await api.func(input_) + else: + if isinstance(api.input, Multipart): + output = await anyio.to_thread.run_sync(api.func, **input_) + else: + output = await anyio.to_thread.run_sync(api.func, input_) + protos = await api.output.to_proto(output) + # TODO(aarnphm): support multiple proto fields + response = pb.Response(**{api.output._proto_fields[0]: protos}) + except BentoMLException as e: + log_exception(request, sys.exc_info()) + await context.abort(code=grpc_status_code(e), details=e.message) + except (RuntimeError, TypeError, NotImplementedError): + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="A runtime error has occurred, see stacktrace from logs.", + ) + except Exception: # pylint: disable=broad-except + log_exception(request, sys.exc_info()) + await context.abort( + code=grpc.StatusCode.INTERNAL, + details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.", + ) + return response + + return BentoServiceImpl() diff --git a/bentoml/_internal/server/grpc_app.py b/bentoml/_internal/server/grpc_app.py new file mode 100644 index 0000000000..c14756929b --- /dev/null +++ b/bentoml/_internal/server/grpc_app.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import typing as t +import logging +from typing import TYPE_CHECKING +from functools import partial + +from simple_di import inject +from simple_di import Provide + +from ..configuration.containers import BentoMLContainer + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + + from bentoml.grpc.types import Interceptors + + from ..service import Service + from .grpc.servicer import Servicer + + OnStartup = list[t.Callable[[], None | t.Coroutine[t.Any, t.Any, None]]] + + +class GRPCAppFactory: + """ + GRPCApp creates an async gRPC API server based on APIs defined with a BentoService via BentoService#apis. + This is a light wrapper around GRPCServer with addition to `on_startup` and `on_shutdown` hooks. + + Note that even though the code are similar with BaseAppFactory, gRPC protocol is different from ASGI. + """ + + @inject + def __init__( + self, + bento_service: Service, + *, + enable_metrics: bool = Provide[ + BentoMLContainer.api_server_config.metrics.enabled + ], + ) -> None: + self.bento_service = bento_service + self.enable_metrics = enable_metrics + + @property + def on_startup(self) -> OnStartup: + on_startup: OnStartup = [self.bento_service.on_grpc_server_startup] + if BentoMLContainer.development_mode.get(): + for runner in self.bento_service.runners: + on_startup.append(partial(runner.init_local, quiet=True)) + else: + for runner in self.bento_service.runners: + on_startup.append(runner.init_client) + + return on_startup + + @property + def on_shutdown(self) -> list[t.Callable[[], None]]: + on_shutdown = [self.bento_service.on_grpc_server_shutdown] + for runner in self.bento_service.runners: + on_shutdown.append(runner.destroy) + + return on_shutdown + + def __call__(self) -> Servicer: + from .grpc import Servicer + + return Servicer( + self.bento_service, + on_startup=self.on_startup, + on_shutdown=self.on_shutdown, + mount_servicers=self.bento_service.mount_servicers, + interceptors=self.interceptors, + ) + + @property + def interceptors(self) -> Interceptors: + # Note that order of interceptors is important here. + + from bentoml.grpc.interceptors.opentelemetry import ( + AsyncOpenTelemetryServerInterceptor, + ) + + interceptors: Interceptors = [AsyncOpenTelemetryServerInterceptor] + + if self.enable_metrics: + from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor + + interceptors.append(PrometheusServerInterceptor) + + if BentoMLContainer.api_server_config.logging.access.enabled.get(): + from bentoml.grpc.interceptors.access import AccessLogServerInterceptor + + access_logger = logging.getLogger("bentoml.access") + if access_logger.getEffectiveLevel() <= logging.INFO: + interceptors.append(AccessLogServerInterceptor) + + # add users-defined interceptors. + interceptors.extend(self.bento_service.interceptors) + + return interceptors diff --git a/bentoml/_internal/server/http/__init__.py b/bentoml/_internal/server/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bentoml/_internal/server/access.py b/bentoml/_internal/server/http/access.py similarity index 95% rename from bentoml/_internal/server/access.py rename to bentoml/_internal/server/http/access.py index 972fa44c79..3c146abf26 100644 --- a/bentoml/_internal/server/access.py +++ b/bentoml/_internal/server/http/access.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import logging from timeit import default_timer from typing import TYPE_CHECKING from contextvars import ContextVar if TYPE_CHECKING: - from .. import external_typing as ext + from ... import external_typing as ext REQ_CONTENT_LENGTH = "REQUEST_CONTENT_LENGTH" REQ_CONTENT_TYPE = "REQUEST_CONTENT_TYPE" @@ -41,7 +43,7 @@ class AccessLogMiddleware: def __init__( self, - app: "ext.ASGIApp", + app: ext.ASGIApp, has_request_content_length: bool = False, has_request_content_type: bool = False, has_response_content_length: bool = False, @@ -56,9 +58,9 @@ def __init__( async def __call__( self, - scope: "ext.ASGIScope", - receive: "ext.ASGIReceive", - send: "ext.ASGISend", + scope: ext.ASGIScope, + receive: ext.ASGIReceive, + send: ext.ASGISend, ) -> None: if not scope["type"].startswith("http"): await self.app(scope, receive, send) diff --git a/bentoml/_internal/server/instruments.py b/bentoml/_internal/server/http/instruments.py similarity index 94% rename from bentoml/_internal/server/instruments.py rename to bentoml/_internal/server/http/instruments.py index 06e8df2359..f692f62ca4 100644 --- a/bentoml/_internal/server/instruments.py +++ b/bentoml/_internal/server/http/instruments.py @@ -8,25 +8,23 @@ from simple_di import inject from simple_di import Provide -from ..utils.metrics import metric_name -from ..configuration.containers import BentoMLContainer +from ...context import component_context +from ...utils.metrics import metric_name +from ...configuration.containers import BentoMLContainer if TYPE_CHECKING: - from .. import external_typing as ext - from ..server.metrics.prometheus import PrometheusClient + from ... import external_typing as ext + from ...server.metrics.prometheus import PrometheusClient logger = logging.getLogger(__name__) -START_TIME_VAR: "contextvars.ContextVar[float]" = contextvars.ContextVar( - "START_TIME_VAR" -) -STATUS_VAR: "contextvars.ContextVar[int]" = contextvars.ContextVar("STATUS_VAR") -from ..context import component_context +START_TIME_VAR: contextvars.ContextVar[float] = contextvars.ContextVar("START_TIME_VAR") +STATUS_VAR: contextvars.ContextVar[int] = contextvars.ContextVar("STATUS_VAR") class HTTPTrafficMetricsMiddleware: def __init__( self, - app: "ext.ASGIApp", + app: ext.ASGIApp, namespace: str = "bentoml_api_server", ): self.app = app @@ -36,7 +34,7 @@ def __init__( @inject def _setup( self, - metrics_client: "PrometheusClient" = Provide[BentoMLContainer.metrics_client], + metrics_client: PrometheusClient = Provide[BentoMLContainer.metrics_client], duration_buckets: tuple[float, ...] = Provide[ BentoMLContainer.duration_buckets ], @@ -105,9 +103,9 @@ def _setup( async def __call__( self, - scope: "ext.ASGIScope", - receive: "ext.ASGIReceive", - send: "ext.ASGISend", + scope: ext.ASGIScope, + receive: ext.ASGIReceive, + send: ext.ASGISend, ) -> None: if not self._is_setup: self._setup() diff --git a/bentoml/_internal/server/service_app.py b/bentoml/_internal/server/http_app.py similarity index 95% rename from bentoml/_internal/server/service_app.py rename to bentoml/_internal/server/http_app.py index e20737a0ac..3e4e33ac9a 100644 --- a/bentoml/_internal/server/service_app.py +++ b/bentoml/_internal/server/http_app.py @@ -85,9 +85,9 @@ def log_exception(request: Request, exc_info: t.Any) -> None: ) -class ServiceAppFactory(BaseAppFactory): +class HTTPAppFactory(BaseAppFactory): """ - ServiceApp creates a REST API server based on APIs defined with a BentoService + HTTPApp creates a REST API server based on APIs defined with a BentoService via BentoService#apis. Each InferenceAPI will become one endpoint exposed on the REST server, and the RequestHandler defined on each InferenceAPI object will be used to handle Request object before feeding the @@ -98,10 +98,8 @@ class ServiceAppFactory(BaseAppFactory): def __init__( self, bento_service: Service, - enable_access_control: bool = Provide[ - BentoMLContainer.api_server_config.cors.enabled - ], - access_control_options: dict[str, list[str] | int] = Provide[ + enable_access_control: bool = Provide[BentoMLContainer.http.cors.enabled], + access_control_options: dict[str, list[str] | str | int] = Provide[ BentoMLContainer.access_control_options ], enable_metrics: bool = Provide[ @@ -217,14 +215,14 @@ def middlewares(self) -> list[Middleware]: # metrics middleware if self.enable_metrics: - from .instruments import HTTPTrafficMetricsMiddleware + from .http.instruments import HTTPTrafficMetricsMiddleware middlewares.append(Middleware(HTTPTrafficMetricsMiddleware)) # otel middleware - import opentelemetry.instrumentation.asgi as otel_asgi # type: ignore + import opentelemetry.instrumentation.asgi as otel_asgi - def client_request_hook(span: Span, _scope: dict[str, t.Any]) -> None: + def client_request_hook(span: Span, _: dict[str, t.Any]) -> None: if span is not None: trace_context.request_id = span.context.span_id @@ -241,7 +239,7 @@ def client_request_hook(span: Span, _scope: dict[str, t.Any]) -> None: access_log_config = BentoMLContainer.api_server_config.logging.access if access_log_config.enabled.get(): - from .access import AccessLogMiddleware + from .http.access import AccessLogMiddleware access_logger = logging.getLogger("bentoml.access") if access_logger.getEffectiveLevel() <= logging.INFO: diff --git a/bentoml/_internal/server/metrics/prometheus.py b/bentoml/_internal/server/metrics/prometheus.py index f18a5084af..ce2230e066 100644 --- a/bentoml/_internal/server/metrics/prometheus.py +++ b/bentoml/_internal/server/metrics/prometheus.py @@ -1,10 +1,17 @@ -# type: ignore[reportMissingTypeStubs] +from __future__ import annotations + import os import sys import typing as t import logging +from typing import TYPE_CHECKING from functools import partial +if TYPE_CHECKING: + from prometheus_client.metrics_core import Metric + + from ... import external_typing as ext + logger = logging.getLogger(__name__) @@ -13,7 +20,7 @@ def __init__( self, *, multiproc: bool = True, - multiproc_dir: t.Optional[str] = None, + multiproc_dir: str | None = None, ): """ Set up multiproc_dir for prometheus to work in multiprocess mode, @@ -87,6 +94,9 @@ def start_http_server(self, port: int, addr: str = "") -> None: registry=self.registry, ) + def make_wsgi_app(self) -> ext.WSGIApp: + return self.prometheus_client.make_wsgi_app(registry=self.registry) # type: ignore (unfinished prometheus types) + def generate_latest(self): if self.multiproc: registry = self.prometheus_client.CollectorRegistry() @@ -95,6 +105,13 @@ def generate_latest(self): else: return self.prometheus_client.generate_latest() + def text_string_to_metric_families(self) -> t.Generator[Metric, None, None]: + from prometheus_client.parser import text_string_to_metric_families + + yield from text_string_to_metric_families( + self.generate_latest().decode("utf-8") + ) + @property def CONTENT_TYPE_LATEST(self) -> str: return self.prometheus_client.CONTENT_TYPE_LATEST diff --git a/bentoml/_internal/server/runner_app.py b/bentoml/_internal/server/runner_app.py index f463590ba6..97a59f5395 100644 --- a/bentoml/_internal/server/runner_app.py +++ b/bentoml/_internal/server/runner_app.py @@ -183,13 +183,13 @@ def client_request_hook(span: Span, _scope: t.Dict[str, t.Any]) -> None: ) if self.enable_metrics: - from .instruments import RunnerTrafficMetricsMiddleware + from .http.instruments import RunnerTrafficMetricsMiddleware middlewares.append(Middleware(RunnerTrafficMetricsMiddleware)) access_log_config = BentoMLContainer.runners_config.logging.access if access_log_config.enabled.get(): - from .access import AccessLogMiddleware + from .http.access import AccessLogMiddleware access_logger = logging.getLogger("bentoml.access") if access_logger.getEffectiveLevel() <= logging.INFO: diff --git a/bentoml/_internal/service/service.py b/bentoml/_internal/service/service.py index 3c654438e2..864be97f18 100644 --- a/bentoml/_internal/service/service.py +++ b/bentoml/_internal/service/service.py @@ -3,6 +3,7 @@ import typing as t import logging from typing import TYPE_CHECKING +from functools import partial import attr @@ -16,13 +17,19 @@ from ..io_descriptors import IODescriptor if TYPE_CHECKING: + import grpc + + from bentoml.grpc.types import AddServicerFn + from bentoml.grpc.types import ServicerClass + from .. import external_typing as ext from ..bento import Bento + from ..server.grpc.servicer import Servicer from .openapi.specification import OpenAPISpecification +else: + from bentoml.grpc.utils import import_grpc - WSGI_APP = t.Callable[ - [t.Callable[..., t.Any], t.Mapping[str, t.Any]], t.Iterable[bytes] - ] + grpc, _ = import_grpc() logger = logging.getLogger(__name__) @@ -82,6 +89,7 @@ class Service: runners: t.List[Runner] models: t.List[Model] + # starlette related mount_apps: t.List[t.Tuple[ext.ASGIApp, str, str]] = attr.field( init=False, factory=list ) @@ -89,6 +97,16 @@ class Service: t.Tuple[t.Type[ext.AsgiMiddleware], t.Dict[str, t.Any]] ] = attr.field(init=False, factory=list) + # gRPC related + mount_servicers: list[tuple[ServicerClass, AddServicerFn, list[str]]] = attr.field( + init=False, factory=list + ) + interceptors: list[partial[grpc.aio.ServerInterceptor]] = attr.field( + init=False, factory=list + ) + grpc_handlers: list[grpc.GenericRpcHandler] = attr.field(init=False, factory=list) + + # list of APIs from @svc.api apis: t.Dict[str, InferenceAPI] = attr.field(init=False, factory=dict) # Tag/Bento are only set when the service was loaded from a bento @@ -197,11 +215,23 @@ def on_asgi_app_startup(self) -> None: def on_asgi_app_shutdown(self) -> None: pass + def on_grpc_server_startup(self) -> None: + pass + + def on_grpc_server_shutdown(self) -> None: + pass + + @property + def grpc_servicer(self) -> Servicer: + from ..server.grpc_app import GRPCAppFactory + + return GRPCAppFactory(self)() + @property def asgi_app(self) -> "ext.ASGIApp": - from ..server.service_app import ServiceAppFactory + from ..server.http_app import HTTPAppFactory - return ServiceAppFactory(self)() + return HTTPAppFactory(self)() def mount_asgi_app( self, app: "ext.ASGIApp", path: str = "/", name: t.Optional[str] = None @@ -209,7 +239,7 @@ def mount_asgi_app( self.mount_apps.append((app, path, name)) # type: ignore def mount_wsgi_app( - self, app: WSGI_APP, path: str = "/", name: t.Optional[str] = None + self, app: ext.WSGIApp, path: str = "/", name: t.Optional[str] = None ) -> None: # TODO: Migrate to a2wsgi from starlette.middleware.wsgi import WSGIMiddleware @@ -217,10 +247,46 @@ def mount_wsgi_app( self.mount_apps.append((WSGIMiddleware(app), path, name)) # type: ignore def add_asgi_middleware( - self, middleware_cls: t.Type["ext.AsgiMiddleware"], **options: t.Any + self, middleware_cls: t.Type[ext.AsgiMiddleware], **options: t.Any ) -> None: self.middlewares.append((middleware_cls, options)) + def mount_grpc_servicer( + self, + servicer_cls: ServicerClass, + add_servicer_fn: AddServicerFn, + service_names: list[str], + ) -> None: + self.mount_servicers.append((servicer_cls, add_servicer_fn, service_names)) + + def add_grpc_interceptor( + self, interceptor_cls: t.Type[grpc.aio.ServerInterceptor], **options: t.Any + ) -> None: + from bentoml.exceptions import BadInput + + if not issubclass(interceptor_cls, grpc.aio.ServerInterceptor): + if isinstance(interceptor_cls, partial): + if options: + logger.debug( + "'%s' is a partial class, hence '%s' will be ignored.", + interceptor_cls, + options, + ) + if not issubclass(interceptor_cls.func, grpc.aio.ServerInterceptor): + raise BadInput( + "'partial' class is not a subclass of 'grpc.aio.ServerInterceptor'." + ) + self.interceptors.append(interceptor_cls) + else: + raise BadInput( + f"{interceptor_cls} is not a subclass of 'grpc.aio.ServerInterceptor'." + ) + + self.interceptors.append(partial(interceptor_cls, **options)) + + def add_grpc_handlers(self, handlers: list[grpc.GenericRpcHandler]) -> None: + self.grpc_handlers.extend(handlers) + def on_load_bento(svc: Service, bento: Bento): object.__setattr__(svc, "bento", bento) diff --git a/bentoml/_internal/utils/__init__.py b/bentoml/_internal/utils/__init__.py index a240d04c9f..71ad30f259 100644 --- a/bentoml/_internal/utils/__init__.py +++ b/bentoml/_internal/utils/__init__.py @@ -6,6 +6,8 @@ import random import socket import typing as t +import inspect +import logging import functools import contextlib from typing import overload @@ -56,48 +58,41 @@ "validate_or_create_dir", "display_path_under_home", "rich_console", + "experimental", ] +_EXPERIMENTAL_APIS: set[str] = set() -@overload -def kwargs_transformers( - func: GenericFunction[t.Concatenate[str, bool, t.Iterable[str], P]], - *, - transformer: GenericFunction[t.Any], -) -> GenericFunction[t.Concatenate[str, t.Iterable[str], bool, P]]: - ... +def _warn_experimental(f: t.Any): + api_name = f.__name__ if inspect.isfunction(f) else repr(f) + if api_name not in _EXPERIMENTAL_APIS: + _EXPERIMENTAL_APIS.add(api_name) + msg = "'%s' is an EXPERIMENTAL API and is currently not yet stable. Proceed with caution!" + logger = logging.getLogger(f.__module__) + logger.warning(msg, api_name) -@overload -def kwargs_transformers( - func: None = None, *, transformer: GenericFunction[t.Any] -) -> GenericFunction[t.Any]: - ... +def experimental(f: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: + @functools.wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: + _warn_experimental(f) + return f(*args, **kwargs) -def kwargs_transformers( - _func: t.Callable[..., t.Any] | None = None, - *, - transformer: GenericFunction[t.Any], -) -> GenericFunction[t.Any]: - def decorator(func: GenericFunction[t.Any]) -> t.Callable[P, t.Any]: - @functools.wraps(func) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: - return func(*args, **{k: transformer(v) for k, v in kwargs.items()}) + return wrapper - return wrapper - if _func is None: - return decorator - return decorator(_func) +def add_experimental_docstring(f: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: + f.__doc__ = "[EXPERIMENTAL] " + (f.__doc__ if f.__doc__ is not None else "") + return f -@t.overload +@overload def first_not_none(*args: T | None, default: T) -> T: ... -@t.overload +@overload def first_not_none(*args: T | None) -> T | None: ... @@ -179,13 +174,27 @@ def _(*args: P.args, **kwargs: P.kwargs) -> t.Optional[_T_co]: @contextlib.contextmanager def reserve_free_port( host: str = "localhost", + port: int | None = None, prefix: t.Optional[str] = None, max_retry: int = 50, + enable_so_reuseport: bool = False, ) -> t.Iterator[int]: """ detect free port and reserve until exit the context """ + import psutil + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if enable_so_reuseport: + if psutil.WINDOWS: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + elif psutil.MACOS or psutil.FREEBSD: + sock.setsockopt(socket.SOL_SOCKET, 0x10000, 1) # SO_REUSEPORT_LB + else: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: + raise RuntimeError("Failed to set SO_REUSEPORT.") from None if prefix is not None: prefix_num = int(prefix) * 10 ** (5 - len(prefix)) suffix_range = min(65535 - prefix_num, 10 ** (5 - len(prefix))) @@ -199,13 +208,17 @@ def reserve_free_port( continue else: raise RuntimeError( - f"cannot find free port with prefix {prefix} after {max_retry} retries" - ) + f"Cannot find free port with prefix {prefix} after {max_retry} retries." + ) from None else: - sock.bind((host, 0)) - port = sock.getsockname()[1] - yield port - sock.close() + if port: + sock.bind((host, port)) + else: + sock.bind((host, 0)) + try: + yield sock.getsockname()[1] + finally: + sock.close() def copy_file_to_fs_folder( @@ -359,8 +372,6 @@ def __call__( @contextlib.contextmanager @functools.wraps(func) def _func(*args: "P.args", **kwargs: "P.kwargs") -> t.Any: - import inspect - bound_args = inspect.signature(func).bind(*args, **kwargs) bound_args.apply_defaults() if self._cache_key_template: diff --git a/bentoml/_internal/utils/buildx.py b/bentoml/_internal/utils/buildx.py index 34bb1d9139..53248be33a 100644 --- a/bentoml/_internal/utils/buildx.py +++ b/bentoml/_internal/utils/buildx.py @@ -140,8 +140,8 @@ def build( build_args: dict[str, str] | None, build_context: dict[str, str] | None, builder: str | None, - cache_from: str | list[str] | dict[str, str] | None, - cache_to: str | list[str] | dict[str, str] | None, + cache_from: str | t.Iterable[str] | dict[str, str] | None, + cache_to: str | t.Iterable[str] | dict[str, str] | None, cgroup_parent: str | None, file: PathType | None, iidfile: PathType | None, @@ -150,14 +150,14 @@ def build( metadata_file: PathType | None, network: str | None, no_cache: bool, - no_cache_filter: list[str] | None, + no_cache_filter: t.Iterable[str] | None, output: str | dict[str, str] | None, - platform: str | list[str] | None, + platform: str | t.Iterable[str] | None, progress: t.Literal["auto", "tty", "plain"], pull: bool, push: bool, quiet: bool, - secrets: str | list[str] | None, + secrets: str | t.Iterable[str] | None, shm_size: str | int | None, rm: bool, ssh: str | None, diff --git a/bentoml/_internal/utils/platform.py b/bentoml/_internal/utils/platform.py deleted file mode 100644 index 0811724bbf..0000000000 --- a/bentoml/_internal/utils/platform.py +++ /dev/null @@ -1,29 +0,0 @@ -import signal -import subprocess -from typing import TYPE_CHECKING - -import psutil - -if TYPE_CHECKING: - import typing as t - - -def kill_subprocess_tree(p: "subprocess.Popen[t.Any]") -> None: - """ - Tell the process to terminate and kill all of its children. Availabe both on Windows and Linux. - Note: It will return immediately rather than wait for the process to terminate. - - Args: - p: subprocess.Popen object - """ - if psutil.WINDOWS: - subprocess.call(["taskkill", "/F", "/T", "/PID", str(p.pid)]) - else: - p.terminate() - - -def cancel_subprocess(p: "subprocess.Popen[t.Any]") -> None: - if psutil.WINDOWS: - p.send_signal(signal.CTRL_C_EVENT) # type: ignore - else: - p.send_signal(signal.SIGINT) diff --git a/bentoml/bentos.py b/bentoml/bentos.py index 27d2c7c3bf..57585062a1 100644 --- a/bentoml/bentos.py +++ b/bentoml/bentos.py @@ -1,3 +1,4 @@ +# pylint: disable=unused-argument """ User facing python APIs for managing local bentos and build new bentos """ @@ -5,10 +6,14 @@ from __future__ import annotations import os +import sys import typing as t import logging +import tempfile +import contextlib import subprocess from typing import TYPE_CHECKING +from functools import partial from simple_di import inject from simple_di import Provide @@ -22,10 +27,14 @@ from ._internal.bento.build_config import BentoBuildConfig from ._internal.configuration.containers import BentoMLContainer +if sys.version_info >= (3, 8): + from shutil import copytree +else: + from backports.shutil_copytree import copytree + if TYPE_CHECKING: from ._internal.bento import BentoStore from ._internal.types import PathType - from ._internal.models import ModelStore logger = logging.getLogger(__name__) @@ -261,7 +270,6 @@ def build( version: t.Optional[str] = None, build_ctx: t.Optional[str] = None, _bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store], - _model_store: "ModelStore" = Provide[BentoMLContainer.model_store], ) -> "Bento": """ User-facing API for building a Bento. The available build options are identical to the keys of a @@ -287,7 +295,6 @@ def build( version: Override the default auto generated version str build_ctx: Build context directory, when used as _bento_store: save Bento created to this BentoStore - _model_store: pull Models required from this ModelStore Returns: Bento: a Bento instance representing the materialized Bento saved in BentoStore @@ -355,7 +362,6 @@ def build_bentofile( version: t.Optional[str] = None, build_ctx: t.Optional[str] = None, _bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store], - _model_store: "ModelStore" = Provide[BentoMLContainer.model_store], ) -> "Bento": """ Build a Bento base on options specified in a bentofile.yaml file. @@ -368,7 +374,6 @@ def build_bentofile( version: Override the default auto generated version str build_ctx: Build context directory, when used as _bento_store: save Bento created to this BentoStore - _model_store: pull Models required from this ModelStore """ try: bentofile = resolve_user_filepath(bentofile, build_ctx) @@ -388,18 +393,70 @@ def build_bentofile( return bento +@contextlib.contextmanager +def construct_dockerfile( + bento: Bento, + *, + features: t.Sequence[str] | None = None, + docker_final_stage: str | None = None, +) -> t.Generator[tuple[str, str], None, None]: + dockerfile_path = os.path.join("env", "docker", "Dockerfile") + final_instruction = "" + if features is not None: + features = [l for s in map(lambda x: x.split(","), features) for l in s] + if not all(f in FEATURES for f in features): + raise InvalidArgument( + f"Available features are: {FEATURES}. Invalid fields from provided: {set(features) - set(FEATURES)}" + ) + final_instruction += f"""\ +RUN --mount=type=cache,target=/root/.cache/pip pip install bentoml[{','.join(features)}] +""" + if docker_final_stage: + final_instruction += f"""\ +{docker_final_stage} +""" + with open(bento.path_of(dockerfile_path), "r") as f: + FINAL_DOCKERFILE = f"""\ +{f.read()} +FROM base-{bento.info.docker.distro} +# Additional instructions for final image. +{final_instruction} +""" + if final_instruction != "": + with tempfile.TemporaryDirectory("bento-tmp") as tmpdir: + copytree(bento.path, tmpdir, dirs_exist_ok=True) + with open(os.path.join(tmpdir, dockerfile_path), "w") as dockerfile: + dockerfile.write(FINAL_DOCKERFILE) + yield tmpdir, dockerfile.name + else: + yield bento.path, dockerfile_path + + +# Sync with BentoML extra dependencies +FEATURES = [ + "tracing", + "grpc", + "tracing.zipkin", + "tracing.jaeger", + "tracing.otlp", +] + + @inject def containerize( tag: Tag | str, - docker_image_tag: t.List[str] | None = None, + docker_image_tag: t.Iterable[str] | None = None, *, + # containerize options + features: t.Sequence[str] | None = None, + # docker options add_host: dict[str, str] | None = None, allow: t.List[str] | None = None, build_args: dict[str, str] | None = None, build_context: dict[str, str] | None = None, builder: str | None = None, - cache_from: str | t.List[str] | dict[str, str] | None = None, - cache_to: str | t.List[str] | dict[str, str] | None = None, + cache_from: str | t.Iterable[str] | dict[str, str] | None = None, + cache_to: str | t.Iterable[str] | dict[str, str] | None = None, cgroup_parent: str | None = None, iidfile: PathType | None = None, labels: dict[str, str] | None = None, @@ -407,14 +464,14 @@ def containerize( metadata_file: PathType | None = None, network: str | None = None, no_cache: bool = False, - no_cache_filter: t.List[str] | None = None, + no_cache_filter: t.Iterable[str] | None = None, output: str | dict[str, str] | None = None, - platform: str | t.List[str] | None = None, + platform: str | t.Iterable[str] | None = None, progress: t.Literal["auto", "tty", "plain"] = "auto", pull: bool = False, push: bool = False, quiet: bool = False, - secrets: str | t.List[str] | None = None, + secrets: str | t.Iterable[str] | None = None, shm_size: str | int | None = None, rm: bool = False, ssh: str | None = None, @@ -423,6 +480,8 @@ def containerize( _bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store], ) -> bool: + import psutil + from bentoml._internal.utils import buildx env = {"DOCKER_BUILDKIT": "1", "DOCKER_SCAN_SUGGEST": "false"} @@ -431,61 +490,58 @@ def containerize( if not docker_image_tag: docker_image_tag = [str(bento.tag)] - dockerfile_path = os.path.join("env", "docker", "Dockerfile") - logger.info(f"Building docker image for {bento}...") - try: - buildx.build( - subprocess_env=env, - cwd=bento.path, - file=dockerfile_path, - tags=docker_image_tag, - add_host=add_host, - allow=allow, - build_args=build_args, - build_context=build_context, - builder=builder, - cache_from=cache_from, - cache_to=cache_to, - cgroup_parent=cgroup_parent, - iidfile=iidfile, - labels=labels, - load=load, - metadata_file=metadata_file, - network=network, - no_cache=no_cache, - no_cache_filter=no_cache_filter, - output=output, - platform=platform, - progress=progress, - pull=pull, - push=push, - quiet=quiet, - secrets=secrets, - shm_size=shm_size, - rm=rm, - ssh=ssh, - target=target, - ulimit=ulimit, + if platform and not psutil.LINUX and platform != "linux/amd64": + logger.warning( + 'Current platform is set to "%s". To avoid issue, we recommend you to build the container with x86_64 (amd64): "bentoml containerize %s --platform linux/amd64"', + ",".join(platform), + str(bento.tag), ) + run_buildx = partial( + buildx.build, + subprocess_env=env, + tags=docker_image_tag, + add_host=add_host, + allow=allow, + build_args=build_args, + build_context=build_context, + builder=builder, + cache_from=cache_from, + cache_to=cache_to, + cgroup_parent=cgroup_parent, + iidfile=iidfile, + labels=labels, + load=load, + metadata_file=metadata_file, + network=network, + no_cache=no_cache, + no_cache_filter=no_cache_filter, + output=output, + platform=platform, + progress=progress, + pull=pull, + push=push, + quiet=quiet, + secrets=secrets, + shm_size=shm_size, + rm=rm, + ssh=ssh, + target=target, + ulimit=ulimit, + ) + clean_context = contextlib.ExitStack() + required = clean_context.enter_context( + construct_dockerfile(bento, features=features) + ) + try: + build_path, dockerfile_path = required + run_buildx(cwd=build_path, file=dockerfile_path) + return True except subprocess.CalledProcessError as e: logger.error(f"Failed building docker image: {e}") - if platform != "linux/amd64": - logger.debug( - f"""If you run into the following error: "failed to solve: pull access denied, repository does not exist or may require authorization: server message: insufficient_scope: authorization failed". This means Docker doesn't have context of your build platform {platform}. By default BentoML will set target build platform to the current machine platform via `uname -m`. Try again by specifying to build x86_64 (amd64) platform: bentoml containerize {str(bento.tag)} --platform linux/amd64""" - ) return False - else: - logger.info( - 'Successfully built docker image for "%s" with tags "%s"', - str(bento.tag), - ",".join(docker_image_tag), - ) - logger.info( - 'To run your newly built Bento container, use one of the above tags, and pass it to "docker run". i.e: "docker run -it --rm -p 3000:3000 %s"', - docker_image_tag[0], - ) - return True + finally: + clean_context.close() __all__ = [ diff --git a/bentoml/exceptions.py b/bentoml/exceptions.py index 34924e7b28..3e9e292e9b 100644 --- a/bentoml/exceptions.py +++ b/bentoml/exceptions.py @@ -73,6 +73,14 @@ class NotFound(BentoMLException): error_code = HTTPStatus.NOT_FOUND +class UnprocessableEntity(BentoMLException): + """ + Raise when API server receiving unprocessable entity request + """ + + error_code = HTTPStatus.UNPROCESSABLE_ENTITY + + class TooManyRequests(BentoMLException): """ Raise when incoming requests exceeds the capacity of a server diff --git a/bentoml/grpc/__init__.py b/bentoml/grpc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bentoml/grpc/buf.yaml b/bentoml/grpc/buf.yaml new file mode 100644 index 0000000000..4e26d9abb8 --- /dev/null +++ b/bentoml/grpc/buf.yaml @@ -0,0 +1,22 @@ +version: v1 +lint: + use: + - DEFAULT + - COMMENT_ENUM + - COMMENT_MESSAGE + - COMMENT_RPC + - COMMENT_SERVICE + except: + - DIRECTORY_SAME_PACKAGE + - RPC_REQUEST_STANDARD_NAME + - RPC_RESPONSE_STANDARD_NAME + ignore_only: + DEFAULT: + - bentoml/grpc/v1alpha1/service_test.proto + ENUM_VALUE_PREFIX: + - bentoml/grpc/v1alpha1/service.proto + enum_zero_value_suffix: _UNSPECIFIED + rpc_allow_same_request_response: true + rpc_allow_google_protobuf_empty_requests: true + rpc_allow_google_protobuf_empty_responses: true + service_suffix: Service diff --git a/bentoml/grpc/interceptors/__init__.py b/bentoml/grpc/interceptors/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bentoml/grpc/interceptors/access.py b/bentoml/grpc/interceptors/access.py new file mode 100644 index 0000000000..a89cc0f0a1 --- /dev/null +++ b/bentoml/grpc/interceptors/access.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import typing as t +import logging +import functools +from timeit import default_timer +from typing import TYPE_CHECKING + +from bentoml.grpc.utils import to_http_status +from bentoml.grpc.utils import wrap_rpc_handler +from bentoml.grpc.utils import GRPC_CONTENT_TYPE + +if TYPE_CHECKING: + import grpc + from grpc import aio + from grpc.aio._typing import MetadataType # pylint: disable=unused-import + + from bentoml.grpc.types import Request + from bentoml.grpc.types import Response + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import AsyncHandlerMethod + from bentoml.grpc.types import HandlerCallDetails + from bentoml.grpc.types import BentoServicerContext + from bentoml.grpc.v1alpha1 import service_pb2 as pb +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + grpc, aio = import_grpc() + + +class AccessLogServerInterceptor(aio.ServerInterceptor): + """ + An asyncio interceptor for access logging. + """ + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + logger = logging.getLogger("bentoml.access") + handler = await continuation(handler_call_details) + method_name = handler_call_details.method + + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + content_type = GRPC_CONTENT_TYPE + trailing_metadata: MetadataType | None = context.trailing_metadata() + if trailing_metadata: + trailing = dict(trailing_metadata) + content_type = trailing.get("content-type", GRPC_CONTENT_TYPE) + + response = pb.Response() + start = default_timer() + try: + response = await behaviour(request, context) + except Exception as e: # pylint: disable=broad-except + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + finally: + latency = max(default_timer() - start, 0) * 1000 + + req = [ + "scheme=http", # TODO: support https when ssl is added + f"path={method_name}", + f"type={content_type}", + f"size={request.ByteSize()}", + ] + + # Note that in order AccessLogServerInterceptor to work, the + # interceptor must be added to the server after AsyncOpenTeleServerInterceptor + # and PrometheusServerInterceptor. + typed_context_code = t.cast(grpc.StatusCode, context.code()) + resp = [ + f"http_status={to_http_status(typed_context_code)}", + f"grpc_status={typed_context_code.value[0]}", + f"type={content_type}", + f"size={response.ByteSize()}", + ] + + logger.info( + "%s (%s) (%s) %.3fms", + context.peer(), + ",".join(req), + ",".join(resp), + latency, + ) + return response + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) diff --git a/bentoml/grpc/interceptors/opentelemetry.py b/bentoml/grpc/interceptors/opentelemetry.py new file mode 100644 index 0000000000..b36b19a12c --- /dev/null +++ b/bentoml/grpc/interceptors/opentelemetry.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import typing as t +import logging +import functools +from typing import TYPE_CHECKING +from contextlib import asynccontextmanager + +from simple_di import inject +from simple_di import Provide +from opentelemetry import trace +from opentelemetry.context import attach +from opentelemetry.context import detach +from opentelemetry.propagate import extract +from opentelemetry.trace.status import Status +from opentelemetry.trace.status import StatusCode +from opentelemetry.semconv.trace import SpanAttributes + +from bentoml.grpc.utils import wrap_rpc_handler +from bentoml.grpc.utils import GRPC_CONTENT_TYPE +from bentoml.grpc.utils import parse_method_name +from bentoml._internal.utils.pkg import get_pkg_version +from bentoml._internal.configuration.containers import BentoMLContainer + +if TYPE_CHECKING: + import grpc + from grpc import aio + from grpc.aio._typing import MetadataKey + from grpc.aio._typing import MetadataType + from grpc.aio._typing import MetadataValue + from opentelemetry.trace import Span + from opentelemetry.sdk.trace import TracerProvider + + from bentoml.grpc.types import Request + from bentoml.grpc.types import Response + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import AsyncHandlerMethod + from bentoml.grpc.types import HandlerCallDetails + from bentoml.grpc.types import BentoServicerContext +else: + from bentoml.grpc.utils import import_grpc + + grpc, aio = import_grpc() + +logger = logging.getLogger(__name__) + + +class _OpenTelemetryServicerContext(aio.ServicerContext["Request", "Response"]): + def __init__(self, servicer_context: BentoServicerContext, active_span: Span): + self._servicer_context = servicer_context + self._active_span = active_span + self._code = grpc.StatusCode.OK + self._details = "" + super().__init__() + + def __getattr__(self, attr: str) -> t.Any: + return getattr(self._servicer_context, attr) + + async def read(self) -> Request: + return await self._servicer_context.read() + + async def write(self, message: Response) -> None: + return await self._servicer_context.write(message) + + def trailing_metadata(self) -> aio.Metadata: + return self._servicer_context.trailing_metadata() # type: ignore (unfinished type) + + def auth_context(self) -> t.Mapping[str, t.Iterable[bytes]]: + return self._servicer_context.auth_context() + + def peer_identity_key(self) -> str | None: + return self._servicer_context.peer_identity_key() + + def peer_identities(self) -> t.Iterable[bytes] | None: + return self._servicer_context.peer_identities() + + def peer(self) -> str: + return self._servicer_context.peer() + + def disable_next_message_compression(self) -> None: + self._servicer_context.disable_next_message_compression() + + def set_compression(self, compression: grpc.Compression) -> None: + return self._servicer_context.set_compression(compression) + + def invocation_metadata(self) -> aio.Metadata | None: + return self._servicer_context.invocation_metadata() + + def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None: + self._servicer_context.set_trailing_metadata(trailing_metadata) + + async def send_initial_metadata(self, initial_metadata: MetadataType) -> None: + return await self._servicer_context.send_initial_metadata(initial_metadata) + + async def abort( + self, + code: grpc.StatusCode, + details: str = "", + trailing_metadata: MetadataType = tuple(), + ) -> None: + self._code = code + self._details = details + self._active_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0] + ) + self._active_span.set_status( + Status(status_code=StatusCode.ERROR, description=f"{code}:{details}") + ) + return await self._servicer_context.abort( + code, details=details, trailing_metadata=trailing_metadata + ) + + def set_code(self, code: grpc.StatusCode) -> None: + self._code = code + details = self._details or code.value[1] + self._active_span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0] + ) + if code != grpc.StatusCode.OK: + self._active_span.set_status( + Status(status_code=StatusCode.ERROR, description=f"{code}:{details}") + ) + return self._servicer_context.set_code(code) + + def code(self) -> grpc.StatusCode: + return self._code + + def set_details(self, details: str) -> None: + self._details = details + if self._code != grpc.StatusCode.OK: + self._active_span.set_status( + Status( + status_code=StatusCode.ERROR, description=f"{self._code}:{details}" + ) + ) + return self._servicer_context.set_details(details) + + def details(self) -> str: + return self._details + + +# Since opentelemetry doesn't provide an async implementation for the server interceptor, +# we will need to create an async implementation ourselves. +# By doing this we will have more control over how to handle span and context propagation. +# +# Until there is a solution upstream, this implementation is sufficient for our needs. +class AsyncOpenTelemetryServerInterceptor(aio.ServerInterceptor): + @inject + def __init__( + self, + *, + tracer_provider: TracerProvider = Provide[BentoMLContainer.tracer_provider], + schema_url: str | None = None, + ): + self._tracer = tracer_provider.get_tracer( + "opentelemetry.instrumentation.grpc", + get_pkg_version("opentelemetry-instrumentation-grpc"), + schema_url=schema_url, + ) + + @asynccontextmanager + async def set_remote_context( + self, servicer_context: BentoServicerContext + ) -> t.AsyncGenerator[None, None]: + metadata = servicer_context.invocation_metadata() + if metadata: + md: dict[MetadataKey, MetadataValue] = {m.key: m.value for m in metadata} + ctx = extract(md) + token = attach(ctx) + try: + yield + finally: + detach(token) + else: + yield + + def start_span( + self, + method_name: str, + context: BentoServicerContext, + set_status_on_exception: bool = False, + ) -> t.ContextManager[Span]: + attributes: dict[str, str | bytes] = { + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0], + } + + # method_name shouldn't be none, otherwise + # it will never reach this point. + method_rpc, _ = parse_method_name(method_name) + attributes.update( + { + SpanAttributes.RPC_METHOD: method_rpc.method, + SpanAttributes.RPC_SERVICE: method_rpc.fully_qualified_service, + } + ) + + # add some attributes from the metadata + metadata = context.invocation_metadata() + if metadata: + dct: dict[str, str | bytes] = dict(metadata) + if "user-agent" in dct: + attributes["rpc.user_agent"] = dct["user-agent"] + + # get trailing metadata + trailing_metadata: MetadataType | None = context.trailing_metadata() + if trailing_metadata: + trailing = dict(trailing_metadata) + attributes["rpc.content_type"] = trailing.get( + "content-type", GRPC_CONTENT_TYPE + ) + + # Split up the peer to keep with how other telemetry sources + # do it. This looks like: + # * ipv6:[::1]:57284 + # * ipv4:127.0.0.1:57284 + # * ipv4:10.2.1.1:57284,127.0.0.1:57284 + # + # the process ip and port would be [::1] 57284 + try: + ipv4_addr = context.peer().split(",")[0] + ip, port = ipv4_addr.split(":", 1)[1].rsplit(":", 1) + attributes.update( + { + SpanAttributes.NET_PEER_IP: ip, + SpanAttributes.NET_PEER_PORT: port, + } + ) + # other telemetry sources add this, so we will too + if ip in ("[::1]", "127.0.0.1"): + attributes[SpanAttributes.NET_PEER_NAME] = "localhost" + except IndexError: + logger.warning(f"Failed to parse peer address '{context.peer()}'") + + return self._tracer.start_as_current_span( + name=method_name, + kind=trace.SpanKind.SERVER, + attributes=attributes, + set_status_on_exception=set_status_on_exception, + ) + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + handler = await continuation(handler_call_details) + method_name = handler_call_details.method + + # Currently not support streaming RPCs. + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + + async with self.set_remote_context(context): + with self.start_span(method_name, context) as span: + # wrap context + wrapped_context = _OpenTelemetryServicerContext(context, span) + + # And now we run the actual RPC. + try: + return await behaviour(request, wrapped_context) + except Exception as e: + # We are interested in uncaught exception, otherwise + # it will be handled by gRPC. + if type(e) != Exception: + span.record_exception(e) + raise e + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) diff --git a/bentoml/grpc/interceptors/prometheus.py b/bentoml/grpc/interceptors/prometheus.py new file mode 100644 index 0000000000..62424b7d5e --- /dev/null +++ b/bentoml/grpc/interceptors/prometheus.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import typing as t +import logging +import functools +import contextvars +from timeit import default_timer +from typing import TYPE_CHECKING + +from simple_di import inject +from simple_di import Provide + +from bentoml.grpc.utils import to_http_status +from bentoml.grpc.utils import wrap_rpc_handler +from bentoml._internal.context import component_context +from bentoml._internal.configuration.containers import BentoMLContainer + +START_TIME_VAR: contextvars.ContextVar[float] = contextvars.ContextVar("START_TIME_VAR") + +if TYPE_CHECKING: + import grpc + from grpc import aio + + from bentoml.grpc.types import Request + from bentoml.grpc.types import Response + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.types import AsyncHandlerMethod + from bentoml.grpc.types import HandlerCallDetails + from bentoml.grpc.types import BentoServicerContext + from bentoml.grpc.v1alpha1 import service_pb2 as pb + from bentoml._internal.server.metrics.prometheus import PrometheusClient +else: + from bentoml.grpc.utils import import_grpc + from bentoml.grpc.utils import import_generated_stubs + + pb, _ = import_generated_stubs() + grpc, aio = import_grpc() + + +logger = logging.getLogger(__name__) + + +class PrometheusServerInterceptor(aio.ServerInterceptor): + """ + An async interceptor for Prometheus metrics. + """ + + def __init__(self, *, namespace: str = "bentoml_api_server"): + self._is_setup = False + self.namespace = namespace + + @inject + def _setup( + self, + metrics_client: PrometheusClient = Provide[BentoMLContainer.metrics_client], + duration_buckets: tuple[float, ...] = Provide[ + BentoMLContainer.duration_buckets + ], + ): # pylint: disable=attribute-defined-outside-init + self.metrics_request_duration = metrics_client.Histogram( + namespace=self.namespace, + name="request_duration_seconds", + documentation="API GRPC request duration in seconds", + labelnames=[ + "api_name", + "service_name", + "service_version", + "http_response_code", + ], + buckets=duration_buckets, + ) + self.metrics_request_total = metrics_client.Counter( + namespace=self.namespace, + name="request_total", + documentation="Total number of GRPC requests", + labelnames=[ + "api_name", + "service_name", + "service_version", + "http_response_code", + ], + ) + self.metrics_request_in_progress = metrics_client.Gauge( + namespace=self.namespace, + name="request_in_progress", + documentation="Total number of GRPC requests in progress now", + labelnames=["api_name", "service_name", "service_version"], + multiprocess_mode="livesum", + ) + self._is_setup = True + + async def intercept_service( + self, + continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]], + handler_call_details: HandlerCallDetails, + ) -> RpcMethodHandler: + if not self._is_setup: + self._setup() + + handler = await continuation(handler_call_details) + + if handler and (handler.response_streaming or handler.request_streaming): + return handler + + START_TIME_VAR.set(default_timer()) + + def wrapper(behaviour: AsyncHandlerMethod[Response]): + @functools.wraps(behaviour) + async def new_behaviour( + request: Request, context: BentoServicerContext + ) -> Response | t.Awaitable[Response]: + if not isinstance(request, pb.Request): + return await behaviour(request, context) + + api_name = request.api_name + + # instrument request total count + self.metrics_request_total.labels( + api_name=api_name, + service_name=component_context.bento_name, + service_version=component_context.bento_version, + http_response_code=to_http_status( + t.cast(grpc.StatusCode, context.code()) + ), + ).inc() + + # instrument request duration + assert START_TIME_VAR.get() != 0 + total_time = max(default_timer() - START_TIME_VAR.get(), 0) + self.metrics_request_duration.labels( # type: ignore (unfinished prometheus types) + api_name=api_name, + service_name=component_context.bento_name, + service_version=component_context.bento_version, + http_response_code=to_http_status( + t.cast(grpc.StatusCode, context.code()) + ), + ).observe( + total_time + ) + START_TIME_VAR.set(0) + # instrument request in progress + with self.metrics_request_in_progress.labels( + api_name=api_name, + service_version=component_context.bento_version, + service_name=component_context.bento_name, + ).track_inprogress(): + response = await behaviour(request, context) + return response + + return new_behaviour + + return wrap_rpc_handler(wrapper, handler) diff --git a/bentoml/grpc/types.py b/bentoml/grpc/types.py new file mode 100644 index 0000000000..9fa1dceee3 --- /dev/null +++ b/bentoml/grpc/types.py @@ -0,0 +1,108 @@ +# pragma: no cover +""" +Specific types for BentoService gRPC server. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import typing as t + from functools import partial + + import grpc + from grpc import aio + + from bentoml.grpc.v1alpha1.service_pb2 import Request + from bentoml.grpc.v1alpha1.service_pb2 import Response + from bentoml.grpc.v1alpha1.service_pb2_grpc import BentoServiceServicer + + P = t.TypeVar("P") + + BentoServicerContext = aio.ServicerContext[Request, Response] + + RequestDeserializerFn = t.Callable[[Request | None], object] | None + ResponseSerializerFn = t.Callable[[bytes], Response | None] | None + + HandlerMethod = t.Callable[[Request, BentoServicerContext], P] + AsyncHandlerMethod = t.Callable[[Request, BentoServicerContext], t.Awaitable[P]] + + class RpcMethodHandler( + t.NamedTuple( + "RpcMethodHandler", + request_streaming=bool, + response_streaming=bool, + request_deserializer=RequestDeserializerFn, + response_serializer=ResponseSerializerFn, + unary_unary=t.Optional[HandlerMethod[Response]], + unary_stream=t.Optional[HandlerMethod[Response]], + stream_unary=t.Optional[HandlerMethod[Response]], + stream_stream=t.Optional[HandlerMethod[Response]], + ), + grpc.RpcMethodHandler, + ): + """An implementation of a single RPC method.""" + + request_streaming: bool + response_streaming: bool + request_deserializer: RequestDeserializerFn + response_serializer: ResponseSerializerFn + unary_unary: t.Optional[HandlerMethod[Response]] + unary_stream: t.Optional[HandlerMethod[Response]] + stream_unary: t.Optional[HandlerMethod[Response]] + stream_stream: t.Optional[HandlerMethod[Response]] + + class HandlerCallDetails( + t.NamedTuple( + "HandlerCallDetails", method=str, invocation_metadata=aio.Metadata + ), + grpc.HandlerCallDetails, + ): + """Describes an RPC that has just arrived for service. + + Attributes: + method: The method name of the RPC. + invocation_metadata: A sequence of metadatum, a key-value pair included in the HTTP header. + An example is: ``('binary-metadata-bin', b'\\x00\\xFF')`` + """ + + method: str + invocation_metadata: aio.Metadata + + # Servicer types + ServicerImpl = t.TypeVar("ServicerImpl") + Servicer = t.Annotated[ServicerImpl, object] + ServicerClass = t.Type[Servicer[t.Any]] + AddServicerFn = t.Callable[[Servicer[t.Any], aio.Server | grpc.Server], None] + + # accepted proto fields + ProtoField = t.Annotated[ + str, + t.Literal[ + "dataframe", + "file", + "json", + "ndarray", + "series", + "text", + "multipart", + "serialized_bytes", + ], + ] + + Interceptors = list[ + t.Callable[[], aio.ServerInterceptor] | partial[aio.ServerInterceptor] + ] + + # types defined for client interceptors + BentoUnaryUnaryCall = aio.UnaryUnaryCall[Request, Response] + + __all__ = [ + "Request", + "Response", + "BentoServicerContext", + "BentoServiceServicer", + "HandlerCallDetails", + "RpcMethodHandler", + "BentoUnaryUnaryCall", + ] diff --git a/bentoml/grpc/utils/__init__.py b/bentoml/grpc/utils/__init__.py new file mode 100644 index 0000000000..4f252146cc --- /dev/null +++ b/bentoml/grpc/utils/__init__.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import typing as t +import logging +from http import HTTPStatus +from typing import TYPE_CHECKING +from functools import lru_cache +from dataclasses import dataclass + +from bentoml._internal.utils.lazy_loader import LazyLoader + +if TYPE_CHECKING: + import types + from enum import Enum + + import grpc + from google.protobuf import descriptor as descriptor_mod + + from bentoml.exceptions import BentoMLException + from bentoml.grpc.types import RpcMethodHandler + from bentoml.grpc.v1alpha1 import service_pb2 as pb + + # We need this here so that __all__ is detected due to lazy import + def import_generated_stubs( + version: str = "v1alpha1", + ) -> tuple[types.ModuleType, types.ModuleType]: + ... + + def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: + ... + +else: + from bentoml.grpc.utils._import_hook import import_grpc + from bentoml.grpc.utils._import_hook import import_generated_stubs + + pb, _ = import_generated_stubs() + grpc, _ = import_grpc() + descriptor_mod = LazyLoader( + "descriptor_mod", globals(), "google.protobuf.descriptor" + ) + +__all__ = [ + "grpc_status_code", + "parse_method_name", + "to_http_status", + "GRPC_CONTENT_TYPE", + "import_generated_stubs", + "import_grpc", +] + +logger = logging.getLogger(__name__) + +# content-type is always application/grpc +GRPC_CONTENT_TYPE = "application/grpc" + + +def get_field_by_name( + descriptor: descriptor_mod.FieldDescriptor | descriptor_mod.Descriptor, + field: str, +) -> descriptor_mod.FieldDescriptor: + if isinstance(descriptor, descriptor_mod.FieldDescriptor): + # descriptor is a FieldDescriptor + return descriptor.message_type.fields_by_name[field] + elif isinstance(descriptor, descriptor_mod.Descriptor): + # descriptor is a Descriptor + return descriptor.fields_by_name[field] + else: + raise NotImplementedError(f"Type {type(descriptor)} is not yet supported.") + + +def is_map_field(field: descriptor_mod.FieldDescriptor) -> bool: + return ( + field.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE + and field.message_type.has_options + and field.message_type.GetOptions().map_entry + ) + + +@lru_cache(maxsize=1) +def http_status_to_grpc_status_map() -> dict[Enum, grpc.StatusCode]: + # Maps HTTP status code to grpc.StatusCode + from http import HTTPStatus + + return { + HTTPStatus.OK: grpc.StatusCode.OK, + HTTPStatus.UNAUTHORIZED: grpc.StatusCode.UNAUTHENTICATED, + HTTPStatus.FORBIDDEN: grpc.StatusCode.PERMISSION_DENIED, + HTTPStatus.NOT_FOUND: grpc.StatusCode.UNIMPLEMENTED, + HTTPStatus.TOO_MANY_REQUESTS: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.BAD_GATEWAY: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.SERVICE_UNAVAILABLE: grpc.StatusCode.UNAVAILABLE, + HTTPStatus.GATEWAY_TIMEOUT: grpc.StatusCode.DEADLINE_EXCEEDED, + HTTPStatus.BAD_REQUEST: grpc.StatusCode.INVALID_ARGUMENT, + HTTPStatus.INTERNAL_SERVER_ERROR: grpc.StatusCode.INTERNAL, + HTTPStatus.UNPROCESSABLE_ENTITY: grpc.StatusCode.FAILED_PRECONDITION, + } + + +@lru_cache(maxsize=1) +def grpc_status_to_http_status_map() -> dict[grpc.StatusCode, Enum]: + return {v: k for k, v in http_status_to_grpc_status_map().items()} + + +@lru_cache(maxsize=1) +def filetype_pb_to_mimetype_map() -> dict[pb.File.FileType.ValueType, str]: + return { + pb.File.FILE_TYPE_CSV: "text/csv", + pb.File.FILE_TYPE_PLAINTEXT: "text/plain", + pb.File.FILE_TYPE_JSON: "application/json", + pb.File.FILE_TYPE_BYTES: "application/octet-stream", + pb.File.FILE_TYPE_PDF: "application/pdf", + pb.File.FILE_TYPE_PNG: "image/png", + pb.File.FILE_TYPE_JPEG: "image/jpeg", + pb.File.FILE_TYPE_GIF: "image/gif", + pb.File.FILE_TYPE_TIFF: "image/tiff", + pb.File.FILE_TYPE_BMP: "image/bmp", + pb.File.FILE_TYPE_WEBP: "image/webp", + pb.File.FILE_TYPE_SVG: "image/svg+xml", + } + + +@lru_cache(maxsize=1) +def mimetype_to_filetype_pb_map() -> dict[str, pb.File.FileType.ValueType]: + return {v: k for k, v in filetype_pb_to_mimetype_map().items()} + + +def grpc_status_code(err: BentoMLException) -> grpc.StatusCode: + """ + Convert BentoMLException.error_code to grpc.StatusCode. + """ + return http_status_to_grpc_status_map().get(err.error_code, grpc.StatusCode.UNKNOWN) + + +def to_http_status(status_code: grpc.StatusCode) -> int: + """ + Convert grpc.StatusCode to HTTPStatus. + """ + status = grpc_status_to_http_status_map().get( + status_code, HTTPStatus.INTERNAL_SERVER_ERROR + ) + + return status.value + + +@dataclass +class MethodName: + """ + Represents a gRPC method name. + + Attributes: + package: This is defined by `package foo.bar`, designation in the protocol buffer definition + service: service name in protocol buffer definition (eg: service SearchService { ... }) + method: method name + """ + + package: str = "" + service: str = "" + method: str = "" + + @property + def fully_qualified_service(self): + """return the service name prefixed with package""" + return f"{self.package}.{self.service}" if self.package else self.service + + +def parse_method_name(method_name: str) -> tuple[MethodName, bool]: + """ + Infers the grpc service and method name from the handler_call_details. + e.g. /package.ServiceName/MethodName + """ + method = method_name.split("/", maxsplit=2) + # sanity check for method. + if len(method) != 3: + return MethodName(), False + _, package_service, method = method + *packages, service = package_service.rsplit(".", maxsplit=1) + package = packages[0] if packages else "" + return MethodName(package, service, method), True + + +def wrap_rpc_handler( + wrapper: t.Callable[..., t.Any], + handler: RpcMethodHandler | None, +) -> RpcMethodHandler | None: + if not handler: + return None + if not handler.request_streaming and not handler.response_streaming: + assert handler.unary_unary + return handler._replace(unary_unary=wrapper(handler.unary_unary)) + elif not handler.request_streaming and handler.response_streaming: + assert handler.unary_stream + return handler._replace(unary_stream=wrapper(handler.unary_stream)) + elif handler.request_streaming and not handler.response_streaming: + assert handler.stream_unary + return handler._replace(stream_unary=wrapper(handler.stream_unary)) + elif handler.request_streaming and handler.response_streaming: + assert handler.stream_stream + return handler._replace(stream_stream=wrapper(handler.stream_stream)) + else: + raise RuntimeError(f"RPC method handler {handler} does not exist.") from None diff --git a/bentoml/grpc/utils/_import_hook.py b/bentoml/grpc/utils/_import_hook.py new file mode 100644 index 0000000000..285e0abec0 --- /dev/null +++ b/bentoml/grpc/utils/_import_hook.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from pathlib import Path + +if TYPE_CHECKING: + import types + + +def import_generated_stubs( + version: str = "v1alpha1", + file: str = "service.proto", +) -> tuple[types.ModuleType, types.ModuleType]: + """ + Import generated stubs. + + Args: + version: The version of the proto file to import. + file: The name of the proto file to import. + + Returns: + A tuple of the generated stubs for the proto file. + + Examples: + + .. code-block:: python + + from bentoml.grpc.utils import import_generated_stubs + + # given proto file bentoml/grpc/v1alpha2/service.proto exists + pb, services = import_generated_stubs(version="v1alpha2", file="service.proto") + """ + # generate git root from this file's path + from bentoml._internal.utils import LazyLoader + + GIT_ROOT = Path(__file__).parent.parent.parent.parent + + exception_message = f"Generated stubs for '{version}/{file}' are missing. To generate stubs, run '{GIT_ROOT}/scripts/generate_grpc_stubs.sh'" + file = file.split(".")[0] + + service_pb2 = LazyLoader( + f"{file}_pb2", + globals(), + f"bentoml.grpc.{version}.{file}_pb2", + exc_msg=exception_message, + ) + service_pb2_grpc = LazyLoader( + f"{file}_pb2_grpc", + globals(), + f"bentoml.grpc.{version}.{file}_pb2_grpc", + exc_msg=exception_message, + ) + return service_pb2, service_pb2_grpc + + +def import_grpc() -> tuple[types.ModuleType, types.ModuleType]: + from bentoml._internal.utils import LazyLoader + + exception_message = "'grpcio' is required for gRPC support. Install with 'pip install bentoml[grpc]'." + grpc = LazyLoader( + "grpc", + globals(), + "grpc", + exc_msg=exception_message, + ) + aio = LazyLoader("aio", globals(), "grpc.aio", exc_msg=exception_message) + return grpc, aio diff --git a/bentoml/grpc/v1alpha1/__init__.py b/bentoml/grpc/v1alpha1/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/bentoml/grpc/v1alpha1/service.proto b/bentoml/grpc/v1alpha1/service.proto new file mode 100644 index 0000000000..c166ee715c --- /dev/null +++ b/bentoml/grpc/v1alpha1/service.proto @@ -0,0 +1,279 @@ +syntax = "proto3"; + +package bentoml.grpc.v1alpha1; + +import "google/protobuf/struct.proto"; +import "google/protobuf/wrappers.proto"; + +// cc_enable_arenas pre-allocate memory for given message to improve speed. (C++ only) +option cc_enable_arenas = true; +option go_package = "github.com/bentoml/grpc/v1alpha1"; +option java_multiple_files = true; +option java_outer_classname = "ServiceProto"; +option java_package = "com.bentoml.grpc.v1alpha1"; +option objc_class_prefix = "SVC"; +option py_generic_services = true; + +// a gRPC BentoServer. +service BentoService { + // Call handles methodcaller of given API entrypoint. + rpc Call(Request) returns (Response) {} +} + +// Request message for incoming Call. +message Request { + // api_name defines the API entrypoint to call. + // api_name is the name of the function defined in bentoml.Service. + // Example: + // + // @svc.api(input=NumpyNdarray(), output=File()) + // def predict(input: NDArray[float]) -> bytes: + // ... + // + // api_name is "predict" in this case. + string api_name = 1; + + oneof content { + // NDArray represents a n-dimensional array of arbitrary type. + NDArray ndarray = 3; + + // DataFrame represents any tabular data type. We are using + // DataFrame as a trivial representation for tabular type. + DataFrame dataframe = 5; + + // Series portrays a series of values. This can be used for + // representing Series types in tabular data. + Series series = 6; + + // File represents for any arbitrary file type. This can be + // plaintext, image, video, audio, etc. + File file = 7; + + // Text represents a string inputs. + google.protobuf.StringValue text = 8; + + // JSON is represented by using google.protobuf.Value. + // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto + google.protobuf.Value json = 9; + + // Multipart represents a multipart message. + // It comprises of a mapping from given type name to a subset of aforementioned types. + Multipart multipart = 10; + + // serialized_bytes is for data serialized in BentoML's internal serialization format. + bytes serialized_bytes = 2; + } + + // Tensor is similiar to ndarray but with a name + // We are reserving it for now for future use. + // repeated Tensor tensors = 4; + reserved 4, 11 to 13; +} + +// Request message for incoming Call. +message Response { + oneof content { + // NDArray represents a n-dimensional array of arbitrary type. + NDArray ndarray = 1; + + // DataFrame represents any tabular data type. We are using + // DataFrame as a trivial representation for tabular type. + DataFrame dataframe = 3; + + // Series portrays a series of values. This can be used for + // representing Series types in tabular data. + Series series = 5; + + // File represents for any arbitrary file type. This can be + // plaintext, image, video, audio, etc. + File file = 6; + + // Text represents a string inputs. + google.protobuf.StringValue text = 7; + + // JSON is represented by using google.protobuf.Value. + // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto + google.protobuf.Value json = 8; + + // Multipart represents a multipart message. + // It comprises of a mapping from given type name to a subset of aforementioned types. + Multipart multipart = 9; + + // serialized_bytes is for data serialized in BentoML's internal serialization format. + bytes serialized_bytes = 2; + } + // Tensor is similiar to ndarray but with a name + // We are reserving it for now for future use. + // repeated Tensor tensors = 4; + reserved 4, 10 to 13; +} + +// Part represents possible value types for multipart message. +// These are the same as the types in Request message. +message Part { + oneof representation { + // NDArray represents a n-dimensional array of arbitrary type. + NDArray ndarray = 1; + + // DataFrame represents any tabular data type. We are using + // DataFrame as a trivial representation for tabular type. + DataFrame dataframe = 3; + + // Series portrays a series of values. This can be used for + // representing Series types in tabular data. + Series series =5; + + // File represents for any arbitrary file type. This can be + // plaintext, image, video, audio, etc. + File file = 6; + + // Text represents a string inputs. + google.protobuf.StringValue text = 7; + + // JSON is represented by using google.protobuf.Value. + // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto + google.protobuf.Value json = 8; + } + + // Tensor is similiar to ndarray but with a name + // We are reserving it for now for future use. + // Tensor tensors = 4; + reserved 2, 9 to 13; +} + +// Multipart represents a multipart message. +// It comprises of a mapping from given type name to a subset of aforementioned types. +message Multipart { + map fields = 1; +} + +// File represents for any arbitrary file type. This can be +// plaintext, image, video, audio, etc. +message File { + // FileType represents possible file type to be handled by BentoML. + // Currently, we only support plaintext (Text()), image (Image()), and file (File()). + // TODO: support audio and video streaming file types. + enum FileType { + FILE_TYPE_UNSPECIFIED = 0; + + // file types + FILE_TYPE_CSV = 1; + FILE_TYPE_PLAINTEXT = 2; + FILE_TYPE_JSON = 3; + FILE_TYPE_BYTES = 4; + FILE_TYPE_PDF = 5; + + // image types + FILE_TYPE_PNG = 6; + FILE_TYPE_JPEG = 7; + FILE_TYPE_GIF = 8; + FILE_TYPE_BMP = 9; + FILE_TYPE_TIFF = 10; + FILE_TYPE_WEBP = 11; + FILE_TYPE_SVG = 12; + } + + // optional type of file, let it be csv, text, parquet, etc. + optional FileType kind = 1; + + // contents of file as bytes. + bytes content = 2; +} + +// DataFrame represents any tabular data type. We are using +// DataFrame as a trivial representation for tabular type. +// This message carries given implementation of tabular data based on given orientation. +// TODO: support index, records, etc. +message DataFrame { + // columns name + repeated string column_names = 1; + + // columns orient. + // { column ↠ { index ↠ value } } + repeated Series columns = 2; +} + +// Series portrays a series of values. This can be used for +// representing Series types in tabular data. +message Series { + // A bool parameter value + repeated bool bool_values = 1 [packed = true]; + + // A float parameter value + repeated float float_values = 2 [packed = true]; + + // A int32 parameter value + repeated int32 int32_values = 3 [packed = true]; + + // A int64 parameter value + repeated int64 int64_values = 6 [packed = true]; + + // A string parameter value + repeated string string_values = 5; + + // represents a double parameter value. + repeated double double_values = 4 [packed = true]; +} + +// NDArray represents a n-dimensional array of arbitrary type. +message NDArray { + // Represents data type of a given array. + enum DType { + // Represents a None type. + DTYPE_UNSPECIFIED = 0; + + // Represents an float type. + DTYPE_FLOAT = 1; + + // Represents an double type. + DTYPE_DOUBLE = 2; + + // Represents a bool type. + DTYPE_BOOL = 3; + + // Represents an int32 type. + DTYPE_INT32 = 4; + + // Represents an int64 type. + DTYPE_INT64 = 5; + + // Represents a uint32 type. + DTYPE_UINT32 = 6; + + // Represents a uint64 type. + DTYPE_UINT64 = 7; + + // Represents a string type. + DTYPE_STRING = 8; + } + + // DTYPE is the data type of given array + DType dtype = 1; + + // shape is the shape of given array. + repeated int32 shape = 2; + + // represents a string parameter value. + repeated string string_values = 5; + + // represents a float parameter value. + repeated float float_values = 3 [packed = true]; + + // represents a double parameter value. + repeated double double_values = 4 [packed = true]; + + // represents a bool parameter value. + repeated bool bool_values = 6 [packed = true]; + + // represents a int32 parameter value. + repeated int32 int32_values = 7 [packed = true]; + + // represents a int64 parameter value. + repeated int64 int64_values = 8 [packed = true]; + + // represents a uint32 parameter value. + repeated uint32 uint32_values = 9 [packed = true]; + + // represents a uint64 parameter value. + repeated uint64 uint64_values = 10 [packed = true]; +} diff --git a/bentoml/grpc/v1alpha1/service_test.proto b/bentoml/grpc/v1alpha1/service_test.proto new file mode 100644 index 0000000000..8fccaac169 --- /dev/null +++ b/bentoml/grpc/v1alpha1/service_test.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package bentoml.testing.v1alpha1; + +option cc_enable_arenas = true; +option go_package = "github.com/bentoml/testing/v1alpha1"; +option optimize_for = SPEED; +option py_generic_services = true; + +// Represents a request for TestService. +message ExecuteRequest { + string input = 1; +} + +// Represents a response from TestService. +message ExecuteResponse { + string output = 1; +} + +// Use for testing interceptors per RPC call. +service TestService { + // Unary API + rpc Execute(ExecuteRequest) returns (ExecuteResponse); +} diff --git a/bentoml/serve.py b/bentoml/serve.py index 0d2541cf8f..fe34741577 100644 --- a/bentoml/serve.py +++ b/bentoml/serve.py @@ -9,27 +9,35 @@ import logging import tempfile import contextlib +from typing import TYPE_CHECKING from pathlib import Path +from functools import partial import psutil from simple_di import inject from simple_di import Provide -from bentoml import load - -from ._internal.log import SERVER_LOGGING_CONFIG -from ._internal.utils import reserve_free_port -from ._internal.resource import CpuResource -from ._internal.utils.uri import path_to_uri -from ._internal.utils.circus import create_standalone_arbiter -from ._internal.utils.analytics import track_serve +from ._internal.utils import experimental from ._internal.configuration.containers import BentoMLContainer +if TYPE_CHECKING: + from circus.watcher import Watcher + + logger = logging.getLogger(__name__) +PROMETHEUS_MESSAGE = ( + 'Prometheus metrics for %s BentoServer from "%s" can be accessed at %s.' +) SCRIPT_RUNNER = "bentoml_cli.worker.runner" SCRIPT_API_SERVER = "bentoml_cli.worker.http_api_server" SCRIPT_DEV_API_SERVER = "bentoml_cli.worker.http_dev_api_server" +SCRIPT_GRPC_API_SERVER = "bentoml_cli.worker.grpc_api_server" +SCRIPT_GRPC_DEV_API_SERVER = "bentoml_cli.worker.grpc_dev_api_server" +SCRIPT_GRPC_PROMETHEUS_SERVER = "bentoml_cli.worker.grpc_prometheus_server" + +API_SERVER_NAME = "_bento_api_server" +PROMETHEUS_SERVER_NAME = "_prometheus_server" @inject @@ -75,12 +83,90 @@ def ensure_prometheus_dir( return alternative +def create_watcher( + name: str, + args: list[str], + *, + use_sockets: bool = True, + **kwargs: t.Any, +) -> Watcher: + from circus.watcher import Watcher + + return Watcher( + name=name, + cmd=sys.executable, + args=args, + copy_env=True, + stop_children=True, + use_sockets=use_sockets, + **kwargs, + ) + + +def log_grpcui_instruction(port: int) -> None: + # logs instruction on how to start gRPCUI + docker_run = partial( + "docker run -it --rm {network_args} fullstorydev/grpcui -plaintext {platform_deps}:{port}".format, + port=port, + ) + message = "To use gRPC UI, run the following command: '%s', followed by opening 'http://0.0.0.0:8080' in your browser of choice." + + linux_instruction = docker_run( + platform_deps="0.0.0.0", network_args="--network=host" + ) + mac_win_instruction = docker_run( + platform_deps="host.docker.internal", network_args="-p 8080:8080" + ) + + if os.path.exists("/.dockerenv"): + logger.info( + "Detected running Bento inside an OCI container. In order to use gRPC UI, do as follows: If your local machine are either MacOS or Windows , then use '%s'. Otherwise use '%s'.", + mac_win_instruction, + linux_instruction, + ) + elif psutil.WINDOWS or psutil.MACOS: + logger.info(message, mac_win_instruction) + elif psutil.LINUX: + logger.info(message, linux_instruction) + + +def construct_ssl_args( + ssl_certfile: str | None, + ssl_keyfile: str | None, + ssl_keyfile_password: str | None, + ssl_version: int | None, + ssl_cert_reqs: int | None, + ssl_ca_certs: str | None, + ssl_ciphers: str | None, +) -> list[str]: + args: list[str] = [] + + # Add optional SSL args if they exist + if ssl_certfile: + args.extend(["--ssl-certfile", str(ssl_certfile)]) + if ssl_keyfile: + args.extend(["--ssl-keyfile", str(ssl_keyfile)]) + if ssl_keyfile_password: + args.extend(["--ssl-keyfile-password", ssl_keyfile_password]) + if ssl_ca_certs: + args.extend(["--ssl-ca-certs", str(ssl_ca_certs)]) + + # match with default uvicorn values. + if ssl_version: + args.extend(["--ssl-version", str(ssl_version)]) + if ssl_cert_reqs: + args.extend(["--ssl-cert-reqs", str(ssl_cert_reqs)]) + if ssl_ciphers: + args.extend(["--ssl-ciphers", ssl_ciphers]) + return args + + @inject -def serve_development( +def serve_http_development( bento_identifier: str, working_dir: str, - port: int = Provide[BentoMLContainer.api_server_config.port], - host: str = Provide[BentoMLContainer.api_server_config.host], + port: int = Provide[BentoMLContainer.http.port], + host: str = Provide[BentoMLContainer.http.host], backlog: int = Provide[BentoMLContainer.api_server_config.backlog], bentoml_home: str = Provide[BentoMLContainer.bentoml_home], ssl_certfile: str | None = Provide[BentoMLContainer.api_server_config.ssl.certfile], @@ -94,70 +180,61 @@ def serve_development( ssl_ciphers: str | None = Provide[BentoMLContainer.api_server_config.ssl.ciphers], reload: bool = False, ) -> None: - working_dir = os.path.realpath(os.path.expanduser(working_dir)) - svc = load(bento_identifier, working_dir=working_dir) # verify service loading + from circus.sockets import CircusSocket - from circus.sockets import CircusSocket # type: ignore - from circus.watcher import Watcher # type: ignore - - prometheus_dir = ensure_prometheus_dir() + from bentoml import load - watchers: t.List[Watcher] = [] + from ._internal.log import SERVER_LOGGING_CONFIG + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve - circus_sockets: t.List[CircusSocket] = [] - circus_sockets.append( - CircusSocket( - name="_bento_api_server", - host=host, - port=port, - backlog=backlog, - ) - ) + working_dir = os.path.realpath(os.path.expanduser(working_dir)) + svc = load(bento_identifier, working_dir=working_dir) - args: list[str | int] = [ - "-m", - SCRIPT_DEV_API_SERVER, - bento_identifier, - "--fd", - "$(circus.sockets._bento_api_server)", - "--working-dir", - working_dir, - "--prometheus-dir", - prometheus_dir, - ] + prometheus_dir = ensure_prometheus_dir() - # Add optional SSL args if they exist - if ssl_certfile: - args.extend(["--ssl-certfile", str(ssl_certfile)]) - if ssl_keyfile: - args.extend(["--ssl-keyfile", str(ssl_keyfile)]) - if ssl_keyfile_password: - args.extend(["--ssl-keyfile-password", ssl_keyfile_password]) - if ssl_ca_certs: - args.extend(["--ssl-ca-certs", str(ssl_ca_certs)]) + watchers: list[Watcher] = [] - # match with default uvicorn values. - if ssl_version: - args.extend(["--ssl-version", int(ssl_version)]) - if ssl_cert_reqs: - args.extend(["--ssl-cert-reqs", int(ssl_cert_reqs)]) - if ssl_ciphers: - args.extend(["--ssl-ciphers", ssl_ciphers]) + circus_sockets: list[CircusSocket] = [ + CircusSocket(name=API_SERVER_NAME, host=host, port=port, backlog=backlog) + ] watchers.append( - Watcher( + create_watcher( name="dev_api_server", - cmd=sys.executable, - args=args, - copy_env=True, - stop_children=True, - use_sockets=True, + args=[ + "-m", + SCRIPT_DEV_API_SERVER, + bento_identifier, + "--fd", + f"$(circus.sockets.{API_SERVER_NAME})", + "--working-dir", + working_dir, + "--prometheus-dir", + prometheus_dir, + *construct_ssl_args( + ssl_certfile=ssl_certfile, + ssl_keyfile=ssl_keyfile, + ssl_keyfile_password=ssl_keyfile_password, + ssl_version=ssl_version, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + ssl_ciphers=ssl_ciphers, + ), + ], working_dir=working_dir, # we don't want to close stdin for child process in case user use debugger. # See https://circus.readthedocs.io/en/latest/for-ops/configuration/ close_child_stdin=False, ) ) + if BentoMLContainer.api_server_config.metrics.enabled.get(): + logger.info( + PROMETHEUS_MESSAGE, + "HTTP", + bento_identifier, + f"http://{host}:{port}/metrics", + ) plugins = [] if reload: @@ -169,7 +246,7 @@ def serve_development( "--reload is passed. BentoML will watch file changes based on 'bentofile.yaml' and '.bentoignore' respectively." ) - # initialize dictionary with {} is faster than using dict() + # NOTE: {} is faster than dict() plugins = [ # reloader plugin { @@ -178,6 +255,7 @@ def serve_development( "bentoml_home": bentoml_home, }, ] + arbiter = create_standalone_arbiter( watchers, sockets=circus_sockets, @@ -190,8 +268,11 @@ def serve_development( with track_serve(svc, production=False): arbiter.start( cb=lambda _: logger.info( # type: ignore - f'Starting development BentoServer from "{bento_identifier}" ' - f"running on http://{host}:{port} (Press CTRL+C to quit)" + 'Starting development %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)', + "HTTP", + bento_identifier, + host, + port, ), ) @@ -200,11 +281,11 @@ def serve_development( @inject -def serve_production( +def serve_http_production( bento_identifier: str, working_dir: str, - port: int = Provide[BentoMLContainer.api_server_config.port], - host: str = Provide[BentoMLContainer.api_server_config.host], + port: int = Provide[BentoMLContainer.http.port], + host: str = Provide[BentoMLContainer.http.host], backlog: int = Provide[BentoMLContainer.api_server_config.backlog], api_workers: int | None = None, ssl_certfile: str | None = Provide[BentoMLContainer.api_server_config.ssl.certfile], @@ -217,11 +298,18 @@ def serve_production( ssl_ca_certs: str | None = Provide[BentoMLContainer.api_server_config.ssl.ca_certs], ssl_ciphers: str | None = Provide[BentoMLContainer.api_server_config.ssl.ciphers], ) -> None: + from bentoml import load + + from ._internal.utils import reserve_free_port + from ._internal.resource import CpuResource + from ._internal.utils.uri import path_to_uri + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve + working_dir = os.path.realpath(os.path.expanduser(working_dir)) svc = load(bento_identifier, working_dir=working_dir, standalone_load=True) from circus.sockets import CircusSocket # type: ignore - from circus.watcher import Watcher # type: ignore watchers: t.List[Watcher] = [] circus_socket_map: t.Dict[str, CircusSocket] = {} @@ -245,9 +333,8 @@ def serve_production( ) watchers.append( - Watcher( + create_watcher( name=f"runner_{runner.name}", - cmd=sys.executable, args=[ "-m", SCRIPT_RUNNER, @@ -265,10 +352,7 @@ def serve_production( "--prometheus-dir", prometheus_dir, ], - copy_env=True, - stop_children=True, working_dir=working_dir, - use_sockets=True, numprocesses=runner.scheduled_worker_count, ) ) @@ -289,9 +373,8 @@ def serve_production( ) watchers.append( - Watcher( + create_watcher( name=f"runner_{runner.name}", - cmd=sys.executable, args=[ "-m", SCRIPT_RUNNER, @@ -304,15 +387,10 @@ def serve_production( working_dir, "--no-access-log", "--worker-id", - "$(circus.wid)", + "$(CIRCUS.WID)", "--worker-env-map", json.dumps(runner.scheduled_worker_env_map), - "--prometheus-dir", - prometheus_dir, ], - copy_env=True, - stop_children=True, - use_sockets=True, working_dir=working_dir, numprocesses=runner.scheduled_worker_count, ) @@ -323,64 +401,57 @@ def serve_production( else: raise NotImplementedError("Unsupported platform: {}".format(sys.platform)) - logger.debug("Runner map: %s", runner_bind_map) + logger.debug(f"Runner map: {runner_bind_map}") - circus_socket_map["_bento_api_server"] = CircusSocket( - name="_bento_api_server", + circus_socket_map[API_SERVER_NAME] = CircusSocket( + name=API_SERVER_NAME, host=host, port=port, backlog=backlog, ) - args: list[str | int] = [ - "-m", - SCRIPT_API_SERVER, - bento_identifier, - "--fd", - "$(circus.sockets._bento_api_server)", - "--runner-map", - json.dumps(runner_bind_map), - "--working-dir", - working_dir, - "--backlog", - f"{backlog}", - "--worker-id", - "$(CIRCUS.WID)", - "--prometheus-dir", - prometheus_dir, - ] - - # Add optional SSL args if they exist - if ssl_certfile: - args.extend(["--ssl-certfile", str(ssl_certfile)]) - if ssl_keyfile: - args.extend(["--ssl-keyfile", str(ssl_keyfile)]) - if ssl_keyfile_password: - args.extend(["--ssl-keyfile-password", ssl_keyfile_password]) - if ssl_ca_certs: - args.extend(["--ssl-ca-certs", str(ssl_ca_certs)]) - - # match with default uvicorn values. - if ssl_version: - args.extend(["--ssl-version", int(ssl_version)]) - if ssl_cert_reqs: - args.extend(["--ssl-cert-reqs", int(ssl_cert_reqs)]) - if ssl_ciphers: - args.extend(["--ssl-ciphers", ssl_ciphers]) - watchers.append( - Watcher( + create_watcher( name="api_server", - cmd=sys.executable, - args=args, - copy_env=True, - numprocesses=api_workers or math.ceil(CpuResource.from_system()), - stop_children=True, - use_sockets=True, + args=[ + "-m", + SCRIPT_API_SERVER, + bento_identifier, + "--fd", + f"$(circus.sockets.{API_SERVER_NAME})", + "--runner-map", + json.dumps(runner_bind_map), + "--working-dir", + working_dir, + "--backlog", + f"{backlog}", + "--worker-id", + "$(CIRCUS.WID)", + "--prometheus-dir", + prometheus_dir, + *construct_ssl_args( + ssl_certfile=ssl_certfile, + ssl_keyfile=ssl_keyfile, + ssl_keyfile_password=ssl_keyfile_password, + ssl_version=ssl_version, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + ssl_ciphers=ssl_ciphers, + ), + ], working_dir=working_dir, + numprocesses=api_workers or math.ceil(CpuResource.from_system()), ) ) + if BentoMLContainer.api_server_config.metrics.enabled.get(): + logger.info( + PROMETHEUS_MESSAGE, + "HTTP", + bento_identifier, + f"http://{host}:{port}/metrics", + ) + arbiter = create_standalone_arbiter( watchers=watchers, sockets=list(circus_socket_map.values()), @@ -390,8 +461,400 @@ def serve_production( try: arbiter.start( cb=lambda _: logger.info( # type: ignore - f'Starting production BentoServer from "{bento_identifier}" ' - f"running on http://{host}:{port} (Press CTRL+C to quit)" + 'Starting production %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)', + "HTTP", + bento_identifier, + host, + port, + ), + ) + finally: + if uds_path is not None: + shutil.rmtree(uds_path) + + +@experimental +@inject +def serve_grpc_development( + bento_identifier: str, + working_dir: str, + port: int = Provide[BentoMLContainer.grpc.port], + host: str = Provide[BentoMLContainer.grpc.host], + bentoml_home: str = Provide[BentoMLContainer.bentoml_home], + reload: bool = False, + reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled], + max_concurrent_streams: int + | None = Provide[BentoMLContainer.grpc.max_concurrent_streams], + backlog: int = Provide[BentoMLContainer.api_server_config.backlog], +) -> None: + from circus.sockets import CircusSocket + + from bentoml import load + + from ._internal.log import SERVER_LOGGING_CONFIG + from ._internal.utils import reserve_free_port + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve + + working_dir = os.path.realpath(os.path.expanduser(working_dir)) + svc = load(bento_identifier, working_dir=working_dir) + + prometheus_dir = ensure_prometheus_dir() + + watchers: list[Watcher] = [] + + circus_sockets: list[CircusSocket] = [] + + if not reflection: + logger.info( + "'reflection' is disabled by default. Tools such as gRPCUI or grpcurl relies on server reflection. To use those, pass '--enable-reflection' to the CLI." + ) + else: + log_grpcui_instruction(port) + + with contextlib.ExitStack() as port_stack: + api_port = port_stack.enter_context( + reserve_free_port(host, port=port, enable_so_reuseport=True) + ) + + args = [ + "-m", + SCRIPT_GRPC_DEV_API_SERVER, + bento_identifier, + "--host", + host, + "--port", + str(api_port), + "--working-dir", + working_dir, + ] + + if reflection: + args.append("--enable-reflection") + if max_concurrent_streams: + args.extend( + [ + "--max-concurrent-streams", + str(max_concurrent_streams), + ] + ) + + # use circus_sockets. CircusSocket support for SO_REUSEPORT + watchers.append( + create_watcher( + name="grpc_dev_api_server", + args=args, + use_sockets=False, + working_dir=working_dir, + # we don't want to close stdin for child process in case user use debugger. + # See https://circus.readthedocs.io/en/latest/for-ops/configuration/ + close_child_stdin=False, + ) + ) + + if BentoMLContainer.api_server_config.metrics.enabled.get(): + metrics_host = BentoMLContainer.grpc.metrics.host.get() + metrics_port = BentoMLContainer.grpc.metrics.port.get() + + circus_sockets.append( + CircusSocket( + name=PROMETHEUS_SERVER_NAME, + host=metrics_host, + port=metrics_port, + backlog=backlog, + ) + ) + + watchers.append( + create_watcher( + name="prom_server", + args=[ + "-m", + SCRIPT_GRPC_PROMETHEUS_SERVER, + "--fd", + f"$(circus.sockets.{PROMETHEUS_SERVER_NAME})", + "--prometheus-dir", + prometheus_dir, + "--backlog", + f"{backlog}", + ], + working_dir=working_dir, + numprocesses=1, + singleton=True, + # we don't want to close stdin for child process in case user use debugger. + # See https://circus.readthedocs.io/en/latest/for-ops/configuration/ + close_child_stdin=False, + ) + ) + + logger.info( + PROMETHEUS_MESSAGE, + "gRPC", + bento_identifier, + f"http://{metrics_host}:{metrics_port}", + ) + + plugins = [] + + if reload: + if sys.platform == "win32": + logger.warning( + "Due to circus limitations, output from the reloader plugin will not be shown on Windows." + ) + logger.debug( + "--reload is passed. BentoML will watch file changes based on 'bentofile.yaml' and '.bentoignore' respectively." + ) + + # NOTE: {} is faster than dict() + plugins = [ + # reloader plugin + { + "use": "bentoml._internal.utils.circus.watchfilesplugin.ServiceReloaderPlugin", + "working_dir": working_dir, + "bentoml_home": bentoml_home, + }, + ] + + arbiter = create_standalone_arbiter( + watchers, + sockets=circus_sockets, + plugins=plugins, + debug=True if sys.platform != "win32" else False, + loggerconfig=SERVER_LOGGING_CONFIG, + loglevel="ERROR", + ) + + with track_serve(svc, production=False): + arbiter.start( + cb=lambda _: logger.info( # type: ignore + 'Starting development %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)', + "gRPC", + bento_identifier, + host, + port, + ), + ) + + +@experimental +@inject +def serve_grpc_production( + bento_identifier: str, + working_dir: str, + port: int = Provide[BentoMLContainer.grpc.port], + host: str = Provide[BentoMLContainer.grpc.host], + backlog: int = Provide[BentoMLContainer.api_server_config.backlog], + api_workers: int | None = None, + reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled], + max_concurrent_streams: int + | None = Provide[BentoMLContainer.grpc.max_concurrent_streams], +) -> None: + from bentoml import load + from bentoml.exceptions import UnprocessableEntity + + from ._internal.utils import reserve_free_port + from ._internal.resource import CpuResource + from ._internal.utils.uri import path_to_uri + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve + + working_dir = os.path.realpath(os.path.expanduser(working_dir)) + svc = load(bento_identifier, working_dir=working_dir, standalone_load=True) + + from circus.sockets import CircusSocket # type: ignore + + watchers: list[Watcher] = [] + circus_socket_map: dict[str, CircusSocket] = {} + runner_bind_map: dict[str, str] = {} + uds_path = None + + prometheus_dir = ensure_prometheus_dir() + + # Check whether users are running --grpc on windows + # also raising warning if users running on MacOS or FreeBSD + if psutil.WINDOWS: + raise UnprocessableEntity( + "'grpc' is not supported on Windows with '--production'. The reason being SO_REUSEPORT socket option is only available on UNIX system, and gRPC implementation depends on this behaviour." + ) + if psutil.MACOS or psutil.FREEBSD: + logger.warning( + "Due to gRPC implementation on exposing SO_REUSEPORT, '--production' behaviour on %s is not correct. We recommend to containerize BentoServer as a Linux container instead.", + "MacOS" if psutil.MACOS else "FreeBSD", + ) + + if psutil.POSIX: + # use AF_UNIX sockets for Circus + uds_path = tempfile.mkdtemp() + for runner in svc.runners: + sockets_path = os.path.join(uds_path, f"{id(runner)}.sock") + assert len(sockets_path) < MAX_AF_UNIX_PATH_LENGTH + + runner_bind_map[runner.name] = path_to_uri(sockets_path) + circus_socket_map[runner.name] = CircusSocket( + name=runner.name, + path=sockets_path, + backlog=backlog, + ) + + watchers.append( + create_watcher( + name=f"runner_{runner.name}", + args=[ + "-m", + SCRIPT_RUNNER, + bento_identifier, + "--runner-name", + runner.name, + "--fd", + f"$(circus.sockets.{runner.name})", + "--working-dir", + working_dir, + "--worker-id", + "$(CIRCUS.WID)", + "--worker-env-map", + json.dumps(runner.scheduled_worker_env_map), + ], + working_dir=working_dir, + numprocesses=runner.scheduled_worker_count, + ) + ) + + elif psutil.WINDOWS: + # Windows doesn't (fully) support AF_UNIX sockets + with contextlib.ExitStack() as port_stack: + for runner in svc.runners: + runner_port = port_stack.enter_context(reserve_free_port()) + runner_host = "127.0.0.1" + + runner_bind_map[runner.name] = f"tcp://{runner_host}:{runner_port}" + circus_socket_map[runner.name] = CircusSocket( + name=runner.name, + host=runner_host, + port=runner_port, + backlog=backlog, + ) + + watchers.append( + create_watcher( + name=f"runner_{runner.name}", + args=[ + "-m", + SCRIPT_RUNNER, + bento_identifier, + "--runner-name", + runner.name, + "--fd", + f"$(circus.sockets.{runner.name})", + "--working-dir", + working_dir, + "--no-access-log", + "--worker-id", + "$(circus.wid)", + "--worker-env-map", + json.dumps(runner.scheduled_worker_env_map), + "--prometheus-dir", + prometheus_dir, + ], + working_dir=working_dir, + numprocesses=runner.scheduled_worker_count, + ) + ) + # reserve one more to avoid conflicts + port_stack.enter_context(reserve_free_port()) + else: + raise NotImplementedError("Unsupported platform: {}".format(sys.platform)) + + logger.debug(f"Runner map: {runner_bind_map}") + + with contextlib.ExitStack() as port_stack: + api_port = port_stack.enter_context( + reserve_free_port(host, port=port, enable_so_reuseport=True) + ) + args = [ + "-m", + SCRIPT_GRPC_API_SERVER, + bento_identifier, + "--host", + host, + "--port", + str(api_port), + "--runner-map", + json.dumps(runner_bind_map), + "--working-dir", + working_dir, + "--worker-id", + "$(CIRCUS.WID)", + ] + if reflection: + args.append("--enable-reflection") + + if max_concurrent_streams: + args.extend( + [ + "--max-concurrent-streams", + str(max_concurrent_streams), + ] + ) + + watchers.append( + create_watcher( + name="grpc_api_server", + args=args, + use_sockets=False, + working_dir=working_dir, + numprocesses=api_workers or math.ceil(CpuResource.from_system()), + ) + ) + + if BentoMLContainer.api_server_config.metrics.enabled.get(): + metrics_host = BentoMLContainer.grpc.metrics.host.get() + metrics_port = BentoMLContainer.grpc.metrics.port.get() + + circus_socket_map[PROMETHEUS_SERVER_NAME] = CircusSocket( + name=PROMETHEUS_SERVER_NAME, + host=metrics_host, + port=metrics_port, + backlog=backlog, + ) + + watchers.append( + create_watcher( + name="prom_server", + args=[ + "-m", + SCRIPT_GRPC_PROMETHEUS_SERVER, + "--fd", + f"$(circus.sockets.{PROMETHEUS_SERVER_NAME})", + "--prometheus-dir", + prometheus_dir, + "--backlog", + f"{backlog}", + ], + working_dir=working_dir, + numprocesses=1, + singleton=True, + ) + ) + + logger.info( + PROMETHEUS_MESSAGE, + "gRPC", + bento_identifier, + f"http://{metrics_host}:{metrics_port}", + ) + arbiter = create_standalone_arbiter( + watchers=watchers, sockets=list(circus_socket_map.values()) + ) + + with track_serve(svc, production=True): + try: + arbiter.start( + cb=lambda _: logger.info( # type: ignore + 'Starting production %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)', + "gRPC", + bento_identifier, + host, + port, ), ) finally: diff --git a/bentoml/start.py b/bentoml/start.py index fc327c48d1..1541f61e89 100644 --- a/bentoml/start.py +++ b/bentoml/start.py @@ -4,80 +4,26 @@ import sys import json import math -import shutil import typing as t import logging -import tempfile import contextlib -from pathlib import Path from simple_di import inject from simple_di import Provide -from bentoml import load - -from ._internal.utils import reserve_free_port -from ._internal.resource import CpuResource -from ._internal.utils.circus import create_standalone_arbiter -from ._internal.utils.analytics import track_serve from ._internal.configuration.containers import BentoMLContainer logger = logging.getLogger(__name__) SCRIPT_RUNNER = "bentoml_cli.worker.runner" SCRIPT_API_SERVER = "bentoml_cli.worker.http_api_server" -SCRIPT_DEV_API_SERVER = "bentoml_cli.worker.http_dev_api_server" +SCRIPT_GRPC_API_SERVER = "bentoml_cli.worker.grpc_api_server" +SCRIPT_GRPC_PROMETHEUS_SERVER = "bentoml_cli.worker.grpc_prometheus_server" API_SERVER = "api_server" RUNNER = "runner" -@inject -def ensure_prometheus_dir( - directory: str = Provide[BentoMLContainer.prometheus_multiproc_dir], - clean: bool = True, - use_alternative: bool = True, -) -> str: - try: - path = Path(directory) - if path.exists(): - if not path.is_dir() or any(path.iterdir()): - if clean: - shutil.rmtree(str(path)) - path.mkdir() - return str(path.absolute()) - else: - raise RuntimeError( - "Prometheus multiproc directory {} is not empty".format(path) - ) - else: - return str(path.absolute()) - else: - path.mkdir(parents=True) - return str(path.absolute()) - except shutil.Error as e: - if not use_alternative: - raise RuntimeError( - f"Failed to clean the prometheus multiproc directory {directory}: {e}" - ) - except OSError as e: - if not use_alternative: - raise RuntimeError( - f"Failed to create the prometheus multiproc directory {directory}: {e}" - ) - assert use_alternative - alternative = tempfile.mkdtemp() - logger.warning( - f"Failed to ensure the prometheus multiproc directory {directory}, " - f"using alternative: {alternative}", - ) - BentoMLContainer.prometheus_multiproc_dir.set(alternative) - return alternative - - -MAX_AF_UNIX_PATH_LENGTH = 103 - - @inject def start_runner_server( bento_identifier: str, @@ -90,6 +36,13 @@ def start_runner_server( """ Experimental API for serving a BentoML runner. """ + from bentoml import load + + from .serve import ensure_prometheus_dir + from ._internal.utils import reserve_free_port + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve + working_dir = os.path.realpath(os.path.expanduser(working_dir)) svc = load(bento_identifier, working_dir=working_dir, standalone_load=True) @@ -98,7 +51,6 @@ def start_runner_server( watchers: t.List[Watcher] = [] circus_socket_map: t.Dict[str, CircusSocket] = {} - uds_path = None ensure_prometheus_dir() @@ -147,25 +99,19 @@ def start_runner_server( f"Runner {runner_name} not found in the service: `{bento_identifier}`, " f"available runners: {[r.name for r in svc.runners]}" ) - arbiter = create_standalone_arbiter( watchers=watchers, sockets=list(circus_socket_map.values()), ) - with track_serve(svc, production=True, component=RUNNER): - try: - arbiter.start( - cb=lambda _: logger.info( # type: ignore - 'Starting RunnerServer from "%s"\n running on http://%s:%s (Press CTRL+C to quit)', - bento_identifier, - host, - port, - ), - ) - finally: - if uds_path is not None: - shutil.rmtree(uds_path) + arbiter.start( + cb=lambda _: logger.info( # type: ignore + 'Starting RunnerServer from "%s" running on http://%s:%s (Press CTRL+C to quit)', + bento_identifier, + host, + port, + ), + ) @inject @@ -187,14 +133,24 @@ def start_http_server( ssl_ca_certs: str | None = Provide[BentoMLContainer.api_server_config.ssl.ca_certs], ssl_ciphers: str | None = Provide[BentoMLContainer.api_server_config.ssl.ciphers], ) -> None: + from bentoml import load + + from .serve import create_watcher + from .serve import API_SERVER_NAME + from .serve import construct_ssl_args + from .serve import PROMETHEUS_MESSAGE + from .serve import ensure_prometheus_dir + from ._internal.resource import CpuResource + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve + working_dir = os.path.realpath(os.path.expanduser(working_dir)) svc = load(bento_identifier, working_dir=working_dir, standalone_load=True) runner_requirements = {runner.name for runner in svc.runners} if not runner_requirements.issubset(set(runner_map)): raise ValueError( - f"{bento_identifier} requires runners {runner_requirements}, but only " - f"{set(runner_map)} are provided" + f"{bento_identifier} requires runners {runner_requirements}, but only {set(runner_map)} are provided." ) from circus.sockets import CircusSocket # type: ignore @@ -202,81 +158,202 @@ def start_http_server( watchers: t.List[Watcher] = [] circus_socket_map: t.Dict[str, CircusSocket] = {} - uds_path = None prometheus_dir = ensure_prometheus_dir() logger.debug("Runner map: %s", runner_map) - circus_socket_map["_bento_api_server"] = CircusSocket( - name="_bento_api_server", + circus_socket_map[API_SERVER_NAME] = CircusSocket( + name=API_SERVER_NAME, host=host, port=port, backlog=backlog, ) - args: list[str | int] = [ - "-m", - SCRIPT_API_SERVER, - bento_identifier, - "--fd", - "$(circus.sockets._bento_api_server)", - "--runner-map", - json.dumps(runner_map), - "--working-dir", - working_dir, - "--backlog", - f"{backlog}", - "--worker-id", - "$(CIRCUS.WID)", - "--prometheus-dir", - prometheus_dir, - ] - - # Add optional SSL args if they exist - if ssl_certfile: - args.extend(["--ssl-certfile", str(ssl_certfile)]) - if ssl_keyfile: - args.extend(["--ssl-keyfile", str(ssl_keyfile)]) - if ssl_keyfile_password: - args.extend(["--ssl-keyfile-password", ssl_keyfile_password]) - if ssl_ca_certs: - args.extend(["--ssl-ca-certs", str(ssl_ca_certs)]) - - # match with default uvicorn values. - if ssl_version: - args.extend(["--ssl-version", int(ssl_version)]) - if ssl_cert_reqs: - args.extend(["--ssl-cert-reqs", int(ssl_cert_reqs)]) - if ssl_ciphers: - args.extend(["--ssl-ciphers", ssl_ciphers]) - watchers.append( - Watcher( - name=API_SERVER, - cmd=sys.executable, - args=args, - copy_env=True, - numprocesses=api_workers or math.ceil(CpuResource.from_system()), - stop_children=True, - use_sockets=True, + create_watcher( + name="api_server", + args=[ + "-m", + SCRIPT_API_SERVER, + bento_identifier, + "--fd", + f"$(circus.sockets.{API_SERVER_NAME})", + "--runner-map", + json.dumps(runner_map), + "--working-dir", + working_dir, + "--backlog", + f"{backlog}", + "--worker-id", + "$(CIRCUS.WID)", + "--prometheus-dir", + prometheus_dir, + *construct_ssl_args( + ssl_certfile=ssl_certfile, + ssl_keyfile=ssl_keyfile, + ssl_keyfile_password=ssl_keyfile_password, + ssl_version=ssl_version, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + ssl_ciphers=ssl_ciphers, + ), + ], working_dir=working_dir, + numprocesses=api_workers or math.ceil(CpuResource.from_system()), ) ) + if BentoMLContainer.api_server_config.metrics.enabled.get(): + logger.info( + PROMETHEUS_MESSAGE, + "HTTP", + bento_identifier, + f"http://{host}:{port}/metrics", + ) arbiter = create_standalone_arbiter( watchers=watchers, sockets=list(circus_socket_map.values()), ) - with track_serve(svc, production=True, component=API_SERVER): - try: - arbiter.start( - cb=lambda _: logger.info( # type: ignore - f'Starting bare Bento API server from "{bento_identifier}" ' - f"running on http://{host}:{port} (Press CTRL+C to quit)" - ), + arbiter.start( + cb=lambda _: logger.info( # type: ignore + 'Starting bare %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)', + "HTTP", + bento_identifier, + host, + port, + ), + ) + + +@inject +def start_grpc_server( + bento_identifier: str, + runner_map: dict[str, str], + working_dir: str, + port: int = Provide[BentoMLContainer.grpc.port], + host: str = Provide[BentoMLContainer.grpc.host], + backlog: int = Provide[BentoMLContainer.api_server_config.backlog], + api_workers: int | None = None, + reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled], + max_concurrent_streams: int + | None = Provide[BentoMLContainer.grpc.max_concurrent_streams], +) -> None: + from bentoml import load + + from .serve import create_watcher + from .serve import PROMETHEUS_MESSAGE + from .serve import ensure_prometheus_dir + from .serve import PROMETHEUS_SERVER_NAME + from ._internal.utils import reserve_free_port + from ._internal.resource import CpuResource + from ._internal.utils.circus import create_standalone_arbiter + from ._internal.utils.analytics import track_serve + + working_dir = os.path.realpath(os.path.expanduser(working_dir)) + svc = load(bento_identifier, working_dir=working_dir, standalone_load=True) + + runner_requirements = {runner.name for runner in svc.runners} + if not runner_requirements.issubset(set(runner_map)): + raise ValueError( + f"{bento_identifier} requires runners {runner_requirements}, but only {set(runner_map)} are provided." + ) + + from circus.sockets import CircusSocket # type: ignore + from circus.watcher import Watcher # type: ignore + + watchers: list[Watcher] = [] + circus_socket_map: dict[str, CircusSocket] = {} + prometheus_dir = ensure_prometheus_dir() + logger.debug("Runner map: %s", runner_map) + with contextlib.ExitStack() as port_stack: + api_port = port_stack.enter_context( + reserve_free_port(host=host, port=port, enable_so_reuseport=True) + ) + + args = [ + "-m", + SCRIPT_GRPC_API_SERVER, + bento_identifier, + "--host", + host, + "--port", + str(api_port), + "--runner-map", + json.dumps(runner_map), + "--working-dir", + working_dir, + "--worker-id", + "$(CIRCUS.WID)", + ] + if reflection: + args.append("--enable-reflection") + + if max_concurrent_streams: + args.extend( + [ + "--max-concurrent-streams", + str(max_concurrent_streams), + ] ) - finally: - if uds_path is not None: - shutil.rmtree(uds_path) + + watchers.append( + create_watcher( + name="grpc_api_server", + args=args, + use_sockets=False, + working_dir=working_dir, + numprocesses=api_workers or math.ceil(CpuResource.from_system()), + ) + ) + + if BentoMLContainer.api_server_config.metrics.enabled.get(): + metrics_host = BentoMLContainer.grpc.metrics.host.get() + metrics_port = BentoMLContainer.grpc.metrics.port.get() + + circus_socket_map[PROMETHEUS_SERVER_NAME] = CircusSocket( + name=PROMETHEUS_SERVER_NAME, + host=metrics_host, + port=metrics_port, + backlog=backlog, + ) + + watchers.append( + create_watcher( + name="prom_server", + args=[ + "-m", + SCRIPT_GRPC_PROMETHEUS_SERVER, + "--fd", + f"$(circus.sockets.{PROMETHEUS_SERVER_NAME})", + "--prometheus-dir", + prometheus_dir, + "--backlog", + f"{backlog}", + ], + working_dir=working_dir, + numprocesses=1, + singleton=True, + ) + ) + + logger.info( + PROMETHEUS_MESSAGE, + "gRPC", + bento_identifier, + f"http://{metrics_host}:{metrics_port}", + ) + arbiter = create_standalone_arbiter( + watchers=watchers, sockets=list(circus_socket_map.values()) + ) + with track_serve(svc, production=True, component=API_SERVER): + arbiter.start( + cb=lambda _: logger.info( # type: ignore + 'Starting bare %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)', + "gRPC", + bento_identifier, + host, + port, + ), + ) diff --git a/bentoml/testing/server.py b/bentoml/testing/server.py index d520e0a88d..8553f0b69f 100644 --- a/bentoml/testing/server.py +++ b/bentoml/testing/server.py @@ -18,10 +18,11 @@ from typing import TYPE_CHECKING from contextlib import contextmanager +import psutil + from .._internal.tag import Tag from .._internal.utils import reserve_free_port from .._internal.utils import cached_contextmanager -from .._internal.utils.platform import kill_subprocess_tree logger = logging.getLogger("bentoml") @@ -75,6 +76,19 @@ async def async_request( return r.status, Headers(headers), r_body +def kill_subprocess_tree(p: subprocess.Popen[t.Any]) -> None: + """ + Tell the process to terminate and kill all of its children. Availabe both on Windows and Linux. + Note: It will return immediately rather than wait for the process to terminate. + Args: + p: subprocess.Popen object + """ + if psutil.WINDOWS: + subprocess.call(["taskkill", "/F", "/T", "/PID", str(p.pid)]) + else: + p.terminate() + + def _wait_until_api_server_ready( host_url: str, timeout: float, diff --git a/bentoml_cli/containerize.py b/bentoml_cli/containerize.py index e6964ca23e..bf5d6959f7 100644 --- a/bentoml_cli/containerize.py +++ b/bentoml_cli/containerize.py @@ -3,14 +3,11 @@ import sys import typing as t import logging +from typing import TYPE_CHECKING -import click - -from bentoml.bentos import containerize as containerize_bento -from bentoml._internal.utils import kwargs_transformers -from bentoml._internal.utils.docker import validate_tag - -logger = logging.getLogger("bentoml") +if TYPE_CHECKING: + F = t.Callable[..., t.Any] + from click import Group def containerize_transformer( @@ -23,7 +20,15 @@ def containerize_transformer( return value -def add_containerize_command(cli: click.Group) -> None: +def add_containerize_command(cli: Group) -> None: + import click + + from bentoml.bentos import FEATURES + from bentoml.bentos import containerize as containerize_bento + from bentoml_cli.utils import kwargs_transformers + from bentoml._internal.utils.docker import validate_tag + from bentoml._internal.configuration.containers import BentoMLContainer + @cli.command() @click.argument("bento_tag", type=click.STRING) @click.option( @@ -160,35 +165,43 @@ def add_containerize_command(cli: click.Group) -> None: @click.option( "--ulimit", type=click.STRING, default=None, help="Ulimit options (default [])." ) + @click.option( + "--enable-features", + multiple=True, + nargs=1, + metavar="[features,]", + help=f"Enable additional BentoML features. Available features are: {', '.join(FEATURES)}.", + ) @kwargs_transformers(transformer=containerize_transformer) def containerize( # type: ignore bento_tag: str, - docker_image_tag: list[str], - add_host: t.Iterable[str], - allow: t.Iterable[str], - build_arg: t.List[str], - build_context: t.List[str], + docker_image_tag: tuple[str], + add_host: tuple[str], + allow: tuple[str], + build_arg: tuple[str], + build_context: tuple[str], builder: str, - cache_from: t.List[str], - cache_to: t.List[str], + cache_from: tuple[str], + cache_to: tuple[str], cgroup_parent: str, iidfile: str, - label: t.List[str], + label: tuple[str], load: bool, network: str, metadata_file: str, no_cache: bool, - no_cache_filter: t.List[str], - output: t.List[str], - platform: t.List[str], + no_cache_filter: tuple[str], + output: tuple[str], + platform: tuple[str], progress: t.Literal["auto", "tty", "plain"], pull: bool, push: bool, - secret: t.List[str], + secret: tuple[str], shm_size: str, ssh: str, target: str, ulimit: str, + enable_features: tuple[str], ) -> None: """Containerizes given Bento into a ready-to-use Docker image. @@ -217,11 +230,11 @@ def containerize( # type: ignore By doing so, BentoML will leverage Docker Buildx features such as multi-node builds for cross-platform images, Full BuildKit capabilities with all of the familiar UI from 'docker build'. - - We also pass all given args for 'docker buildx' through 'bentoml containerize' with ease. """ from bentoml._internal.utils import buildx + logger = logging.getLogger("bentoml") + # run health check whether buildx is install locally buildx.health() @@ -253,7 +266,7 @@ def containerize( # type: ignore key, value = label_str.split("=") labels[key] = value - output_ = None + output_: dict[str, t.Any] | None = None if output: output_ = {} for arg in output: @@ -276,6 +289,9 @@ def containerize( # type: ignore exit_code = not containerize_bento( bento_tag, docker_image_tag=docker_image_tag, + # containerize options + features=enable_features, + # docker options add_host=add_hosts, allow=allow_, build_args=build_args, @@ -291,7 +307,7 @@ def containerize( # type: ignore network=network, no_cache=no_cache, no_cache_filter=no_cache_filter, - output=output_, # type: ignore + output=output_, platform=platform, progress=progress, pull=pull, @@ -303,4 +319,17 @@ def containerize( # type: ignore target=target, ulimit=ulimit, ) + if not exit_code: + grpc_metrics_port = BentoMLContainer.grpc.metrics.port.get() + logger.info( + 'Successfully built docker image for "%s" with tags "%s"', + str(bento_tag), + ",".join(docker_image_tag), + ) + logger.info( + 'To run your newly built Bento container, use one of the above tags, and pass it to "docker run". i.e: "docker run -it --rm -p 3000:3000 %s". To use gRPC, pass "-e BENTOML_USE_GRPC=true -p %s:%s" to "docker run".', + docker_image_tag[0], + grpc_metrics_port, + grpc_metrics_port, + ) sys.exit(exit_code) diff --git a/bentoml_cli/serve.py b/bentoml_cli/serve.py index e34f54abe8..cde029fdbc 100644 --- a/bentoml_cli/serve.py +++ b/bentoml_cli/serve.py @@ -15,7 +15,7 @@ def add_serve_command(cli: click.Group) -> None: from bentoml._internal.log import configure_server_logging from bentoml._internal.configuration.containers import BentoMLContainer - @cli.command() + @cli.command(aliases=["serve-http"]) @click.argument("bento", type=click.STRING, default=".") @click.option( "--production", @@ -26,9 +26,10 @@ def add_serve_command(cli: click.Group) -> None: show_default=True, ) @click.option( + "-p", "--port", type=click.INT, - default=BentoMLContainer.service_port.get(), + default=BentoMLContainer.http.port.get(), help="The port to listen on for the REST api server", envvar="BENTOML_PORT", show_default=True, @@ -36,9 +37,10 @@ def add_serve_command(cli: click.Group) -> None: @click.option( "--host", type=click.STRING, - default=BentoMLContainer.service_host.get(), - help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]", + default=BentoMLContainer.http.host.get(), + help="The host to bind for the REST api server", envvar="BENTOML_HOST", + show_default=True, ) @click.option( "--api-workers", @@ -46,6 +48,7 @@ def add_serve_command(cli: click.Group) -> None: default=None, help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode", envvar="BENTOML_API_WORKERS", + show_default=True, ) @click.option( "--backlog", @@ -74,42 +77,49 @@ def add_serve_command(cli: click.Group) -> None: type=str, default=None, help="SSL certificate file", + show_default=True, ) @click.option( "--ssl-keyfile", type=str, default=None, help="SSL key file", + show_default=True, ) @click.option( "--ssl-keyfile-password", type=str, default=None, help="SSL keyfile password", + show_default=True, ) @click.option( "--ssl-version", type=int, default=None, help="SSL version to use (see stdlib 'ssl' module)", + show_default=True, ) @click.option( "--ssl-cert-reqs", type=int, default=None, help="Whether client certificate is required (see stdlib 'ssl' module)", + show_default=True, ) @click.option( "--ssl-ca-certs", type=str, default=None, help="CA certificates file", + show_default=True, ) @click.option( "--ssl-ciphers", type=str, default=None, help="Ciphers to use (see stdlib 'ssl' module)", + show_default=True, ) def serve( # type: ignore (unused warning) bento: str, @@ -128,36 +138,38 @@ def serve( # type: ignore (unused warning) ssl_ca_certs: str | None, ssl_ciphers: str | None, ) -> None: - """Start a :code:`BentoServer` from a given ``BENTO`` 🍱 + """Start a HTTP BentoServer from a given 🍱 - ``BENTO`` is the serving target, it can be the import as: - - the import path of a :code:`bentoml.Service` instance - - a tag to a Bento in local Bento store - - a folder containing a valid `bentofile.yaml` build file with a `service` field, which provides the import path of a :code:`bentoml.Service` instance - - a path to a built Bento (for internal & debug use only) + \b + BENTO is the serving target, it can be the import as: + - the import path of a 'bentoml.Service' instance + - a tag to a Bento in local Bento store + - a folder containing a valid 'bentofile.yaml' build file with a 'service' field, which provides the import path of a 'bentoml.Service' instance + - a path to a built Bento (for internal & debug use only) e.g.: \b Serve from a bentoml.Service instance source code (for development use only): - :code:`bentoml serve fraud_detector.py:svc` + 'bentoml serve fraud_detector.py:svc' \b Serve from a Bento built in local store: - :code:`bentoml serve fraud_detector:4tht2icroji6zput3suqi5nl2` - :code:`bentoml serve fraud_detector:latest` + 'bentoml serve fraud_detector:4tht2icroji6zput3suqi5nl2' + 'bentoml serve fraud_detector:latest' \b Serve from a Bento directory: - :code:`bentoml serve ./fraud_detector_bento` + 'bentoml serve ./fraud_detector_bento' \b - If :code:`--reload` is provided, BentoML will detect code and model store changes during development, and restarts the service automatically. + If '--reload' is provided, BentoML will detect code and model store changes during development, and restarts the service automatically. - The `--reload` flag will: - - be default, all file changes under `--working-dir` (default to current directory) will trigger a restart - - when specified, respect :obj:`include` and :obj:`exclude` under :obj:`bentofile.yaml` as well as the :obj:`.bentoignore` file in `--working-dir`, for code and file changes - - all model store changes will also trigger a restart (new model saved or existing model removed) + \b + The '--reload' flag will: + - be default, all file changes under '--working-dir' (default to current directory) will trigger a restart + - when specified, respect 'include' and 'exclude' under 'bentofile.yaml' as well as the '.bentoignore' file in '--working-dir', for code and file changes + - all model store changes will also trigger a restart (new model saved or existing model removed) """ configure_server_logging() @@ -170,9 +182,9 @@ def serve( # type: ignore (unused warning) "'--reload' is not supported with '--production'; ignoring" ) - from bentoml.serve import serve_production + from bentoml.serve import serve_http_production - serve_production( + serve_http_production( bento, working_dir=working_dir, port=port, @@ -188,13 +200,13 @@ def serve( # type: ignore (unused warning) ssl_ciphers=ssl_ciphers, ) else: - from bentoml.serve import serve_development + from bentoml.serve import serve_http_development - serve_development( + serve_http_development( bento, working_dir=working_dir, port=port, - host=DEFAULT_DEV_SERVER_HOST if host is None else host, + host=DEFAULT_DEV_SERVER_HOST if not host else host, reload=reload, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, @@ -204,3 +216,155 @@ def serve( # type: ignore (unused warning) ssl_ca_certs=ssl_ca_certs, ssl_ciphers=ssl_ciphers, ) + + from bentoml._internal.utils import add_experimental_docstring + + @cli.command(name="serve-grpc") + @click.argument("bento", type=click.STRING, default=".") + @click.option( + "--production", + type=click.BOOL, + help="Run the BentoServer in production mode", + is_flag=True, + default=False, + show_default=True, + ) + @click.option( + "-p", + "--port", + type=click.INT, + default=BentoMLContainer.grpc.port.get(), + help="The port to listen on for the REST api server", + envvar="BENTOML_PORT", + show_default=True, + ) + @click.option( + "--host", + type=click.STRING, + default=BentoMLContainer.grpc.host.get(), + help="The host to bind for the gRPC server", + envvar="BENTOML_HOST", + show_default=True, + ) + @click.option( + "--api-workers", + type=click.INT, + default=None, + help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode", + envvar="BENTOML_API_WORKERS", + show_default=True, + ) + @click.option( + "--reload", + type=click.BOOL, + is_flag=True, + help="Reload Service when code changes detected, this is only available in development mode", + default=False, + show_default=True, + ) + @click.option( + "--backlog", + type=click.INT, + default=BentoMLContainer.api_server_config.backlog.get(), + help="The maximum number of pending connections.", + show_default=True, + ) + @click.option( + "--working-dir", + type=click.Path(), + help="When loading from source code, specify the directory to find the Service instance", + default=".", + show_default=True, + ) + @click.option( + "--enable-reflection", + is_flag=True, + default=BentoMLContainer.grpc.reflection.enabled.get(), + type=click.BOOL, + help="Enable reflection.", + show_default=True, + ) + @click.option( + "--max-concurrent-streams", + default=BentoMLContainer.grpc.max_concurrent_streams.get(), + type=click.INT, + help="Maximum number of concurrent incoming streams to allow on a http2 connection.", + show_default=True, + ) + @add_experimental_docstring + def serve_grpc( # type: ignore (unused warning) + bento: str, + production: bool, + port: int, + host: str, + api_workers: int | None, + backlog: int, + reload: bool, + working_dir: str, + enable_reflection: bool, + max_concurrent_streams: int | None, + ): + """Start a gRPC BentoServer from a given 🍱 + + \b + BENTO is the serving target, it can be the import as: + - the import path of a 'bentoml.Service' instance + - a tag to a Bento in local Bento store + - a folder containing a valid 'bentofile.yaml' build file with a 'service' field, which provides the import path of a 'bentoml.Service' instance + - a path to a built Bento (for internal & debug use only) + + e.g.: + + \b + Serve from a bentoml.Service instance source code (for development use only): + 'bentoml serve-grpc fraud_detector.py:svc' + + \b + Serve from a Bento built in local store: + 'bentoml serve-grpc fraud_detector:4tht2icroji6zput3suqi5nl2' + 'bentoml serve-grpc fraud_detector:latest' + + \b + Serve from a Bento directory: + 'bentoml serve-grpc ./fraud_detector_bento' + + If '--reload' is provided, BentoML will detect code and model store changes during development, and restarts the service automatically. + + \b + The '--reload' flag will: + - be default, all file changes under '--working-dir' (default to current directory) will trigger a restart + - when specified, respect 'include' and 'exclude' under 'bentofile.yaml' as well as the '.bentoignore' file in '--working-dir', for code and file changes + - all model store changes will also trigger a restart (new model saved or existing model removed) + """ + configure_server_logging() + if production: + if reload: + logger.warning( + "'--reload' is not supported with '--production'; ignoring" + ) + + from bentoml.serve import serve_grpc_production + + serve_grpc_production( + bento, + working_dir=working_dir, + port=port, + host=host, + backlog=backlog, + api_workers=api_workers, + max_concurrent_streams=max_concurrent_streams, + reflection=enable_reflection, + ) + else: + from bentoml.serve import serve_grpc_development + + serve_grpc_development( + bento, + working_dir=working_dir, + port=port, + backlog=backlog, + reload=reload, + host=DEFAULT_DEV_SERVER_HOST if not host else host, + max_concurrent_streams=max_concurrent_streams, + reflection=enable_reflection, + ) diff --git a/bentoml_cli/start.py b/bentoml_cli/start.py index 786f5edb3d..6fd61bdd2f 100644 --- a/bentoml_cli/start.py +++ b/bentoml_cli/start.py @@ -12,7 +12,7 @@ def add_start_command(cli: click.Group) -> None: - from bentoml._internal.log import configure_server_logging + from bentoml._internal.utils import add_experimental_docstring from bentoml._internal.configuration.containers import BentoMLContainer @cli.command(hidden=True) @@ -41,7 +41,7 @@ def add_start_command(cli: click.Group) -> None: @click.option( "--port", type=click.INT, - default=BentoMLContainer.service_port.get(), + default=BentoMLContainer.http.port.get(), help="The port to listen on for the REST api server", envvar="BENTOML_PORT", show_default=True, @@ -49,7 +49,7 @@ def add_start_command(cli: click.Group) -> None: @click.option( "--host", type=click.STRING, - default=BentoMLContainer.service_host.get(), + default=BentoMLContainer.http.host.get(), help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]", envvar="BENTOML_HOST", ) @@ -60,6 +60,13 @@ def add_start_command(cli: click.Group) -> None: help="The maximum number of pending connections.", show_default=True, ) + @click.option( + "--api-workers", + type=click.INT, + default=None, + help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode", + envvar="BENTOML_API_WORKERS", + ) @click.option( "--working-dir", type=click.Path(), @@ -109,6 +116,7 @@ def add_start_command(cli: click.Group) -> None: default=None, help="Ciphers to use (see stdlib 'ssl' module)", ) + @add_experimental_docstring def start_http_server( # type: ignore (unused warning) bento: str, remote_runner: list[str] | None, @@ -118,6 +126,7 @@ def start_http_server( # type: ignore (unused warning) host: str, backlog: int, working_dir: str, + api_workers: int | None, ssl_certfile: str | None, ssl_keyfile: str | None, ssl_keyfile_password: str | None, @@ -126,7 +135,9 @@ def start_http_server( # type: ignore (unused warning) ssl_ca_certs: str | None, ssl_ciphers: str | None, ) -> None: - configure_server_logging() + """ + Start a HTTP API server standalone. This will be used inside Yatai. + """ if sys.path[0] != working_dir: sys.path.insert(0, working_dir) @@ -155,6 +166,7 @@ def start_http_server( # type: ignore (unused warning) port=port, host=host, backlog=backlog, + api_workers=api_workers, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, ssl_keyfile_password=ssl_keyfile_password, @@ -183,7 +195,7 @@ def start_http_server( # type: ignore (unused warning) @click.option( "--port", type=click.INT, - default=BentoMLContainer.service_port.get(), + default=BentoMLContainer.http.port.get(), help="The port to listen on for the REST api server", envvar="BENTOML_PORT", show_default=True, @@ -191,7 +203,7 @@ def start_http_server( # type: ignore (unused warning) @click.option( "--host", type=click.STRING, - default=BentoMLContainer.service_host.get(), + default=BentoMLContainer.http.host.get(), help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]", envvar="BENTOML_HOST", ) @@ -209,6 +221,7 @@ def start_http_server( # type: ignore (unused warning) default=".", show_default=True, ) + @add_experimental_docstring def start_runner_server( # type: ignore (unused warning) bento: str, runner_name: str, @@ -218,7 +231,9 @@ def start_runner_server( # type: ignore (unused warning) backlog: int, working_dir: str, ) -> None: - configure_server_logging() + """ + Start Runner server standalone. This will be used inside Yatai. + """ if sys.path[0] != working_dir: sys.path.insert(0, working_dir) @@ -238,3 +253,95 @@ def start_runner_server( # type: ignore (unused warning) host=host, backlog=backlog, ) + + @cli.command(hidden=True) + @click.argument("bento", type=click.STRING, default=".") + @click.option( + "--remote-runner", + type=click.STRING, + multiple=True, + envvar="BENTOML_SERVE_RUNNER_MAP", + help="JSON string of runners map", + ) + @click.option( + "--port", + type=click.INT, + default=BentoMLContainer.grpc.port.get(), + help="The port to listen on for the gRPC server", + envvar="BENTOML_PORT", + show_default=True, + ) + @click.option( + "--host", + type=click.STRING, + default=BentoMLContainer.grpc.host.get(), + help="The host to bind for the gRPC server (defaults: 0.0.0.0)", + envvar="BENTOML_HOST", + ) + @click.option( + "--backlog", + type=click.INT, + default=BentoMLContainer.api_server_config.backlog.get(), + help="The maximum number of pending connections.", + show_default=True, + ) + @click.option( + "--working-dir", + type=click.Path(), + help="When loading from source code, specify the directory to find the Service instance", + default=".", + show_default=True, + ) + @click.option( + "--api-workers", + type=click.INT, + default=None, + help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode", + envvar="BENTOML_API_WORKERS", + ) + @click.option( + "--enable-reflection", + is_flag=True, + default=BentoMLContainer.grpc.reflection.enabled.get(), + type=click.BOOL, + help="Enable reflection.", + ) + @click.option( + "--max-concurrent-streams", + default=BentoMLContainer.grpc.max_concurrent_streams.get(), + type=click.INT, + help="Maximum number of concurrent incoming streams to allow on a http2 connection.", + ) + @add_experimental_docstring + def start_grpc_server( # type: ignore (unused warning) + bento: str, + remote_runner: list[str] | None, + port: int, + host: str, + backlog: int, + api_workers: int | None, + working_dir: str, + enable_reflection: bool, + max_concurrent_streams: int | None, + ) -> None: + """ + Start a gRPC API server standalone. This will be used inside Yatai. + """ + if sys.path[0] != working_dir: + sys.path.insert(0, working_dir) + + from bentoml.start import start_grpc_server + + runner_map = dict([s.split("=", maxsplit=2) for s in remote_runner or []]) + logger.info(" Using remote runners: %s", runner_map) + start_grpc_server( + bento, + runner_map=runner_map, + working_dir=working_dir, + port=port, + host=host, + backlog=backlog, + api_workers=api_workers, + reflection=enable_reflection, + max_concurrent_streams=max_concurrent_streams, + ) diff --git a/bentoml_cli/utils.py b/bentoml_cli/utils.py index 62593975a9..bc60ed9dde 100644 --- a/bentoml_cli/utils.py +++ b/bentoml_cli/utils.py @@ -15,6 +15,8 @@ from bentoml.exceptions import BentoMLException from bentoml._internal.log import configure_logging +from bentoml._internal.configuration import DEBUG_ENV_VAR +from bentoml._internal.configuration import QUIET_ENV_VAR from bentoml._internal.configuration import get_debug_mode from bentoml._internal.configuration import set_debug_mode from bentoml._internal.configuration import set_quiet_mode @@ -28,6 +30,7 @@ from click import Command from click import Context from click import Parameter + from click import HelpFormatter P = t.ParamSpec("P") @@ -49,9 +52,39 @@ def __call__( # pylint: disable=no-method-argument logger = logging.getLogger("bentoml") +def kwargs_transformers( + _func: F[t.Any] | None = None, + *, + transformer: F[t.Any], +) -> F[t.Any]: + def decorator(func: F[t.Any]) -> t.Callable[P, t.Any]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: + return func(*args, **{k: transformer(v) for k, v in kwargs.items()}) + + return wrapper + + if _func is None: + return decorator + return decorator(_func) + + class BentoMLCommandGroup(click.Group): - """Click command class customized for BentoML CLI, allow specifying a default + """ + Click command class customized for BentoML CLI, allow specifying a default command for each group defined. + + This command groups will also introduce support for aliases for commands. + + Example: + + .. code-block:: python + + @click.group(cls=BentoMLCommandGroup) + def cli(): ... + + @cli.command(aliases=["serve-http"]) + def serve(): ... """ NUMBER_OF_COMMON_PARAMS = 3 @@ -65,6 +98,7 @@ def bentoml_common_params(func: F[P]) -> WrappedCLI[bool, bool]: "--quiet", is_flag=True, default=False, + envvar=QUIET_ENV_VAR, help="Suppress all warnings and info logs", ) @click.option( @@ -72,6 +106,7 @@ def bentoml_common_params(func: F[P]) -> WrappedCLI[bool, bool]: "--debug", is_flag=True, default=False, + envvar=DEBUG_ENV_VAR, help="Generate debug information", ) @click.option( @@ -173,10 +208,17 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any: return t.cast("ClickFunctionWrapper[t.Any]", wrapper) + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super(BentoMLCommandGroup, self).__init__(*args, **kwargs) + # these two dictionaries will store known aliases for commands and groups + self._commands: dict[str, list[str]] = {} + self._aliases: dict[str, str] = {} + def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[F[P]], Command]: if "context_settings" not in kwargs: kwargs["context_settings"] = {} kwargs["context_settings"]["max_content_width"] = 120 + aliases = kwargs.pop("aliases", None) def wrapper(func: F[P]) -> Command: # add common parameters to command. @@ -191,10 +233,48 @@ def wrapper(func: F[P]) -> Command: wrapped.__click_params__[-self.NUMBER_OF_COMMON_PARAMS :] + wrapped.__click_params__[: -self.NUMBER_OF_COMMON_PARAMS] ) - return super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped) + cmd = super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped) + # add aliases to a given commands if it is specified. + if aliases is not None: + assert cmd.name + self._commands[cmd.name] = aliases + self._aliases.update({alias: cmd.name for alias in aliases}) + return cmd return wrapper + def resolve_alias(self, cmd_name: str): + return self._aliases[cmd_name] if cmd_name in self._aliases else cmd_name + + def get_command(self, ctx: Context, cmd_name: str) -> Command | None: + cmd_name = self.resolve_alias(cmd_name) + return super(BentoMLCommandGroup, self).get_command(ctx, cmd_name) + + def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None: + rows: list[tuple[str, str]] = [] + sub_commands = self.list_commands(ctx) + + max_len = max(len(cmd) for cmd in sub_commands) + limit = formatter.width - 6 - max_len + + for sub_command in sub_commands: + cmd = self.get_command(ctx, sub_command) + if cmd is None: + continue + # If the command is hidden, then we skip it. + if hasattr(cmd, "hidden") and cmd.hidden: + continue + if sub_command in self._commands: + aliases = ",".join(sorted(self._commands[sub_command])) + sub_command = "%s (%s)" % (sub_command, aliases) + # this cmd_help is available since click>=7 + # BentoML requires click>=7. + cmd_help = cmd.get_short_help_str(limit) + rows.append((sub_command, cmd_help)) + if rows: + with formatter.section("Commands"): + formatter.write_dl(rows) + def resolve_command( self, ctx: Context, args: list[str] ) -> tuple[str | None, Command | None, list[str]]: @@ -237,13 +317,9 @@ def unparse_click_params( Unparse click call to a list of arguments. Used to modify some parameters and restore to system command. The goal is to unpack cases where parameters can be parsed multiple times. - Refers to ./buildx.py for examples of this usage. This is also used to unparse parameters for running API server. - Args: - params (`dict[str, t.Any]`): - The dictionary of the parameters that is parsed from click.Context. - command_params (`list[click.Parameter]`): - The list of paramters (Arguments/Options) that is part of a given command. + params: The dictionary of the parameters that is parsed from click.Context. + command_params: The list of paramters (Arguments/Options) that is part of a given command. Returns: Unparsed list of arguments that can be redirected to system commands. diff --git a/bentoml_cli/worker/grpc_api_server.py b/bentoml_cli/worker/grpc_api_server.py new file mode 100644 index 0000000000..be6a35ee49 --- /dev/null +++ b/bentoml_cli/worker/grpc_api_server.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +import typing as t + +import click + + +@click.command() +@click.argument("bento_identifier", type=click.STRING, required=False, default=".") +@click.option("--host", type=click.STRING, required=False, default=None) +@click.option("--port", type=click.INT, required=False, default=None) +@click.option( + "--runner-map", + type=click.STRING, + envvar="BENTOML_RUNNER_MAP", + help="JSON string of runners map, default sets to envars `BENTOML_RUNNER_MAP`", +) +@click.option( + "--working-dir", + type=click.Path(exists=True), + help="Working directory for the API server", +) +@click.option( + "--worker-id", + required=False, + type=click.INT, + default=None, + help="If set, start the server as a bare worker with the given worker ID. Otherwise start a standalone server with a supervisor process.", +) +@click.option( + "--enable-reflection", + type=click.BOOL, + is_flag=True, + help="Enable reflection.", + default=False, +) +@click.option( + "--max-concurrent-streams", + type=click.INT, + help="Maximum number of concurrent incoming streams to allow on a HTTP2 connection.", + default=None, +) +def main( + bento_identifier: str, + host: str, + port: int, + runner_map: str | None, + working_dir: str | None, + worker_id: int | None, + enable_reflection: bool, + max_concurrent_streams: int | None, +): + """ + Start BentoML API server. + \b + This is an internal API, users should not use this directly. Instead use `bentoml serve-grpc [--options]` + """ + + import bentoml + from bentoml._internal.log import configure_server_logging + from bentoml._internal.context import component_context + from bentoml._internal.configuration.containers import BentoMLContainer + + component_context.component_type = "grpc_api_server" + component_context.component_index = worker_id + configure_server_logging() + + BentoMLContainer.development_mode.set(False) + if runner_map is not None: + BentoMLContainer.remote_runner_mapping.set(json.loads(runner_map)) + + svc = bentoml.load(bento_identifier, working_dir=working_dir, standalone_load=True) + + # setup context + if svc.tag is None: + component_context.bento_name = f"*{svc.__class__.__name__}" + component_context.bento_version = "not available" + else: + component_context.bento_name = svc.tag.name + component_context.bento_version = svc.tag.version + + from bentoml._internal.server import grpc + + grpc_options: dict[str, t.Any] = {"enable_reflection": enable_reflection} + if max_concurrent_streams: + grpc_options["max_concurrent_streams"] = int(max_concurrent_streams) + + grpc.Server( + grpc.Config(svc.grpc_servicer, bind_address=f"{host}:{port}", **grpc_options) + ).run() + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/bentoml_cli/worker/grpc_dev_api_server.py b/bentoml_cli/worker/grpc_dev_api_server.py new file mode 100644 index 0000000000..b672e5605d --- /dev/null +++ b/bentoml_cli/worker/grpc_dev_api_server.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import typing as t + +import click + + +@click.command() +@click.argument("bento_identifier", type=click.STRING, required=False, default=".") +@click.option("--host", type=click.STRING, required=False, default=None) +@click.option("--port", type=click.INT, required=False, default=None) +@click.option("--working-dir", required=False, type=click.Path(), default=None) +@click.option( + "--enable-reflection", + type=click.BOOL, + is_flag=True, + help="Enable reflection.", + default=False, +) +@click.option( + "--max-concurrent-streams", + type=int, + help="Maximum number of concurrent incoming streams to allow on a http2 connection.", + default=None, +) +def main( + bento_identifier: str, + host: str, + port: int, + working_dir: str | None, + enable_reflection: bool, + max_concurrent_streams: int | None, +): + import psutil + + from bentoml import load + from bentoml._internal.log import configure_server_logging + from bentoml._internal.context import component_context + from bentoml._internal.configuration.containers import BentoMLContainer + + component_context.component_type = "grpc_dev_api_server" + configure_server_logging() + + svc = load(bento_identifier, working_dir=working_dir, standalone_load=True) + if not port: + port = BentoMLContainer.grpc.port.get() + if not host: + host = BentoMLContainer.grpc.host.get() + + # setup context + if svc.tag is None: + component_context.bento_name = f"*{svc.__class__.__name__}" + component_context.bento_version = "not available" + else: + component_context.bento_name = svc.tag.name + component_context.bento_version = svc.tag.version + if psutil.WINDOWS: + import asyncio + + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore + + from bentoml._internal.server import grpc + + grpc_options: dict[str, t.Any] = {"enable_reflection": enable_reflection} + if max_concurrent_streams: + grpc_options["max_concurrent_streams"] = int(max_concurrent_streams) + + grpc.Server( + grpc.Config(svc.grpc_servicer, bind_address=f"{host}:{port}", **grpc_options) + ).run() + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/bentoml_cli/worker/grpc_prometheus_server.py b/bentoml_cli/worker/grpc_prometheus_server.py new file mode 100644 index 0000000000..47a1d8a4e6 --- /dev/null +++ b/bentoml_cli/worker/grpc_prometheus_server.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import typing as t +import logging +from typing import TYPE_CHECKING + +import click + +if TYPE_CHECKING: + from bentoml._internal import external_typing as ext + +logger = logging.getLogger("bentoml") + + +class GenerateLatestMiddleware: + def __init__(self, app: ext.ASGIApp): + from bentoml._internal.configuration.containers import BentoMLContainer + + self.app = app + self.metrics_client = BentoMLContainer.metrics_client.get() + + async def __call__( + self, scope: ext.ASGIScope, receive: ext.ASGIReceive, send: ext.ASGISend + ) -> None: + assert scope["type"] == "http" + assert scope["path"] == "/" + + from starlette.responses import Response + + return await Response( + self.metrics_client.generate_latest(), + status_code=200, + media_type=self.metrics_client.CONTENT_TYPE_LATEST, + )(scope, receive, send) + + +@click.command() +@click.option("--fd", type=click.INT, required=True) +@click.option("--backlog", type=click.INT, default=2048) +@click.option( + "--prometheus-dir", + type=click.Path(exists=True), + help="Required by prometheus to pass the metrics in multi-process mode", +) +def main(fd: int, backlog: int, prometheus_dir: str | None): + """ + Start a standalone Prometheus server to use with gRPC. + \b + This is an internal API, users should not use this directly. Instead use 'bentoml serve-grpc'. + Prometheus then can be accessed at localhost:9090 + """ + + import socket + + import psutil + import uvicorn + from starlette.middleware import Middleware + from starlette.applications import Starlette + from starlette.middleware.wsgi import WSGIMiddleware # TODO: a2wsgi + + from bentoml._internal.log import configure_server_logging + from bentoml._internal.context import component_context + from bentoml._internal.configuration import get_debug_mode + from bentoml._internal.configuration.containers import BentoMLContainer + + component_context.component_type = "prom_server" + + configure_server_logging() + + BentoMLContainer.development_mode.set(False) + metrics_client = BentoMLContainer.metrics_client.get() + if prometheus_dir is not None: + BentoMLContainer.prometheus_multiproc_dir.set(prometheus_dir) + + # create a ASGI app that wraps around the default HTTP prometheus server. + prom_app = Starlette( + debug=get_debug_mode(), middleware=[Middleware(GenerateLatestMiddleware)] + ) + prom_app.mount("/", WSGIMiddleware(metrics_client.make_wsgi_app())) + sock = socket.socket(fileno=fd) + + uvicorn_options: dict[str, t.Any] = { + "backlog": backlog, + "log_config": None, + "workers": 1, + } + if psutil.WINDOWS: + uvicorn_options["loop"] = "asyncio" + import asyncio + + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore + uvicorn.Server(uvicorn.Config(prom_app, **uvicorn_options)).run(sockets=[sock]) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/codecov.yml b/codecov.yml index f3f83382f6..6171b4a18a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,5 +1,5 @@ github_checks: - annotations: false + annotations: false comment: layout: "reach, diff, files" behavior: default @@ -22,7 +22,7 @@ ignore: coverage: precision: 2 round: down - range: "80...100" + range: "80...100" status: default_rules: flag_coverage_not_uploaded_behavior: exclude @@ -63,7 +63,8 @@ coverage: - transformers - xgboost - unit-tests - - e2e-tests + - e2e-tests-http + - e2e-tests-grpc catboost: target: auto threshold: 10% @@ -78,12 +79,12 @@ coverage: target: auto threshold: 10% flags: - - easyocr + - easyocr evalml: target: auto threshold: 10% flags: - - evalml + - evalml fastai: target: auto threshold: 10% @@ -208,7 +209,8 @@ coverage: target: auto threshold: 10% flags: - - e2e-tests + - e2e-tests-http + - e2e-tests-grpc unit: target: auto threshold: 10% @@ -328,11 +330,16 @@ flags: carryforward: true paths: - bentoml/_internal/frameworks/xgboost.py - e2e-tests: + e2e-tests-http: carryforward: true paths: - "bentoml/**/*" - bentoml/models.py + e2e-tests-grpc: + carryforward: true + paths: + - "bentoml/**/*" + - bentoml/grpc/utils.py unit-tests: carryforward: true paths: diff --git a/pyproject.toml b/pyproject.toml index 73cf80c53d..9dc787a2dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,5 @@ [build-system] -requires = [ - # sync with setup.py until we discard non-pep-517/518 - "wheel", - "setuptools>=59.0", - "setuptools-scm[toml]>=6.2.3", -] +requires = ["setuptools<=60", "setuptools_scm[toml]>=6.2", "wheel"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] @@ -13,6 +8,43 @@ git_describe_command = "git describe --dirty --tags --long --first-parent" version_scheme = "post-release" fallback_version = "0.0.0" +[tool.coverage.paths] +source = ["bentoml"] + +[tool.coverage.run] +branch = true +source = ["bentoml", "bentoml_cli"] +omit = [ + "bentoml/**/*_pb2.py", + "bentoml/__main__.py", + "bentoml/_internal/types.py", + "bentoml/_internal/external_typing/*", + "bentoml/testing/*", + "bentoml/io.py", +] + +[tool.coverage.report] +show_missing = true +precision = 2 +omit = [ + "*/bentoml/**/*_pb2*.py", + "*/bentoml/_internal/external_typing/*", + "*/bentoml/_internal/types.py", + "*/bentoml/testing/*", + '*/bentoml/__main__.py', + "*/bentoml/io.py", +] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "raise MissingDependencyException", + "except ImportError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [tool.black] line-length = 88 exclude = ''' @@ -24,104 +56,177 @@ exclude = ''' | \.venv | _build | build + | venv + | lib | dist | typings + | bentoml/grpc )/ | bentoml/_version.py ) ''' +extend-exclude = "(_pb2.py$|_pb2_grpc.py$)" [tool.pytest.ini_options] -addopts = "-rfEX -p pytester -p no:warnings" -python_files = ["test_*.py","*_test.py"] +addopts = "-rfEX -p pytester -p no:warnings -x --capture=tee-sys" +python_files = ["test_*.py", "*_test.py"] testpaths = ["tests"] -markers = ["gpus"] - -[tool.pylint.master] -ignore=".ipynb_checkpoints,typings,bentoml/_internal/external_typing" -ignore-paths=".*.pyi" -unsafe-load-any-extension="no" -extension-pkg-whitelist="numpy,tensorflow,torch,paddle,keras,pydantic" -jobs=4 -persistent="yes" -suggestion-mode="yes" -max-line-length=88 - -[tool.pylint.messages_control] -disable="import-error,print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call,deprecated-itertools-function,deprecated-types-field,next-method-defined,dict-items-not-iterating,dict-keys-not-iterating,dict-values-not-iterating,deprecated-operator-function,deprecated-urllib-function,xreadlines-attribute,deprecated-sys-function,exception-escape,comprehension-escape,logging-fstring-interpolation,logging-format-interpolation,logging-not-lazy,C,R,fixme,protected-access,no-member,unsubscriptable-object,raise-missing-from,isinstance-second-argument-not-valid-type,attribute-defined-outside-init,relative-beyond-top-level" -enable="c-extension-no-member" - -[tool.pylint.reports] -evaluation="10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)" -msg-template="{msg_id}:{symbol} [{line:0>3d}:{column:0>2d}] {obj}: {msg}" -output-format="colorized" -reports="no" -score="yes" +markers = ["gpus", "disable-tf-eager-execution"] -[tool.pylint.refactoring] -max-nested-blocks=5 -never-returning-functions="optparse.Values,sys.exit" +[tool.pylint.main] +recursive = true +extension-pkg-allow-list = [ + "numpy", + "tensorflow", + "torch", + "paddle", + "onnxruntime", + "onnx", + "pydantic.schema", +] +ignore-paths = ["typings", "bentoml/_internal/external_typing", "bentoml/grpc"] +disable = [ + "import-error", + "print-statement", + "parameter-unpacking", + "unpacking-in-except", + "old-raise-syntax", + "backtick", + "raw-checker-failed", + "bad-inline-option", + "locally-disabled", + "file-ignored", + "suppressed-message", + "useless-suppression", + "deprecated-pragma", + "apply-builtin", + "basestring-builtin", + "buffer-builtin", + "cmp-builtin", + "coerce-builtin", + "execfile-builtin", + "file-builtin", + "long-builtin", + "raw_input-builtin", + "reduce-builtin", + "standarderror-builtin", + "coerce-method", + "delslice-method", + "getslice-method", + "setslice-method", + "no-absolute-import", + "old-division", + "dict-iter-method", + "dict-view-method", + "next-method-called", + "metaclass-assignment", + "indexing-exception", + "raising-string", + "reload-builtin", + "oct-method", + "hex-method", + "nonzero-method", + "cmp-method", + "input-builtin", + "round-builtin", + "intern-builtin", + "unichr-builtin", + "map-builtin-not-iterating", + "zip-builtin-not-iterating", + "range-builtin-not-iterating", + "filter-builtin-not-iterating", + "using-cmp-argument", + "exception-message-attribute", + "invalid-str-codec", + "sys-max-int", + "bad-python3-import", + "deprecated-string-function", + "deprecated-str-translate-call", + "deprecated-itertools-function", + "deprecated-types-field", + "next-method-defined", + "dict-items-not-iterating", + "dict-keys-not-iterating", + "dict-values-not-iterating", + "deprecated-operator-function", + "deprecated-urllib-function", + "xreadlines-attribute", + "deprecated-sys-function", + "exception-escape", + "comprehension-escape", + "logging-fstring-interpolation", + "logging-format-interpolation", + "logging-not-lazy", + "C", + "R", + "fixme", + "protected-access", + "no-member", + "unsubscriptable-object", + "raise-missing-from", + "isinstance-second-argument-not-valid-type", + "attribute-defined-outside-init", + "relative-beyond-top-level", +] +enable = ["c-extension-no-member"] +evaluation = "10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)" +msg-template = "{msg_id}:{symbol} [{line:0>3d}:{column:0>2d}] {obj}: {msg}" +output-format = "colorized" +score = true + +[tool.pylint.classes] +valid-metaclass-classmethod-first-arg = ["cls", "mcls", "kls"] + +[tool.pylint.logging] +logging-format-style = "old" # using %s formatter for logging (performance-related) [tool.pylint.miscellaneous] -notes="FIXME,XXX,TODO,NOTE" +notes = ["FIXME", "XXX", "TODO", "NOTE", "WARNING"] -[tool.pylint.typecheck] -ignored-classes="Namespace" -contextmanager-decorators="contextlib.contextmanager" +[tool.pylint.refactoring] +# specify functions that should not return +never-returning-functions = ["sys.exit"] -[tool.pylint.format] -indent-after-paren=4 -indent-string=' ' -max-line-length=100 -max-module-lines=1000 -single-line-class-stmt="no" -single-line-if-stmt="no" +[tool.pylint.spelling] +spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:,pylint:,type:" -[tool.pylint.imports] -allow-wildcard-with-all="no" -analyse-fallback-blocks="no" +[tool.pylint.variables] +init-import = true -[tool.pylint.classes] -defining-attr-methods="__init__,__new__,setUp" -exclude-protected="_asdict,_fields,_replace,_source,_make" -valid-classmethod-first-arg="cls" -valid-metaclass-classmethod-first-arg="mcs" - -[tool.pylint.design] -# Maximum number of arguments for function / method -max-args=5 -# Maximum number of attributes for a class (see R0902). -max-attributes=7 -# Maximum number of boolean expressions in a if statement -max-bool-expr=5 -# Maximum number of branch for function / method body -max-branches=12 -# Maximum number of locals for function / method body -max-locals=15 -# Maximum number of parents for a class (see R0901). -max-parents=7 -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 -# Maximum number of return / yield for function / method body -max-returns=6 -# Maximum number of statements in function / method body -max-statements=50 -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 +[tool.pylint.typecheck] +contextmanager-decorators = [ + "contextlib.contextmanager", + "bentoml._internal.utils.cached_contextmanager", +] [tool.isort] profile = "black" +line_length = 88 length_sort = true force_single_line = true order_by_type = true force_alphabetical_sort_within_sections = true -skip_glob = ["typings/*", "docs/*"] +skip_glob = [ + "typings/*", + "test/*", + "**/*_pb2.py", + "**/*_pb2_grpc.py", + "venv/*", + "lib/*", +] [tool.pyright] pythonVersion = "3.10" -include = ["bentoml"] -exclude = ['bentoml/_version.py','bentoml/__main__.py'] -analysis.useLibraryCodeForTypes = true +include = ["bentoml/"] +exclude = [ + 'bentoml/_version.py', + 'bentoml/__main__.py', + 'bentoml/_internal/external_typing/', + 'bentoml/grpc/v1alpha1/', + '**/*_pb2.py', + "**/*_pb2_grpc.py", +] +useLibraryCodeForTypes = true strictListInference = true strictDictionaryInference = true strictSetInference = true diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index cf6a6b9b4f..6a6a481f78 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -1,5 +1,4 @@ # Dev dependencies -r tests-requirements.txt -setup-cfg-fmt twine wheel diff --git a/requirements/tests-requirements.txt b/requirements/tests-requirements.txt index 18a506eb51..fa6822a0f3 100644 --- a/requirements/tests-requirements.txt +++ b/requirements/tests-requirements.txt @@ -8,11 +8,13 @@ pydantic pylint>=2.14.0 pytest-cov>=3.0.0 pytest>=6.2.0 +pytest-xdist pytest-asyncio pandas scikit-learn imageio>=2.5.0 -watchfiles>=0.15.0 pyarrow -build >=0.8.0 +build[virtualenv] >=0.8.0 yamllint +grpcio-tools>=1.41.0 +opentelemetry-test-utils==0.33b0 diff --git a/scripts/generate_grpc_stubs.sh b/scripts/generate_grpc_stubs.sh new file mode 100755 index 0000000000..ae859fbe11 --- /dev/null +++ b/scripts/generate_grpc_stubs.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +GIT_ROOT=$(git rev-parse --show-toplevel) +STUBS_GENERATOR="bentoml/stubs-generator" + +cd "$GIT_ROOT" || exit 1 + +main() { + local VERSION="${1:-v1alpha1}" + # Use inline heredoc for even faster build + # Keeping image as cache should be fine since we only want to generate the stubs. + if [[ $(docker images --filter=reference="$STUBS_GENERATOR" -q) == "" ]] || test "$(git diff --name-only --diff-filter=d -- "$0")"; then + docker buildx build --platform=linux/amd64 -t "$STUBS_GENERATOR" --load -f- . <=3.15.0, and 3.10 support are provided since protobuf>=3.19.0 + protobuf==3.19.4 + grpcio-tools==1.41 + mypy-protobuf>=3.3.0 +EOT + +RUN --mount=type=cache,target=/var/lib/apt \ + --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -q -y --no-install-recommends --allow-remove-essential bash build-essential ca-certificates + +RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements.txt \ + && rm -rf /workspace/requirements.txt + +EOF + fi + + echo "Generating gRPC stubs..." + find bentoml/grpc/"$VERSION" -type f -name "*.proto" -exec docker run --rm -it -v "$GIT_ROOT":/workspace --platform=linux/amd64 "$STUBS_GENERATOR" python -m grpc_tools.protoc -I. --grpc_python_out=. --python_out=. --mypy_out=. --mypy_grpc_out=. "{}" \; +} + +if [ "${#}" -gt 1 ]; then + echo "$0 takes one optional argument. Usage: $0 [v1alpha2]" + exit 1 +fi +main "$@" diff --git a/scripts/release_pypi.sh b/scripts/release_pypi.sh index db0b6953ad..07fb37e08f 100755 --- a/scripts/release_pypi.sh +++ b/scripts/release_pypi.sh @@ -52,7 +52,11 @@ else fi echo "Generating PyPI source distribution..." -cd "$GIT_ROOT" +cd "$GIT_ROOT" || exit 1 + +# generate gRPC stubs +./scripts/generate_grpc_stubs.sh + python3 -m build -s -w # Use testpypi by default, run script with: "REPO=pypi release.sh" for diff --git a/setup.cfg b/setup.cfg index ca71c0eaa4..4a4fce4a84 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,7 @@ install_requires = uvicorn watchfiles>=0.15.0 backports.cached-property;python_version<"3.8" + backports.shutil_copytree;python_version<"3.8" importlib-metadata;python_version<"3.8" python_requires = >=3.7 include_package_data = True @@ -76,6 +77,7 @@ zip_safe = False include = # include bentoml packages bentoml + bentoml.grpc* bentoml.testing bentoml._internal* # include bentoml_cli packages @@ -87,7 +89,23 @@ console_scripts = bentoml = bentoml_cli.cli:cli [options.extras_require] +grpc = + # Restrict maximum version due to breaking protobuf 4.21.0 changes + # (see https://github.com/protocolbuffers/protobuf/issues/10051) + protobuf>=3.5.0, <3.20 + # Lowest version that support 3.10 + grpcio>=1.41.0 + grpcio-health-checking + grpcio-reflection + opentelemetry-instrumentation-grpc==0.33b0 +tracing.jaeger = + opentelemetry-exporter-jaeger +tracing.zipkin = + opentelemetry-exporter-zipkin +tracing.otlp = + opentelemetry-exporter-otlp tracing = + # kept for compatibility opentelemetry-exporter-jaeger opentelemetry-exporter-zipkin opentelemetry-exporter-otlp @@ -105,30 +123,3 @@ keep_temp = false [sdist] formats = gztar - -[coverage:run] -omit = - bentoml/__main__.py - bentoml/_internal/types.py - bentoml/_internal/external_typing/* - bentoml/testing/* - bentoml/io.py - -[coverage:report] -show_missing = true -precision = 2 -omit = - bentoml/_internal/external_typing/* - bentoml/_internal/types.py - bentoml/testing/* - bentoml/__main__.py - bentoml/io.py -exclude_lines = - pragma: no cover - def __repr__ - raise AssertionError - raise NotImplementedError - raise MissingDependencyException - except ImportError - if __name__ == .__main__.: - if TYPE_CHECKING: diff --git a/tests/unit/_internal/io/test_numpy.py b/tests/unit/_internal/io/test_numpy.py index 811891af76..4f4d5765cf 100644 --- a/tests/unit/_internal/io/test_numpy.py +++ b/tests/unit/_internal/io/test_numpy.py @@ -32,7 +32,7 @@ def test_invalid_dtype(): generic = ExampleGeneric("asdf") with pytest.raises(BentoMLException) as e: _ = NumpyNdarray.from_sample(generic) # type: ignore (test exception) - assert "expects a numpy.array" in str(e.value) + assert "expects a 'numpy.array'" in str(e.value) @pytest.mark.parametrize("dtype, expected", [("float", "number"), (">U8", "integer")]) @@ -82,9 +82,7 @@ def test_numpy_openapi_responses(): def test_verify_numpy_ndarray(caplog: LogCaptureFixture): - partial_check = partial( - from_example._verify_ndarray, exception_cls=BentoMLException # type: ignore (test internal check) - ) + partial_check = partial(from_example.validate_array, exception_cls=BentoMLException) with pytest.raises(BentoMLException) as ex: partial_check(np.array(["asdf"])) @@ -94,10 +92,10 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture): partial_check(np.array([[1]])) assert f'Expecting ndarray of shape "{from_example._shape}"' in str(e.value) # type: ignore (testing message) - # test cases whwere reshape is failed + # test cases where reshape is failed example = NumpyNdarray.from_sample(np.ones((2, 2, 3))) example._enforce_shape = False # type: ignore (test internal check) example._enforce_dtype = False # type: ignore (test internal check) with caplog.at_level(logging.DEBUG): - example._verify_ndarray(np.array("asdf")) + example.validate_array(np.array("asdf")) assert "Failed to reshape" in caplog.text