diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
index 7e11288aa3..7ae205d5de 100644
--- a/.devcontainer/Dockerfile
+++ b/.devcontainer/Dockerfile
@@ -1,6 +1,6 @@
# syntax=docker/dockerfile:1.4-labs
-FROM python:3-bullseye
+FROM --platform=linux/amd64 python:3-bullseye
# [Option] Install zsh
ARG INSTALL_ZSH="true"
@@ -12,10 +12,8 @@ ARG ENABLE_NONROOT_DOCKER="true"
ARG USE_MOBY="true"
# [Option] Select CLI version
ARG CLI_VERSION="latest"
-
# Enable new "BUILDKIT" mode for Docker CLI
ENV DOCKER_BUILDKIT=1
-
ENV DEBIAN_FRONTEND=noninteractive
# Install needed packages and setup non-root user. Use a separate RUN statement to add your
@@ -23,16 +21,28 @@ ENV DEBIAN_FRONTEND=noninteractive
ARG USERNAME=automatic
ARG USER_UID=1000
ARG USER_GID=$USER_UID
-COPY library-scripts/*.sh /tmp/library-scripts/
-RUN --mount=type=cache,target=/var/cache/apt \
- --mount=type=cache,target=/var/lib/apt \
- apt-get update \
- && apt-get install -y build-essential software-properties-common vim \
+COPY .devcontainer/library-scripts/*.sh /tmp/library-scripts/
+
+RUN --mount=type=cache,target=/var/lib/apt \
+ --mount=type=cache,target=/var/cache/apt \
+ apt-get update -y \
+ # Remove imagemagick due to https://security-tracker.debian.org/tracker/CVE-2019-10131
+ && apt-get purge -y imagemagick imagemagick-6-common
+
+# install common packages
+RUN --mount=type=cache,target=/var/lib/apt \
+ --mount=type=cache,target=/var/cache/apt \
+ apt-get install -y build-essential software-properties-common vim \
&& /bin/bash /tmp/library-scripts/common-debian.sh "${INSTALL_ZSH}" "${USERNAME}" "${USER_UID}" "${USER_GID}" "${UPGRADE_PACKAGES}" "true" "true" \
# Use Docker script from script library to set things up
&& /bin/bash /tmp/library-scripts/docker-debian.sh "${ENABLE_NONROOT_DOCKER}" "/var/run/docker-host.sock" "/var/run/docker.sock" "${USERNAME}" "${USE_MOBY}" "${CLI_VERSION}" \
# Clean up
- && apt-get autoremove -y && apt-get clean -y && rm -rf /var/lib/apt/lists/* /tmp/library-scripts/
+ && rm -rf /var/lib/apt/lists/* /tmp/library-scripts/
+
+COPY requirements/*.txt /tmp/pip-tmp/
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install --no-warn-script-location -r /tmp/pip-tmp/dev-requirements.txt -r /tmp/pip-tmp/docs-requirements.txt \
+ && rm -rf /tmp/pip-tmp
# Setting the ENTRYPOINT to docker-init.sh will configure non-root access to
# the Docker socket if "overrideCommand": false is set in devcontainer.json.
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 8c3364e937..3cc51c5f0b 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -3,6 +3,7 @@
{
"name": "BentoML",
"dockerFile": "Dockerfile",
+ "context": "..",
"containerEnv": {
"BENTOML_DEBUG": "True",
"BENTOML_BUNDLE_LOCAL_BUILD": "True",
@@ -22,15 +23,16 @@
// Configure properties specific to VS Code.
"vscode": {
"extensions": [
- "ms-azuretools.vscode-docker",
- "ms-python.vscode-pylance",
"ms-python.python",
+ "ms-python.vscode-pylance",
+ "ms-azuretools.vscode-docker",
"ms-vsliveshare.vsliveshare",
"ms-python.black-formatter",
"ms-python.pylint",
"samuelcolvin.jinjahtml",
"GitHub.copilot",
- "esbenp.prettier-vscode"
+ "esbenp.prettier-vscode",
+ "VisualStudioExptTeam.intellicode-api-usage-examples"
],
"settings": {
"files.watcherExclude": {
@@ -52,18 +54,14 @@
"[jsonc]": {
"editor.defaultFormatter": "esbenp.prettier-vscode"
},
- "editor.minimap.enabled": true,
+ "editor.minimap.enabled": false,
"editor.formatOnSave": true,
- "editor.wordWrapColumn": 88,
+ "editor.wordWrapColumn": 88
}
}
},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
- "forwardPorts": [3000, 8080, 9000, 9090],
- // Link some default configs to codespace container.
- "postCreateCommand": "bash ./.devcontainer/lifecycle/post-create",
+ "forwardPorts": [3000, 3001],
// install BentoML and tools
- "postStartCommand": "bash ./.devcontainer/lifecycle/post-start",
- // Comment out to connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root.
- "remoteUser": "vscode"
+ "postStartCommand": "bash ./.devcontainer/lifecycle/post-start"
}
diff --git a/.devcontainer/lifecycle/post-create b/.devcontainer/lifecycle/post-create
deleted file mode 100755
index 5e72c9360a..0000000000
--- a/.devcontainer/lifecycle/post-create
+++ /dev/null
@@ -1,3 +0,0 @@
-#!/usr/bin/env bash
-
-pip install --user -r requirements/dev-requirements.txt
diff --git a/.devcontainer/lifecycle/post-start b/.devcontainer/lifecycle/post-start
index c36ed88eb7..6edd411cc3 100755
--- a/.devcontainer/lifecycle/post-start
+++ b/.devcontainer/lifecycle/post-start
@@ -7,4 +7,16 @@ git config --global pull.ff only
git fetch upstream --tags && git pull
# install editable wheels & tools for bentoml
-pip install --user -e ".[tracing]" --isolated
+pip install -e ".[tracing,grpc]" --verbose
+pip install -r requirements/dev-requirements.txt
+pip install -U "grpcio-tools>=1.41.0" "mypy-protobuf>=3.3.0"
+# generate stubs
+OPTS=(-I. --grpc_python_out=. --python_out=. --mypy_out=. --mypy_grpc_out=.)
+python -m grpc_tools.protoc "${OPTS[@]}" bentoml/grpc/v1alpha1/service.proto
+python -m grpc_tools.protoc "${OPTS[@]}" bentoml/grpc/v1alpha1/service_test.proto
+# uninstall broken protobuf typestubs
+pip uninstall -y types-protobuf
+
+# setup docker buildx
+docker buildx install
+docker buildx ls | grep bentoml-builder &>/dev/null || docker buildx create --use --name bentoml-builder --platform linux/amd64,linux/arm64 &>/dev/null
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 76f6387247..3a56683e15 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -257,9 +257,15 @@ jobs:
path: ${{ steps.cache-dir.outputs.dir }}
key: ${{ runner.os }}-tests-${{ hashFiles('requirements/tests-requirements.txt') }}
+ # Simulate ./scripts/generate_grpc_stubs.sh
+ - name: Generate gRPC stubs
+ run: |
+ pip install protobuf==3.19.4 "grpcio-tools==1.41"
+ find bentoml/grpc/v1alpha1 -type f -name "*.proto" -exec python -m grpc_tools.protoc -I. --grpc_python_out=. --python_out=. "{}" \;
+
- name: Install dependencies
run: |
- pip install -e .
+ pip install -e ".[grpc]"
pip install -r requirements/tests-requirements.txt
pip install -r tests/e2e/bento_server_general_features/requirements.txt
diff --git a/.gitignore b/.gitignore
index a2baefb2a7..69360490f2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -114,3 +114,10 @@ typings
# test files
catboost_info
+
+# ignore pyvenv
+pyvenv.cfg
+
+# generated stub that is included in distribution
+*_pb2*.py
+*_pb2*.pyi
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
index a1370404be..15a034c82d 100644
--- a/DEVELOPMENT.md
+++ b/DEVELOPMENT.md
@@ -18,22 +18,26 @@ If you are interested in proposing a new feature, make sure to create a new feat
2. Fork the BentoML project on [GitHub](https://github.com/bentoml/BentoML).
3. Clone the source code from your fork of BentoML's GitHub repository:
+
```bash
git clone git@github.com:username/BentoML.git && cd BentoML
```
4. Add the BentoML upstream remote to your local BentoML clone:
+
```bash
git remote add upstream git@github.com:bentoml/BentoML.git
```
5. Configure git to pull from the upstream remote:
+
```bash
git switch main # ensure you're on the main branch
git branch --set-upstream-to=upstream/main
```
6. Install BentoML with pip in editable mode:
+
```bash
pip install -e .
```
@@ -41,11 +45,13 @@ If you are interested in proposing a new feature, make sure to create a new feat
This installs BentoML in an editable state. The changes you make will automatically be reflected without reinstalling BentoML.
7. Install the BentoML development requirements:
+
```bash
pip install -r ./requirements/dev-requirements.txt
```
8. Test the BentoML installation either with `bash`:
+
```bash
bentoml --version
```
@@ -62,65 +68,72 @@ If you are interested in proposing a new feature, make sure to create a new feat
with VS Code
1. Confirm that you have the following installed:
- - [Python3.7+](https://www.python.org/downloads/)
- - VS Code with the [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) and [Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance) extensions
+
+ - [Python3.7+](https://www.python.org/downloads/)
+ - VS Code with the [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) and [Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance) extensions
2. Fork the BentoML project on [GitHub](https://github.com/bentoml/BentoML).
3. Clone the GitHub repository:
- 1. Open the command palette with Ctrl+Shift+P and type in 'clone'.
- 2. Select 'Git: Clone(Recursive)'.
- 3. Clone BentoML.
+
+ 1. Open the command palette with Ctrl+Shift+P and type in 'clone'.
+ 2. Select 'Git: Clone(Recursive)'.
+ 3. Clone BentoML.
4. Add an BentoML upstream remote:
- 1. Open the command palette and enter 'add remote'.
- 2. Select 'Git: Add Remote'.
- 3. Press enter to select 'Add remote' from GitHub.
- 4. Enter https://github.com/bentoml/BentoML.git to select the BentoML repository.
- 5. Name your remote 'upstream'.
+
+ 1. Open the command palette and enter 'add remote'.
+ 2. Select 'Git: Add Remote'.
+ 3. Press enter to select 'Add remote' from GitHub.
+ 4. Enter https://github.com/bentoml/BentoML.git to select the BentoML repository.
+ 5. Name your remote 'upstream'.
5. Pull from the BentoML upstream remote to your main branch:
- 1. Open the command palette and enter 'checkout'.
- 2. Select 'Git: Checkout to...'
- 3. Choose 'main' to switch to the main branch.
- 4. Open the command palette again and enter 'pull from'.
- 5. Click on 'Git: Pull from...'
- 6. Select 'upstream'.
+
+ 1. Open the command palette and enter 'checkout'.
+ 2. Select 'Git: Checkout to...'
+ 3. Choose 'main' to switch to the main branch.
+ 4. Open the command palette again and enter 'pull from'.
+ 5. Click on 'Git: Pull from...'
+ 6. Select 'upstream'.
6. Open a new terminal by clicking the Terminal dropdown at the top of the window, followed by the 'New Terminal' option. Next, add a virtual environment with this command:
```bash
python -m venv .venv
```
7. Click yes if a popup suggests switching to the virtual environment. Otherwise, go through these steps:
- 1. Open any python file in the directory.
- 2. Select the interpreter selector on the blue status bar at the bottom of the editor.
- ![vscode-status-bar](https://user-images.githubusercontent.com/489344/166984038-75f1f4bd-c896-43ee-a7ee-1b57fda359a3.png)
-
- 3. Switch to the path that includes .venv from the dropdown at the top.
- ![vscode-select-venv](https://user-images.githubusercontent.com/489344/166984060-170d25f5-a91f-41d3-96f4-4db3c21df7c8.png)
+ 1. Open any python file in the directory.
+ 2. Select the interpreter selector on the blue status bar at the bottom of the editor.
+ ![vscode-status-bar](https://user-images.githubusercontent.com/489344/166984038-75f1f4bd-c896-43ee-a7ee-1b57fda359a3.png)
+
+ 3. Switch to the path that includes .venv from the dropdown at the top.
+ ![vscode-select-venv](https://user-images.githubusercontent.com/489344/166984060-170d25f5-a91f-41d3-96f4-4db3c21df7c8.png)
8. Update your PowerShell execution policies. Win+x followed by the 'a' key opens the admin Windows PowerShell. Enter the following command to allow the virtual environment activation script to run:
```
Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
```
-
+
## Making Changes
using the Command Line
1. Make sure you're on the main branch.
+
```bash
git switch main
```
2. Use the git pull command to retrieve content from the BentoML Github repository.
+
```bash
git pull
```
3. Create a new branch and switch to it.
+
```bash
git switch -c my-new-branch-name
```
@@ -128,11 +141,13 @@ If you are interested in proposing a new feature, make sure to create a new feat
4. Make your changes!
5. Use the git add command to save the state of files you have changed.
+
```bash
git add
```
6. Commit your changes.
+
```bash
git commit
```
@@ -141,59 +156,71 @@ If you are interested in proposing a new feature, make sure to create a new feat
```bash
git push
```
-
+
using VS Code
1. Switch to the main branch:
- 1. Open the command palette with Ctrl+Shift+P.
- 2. Search for 'Git: Checkout to...'
- 3. Select 'main'.
+
+ 1. Open the command palette with Ctrl+Shift+P.
+ 2. Search for 'Git: Checkout to...'
+ 3. Select 'main'.
2. Pull from the upstream remote:
- 1. Open the command palette.
- 2. Enter and select 'Git: Pull...'
- 3. Select 'upstream'.
+
+ 1. Open the command palette.
+ 2. Enter and select 'Git: Pull...'
+ 3. Select 'upstream'.
3. Create and change to a new branch:
- 1. Type in 'Git: Create Branch...' in the command palette.
- 2. Enter a branch name.
+
+ 1. Type in 'Git: Create Branch...' in the command palette.
+ 2. Enter a branch name.
4. Make your changes!
5. Stage all your changes:
- 1. Enter and select 'Git: Stage All Changes...' in the command palette.
+
+ 1. Enter and select 'Git: Stage All Changes...' in the command palette.
6. Commit your changes:
- 1. Open the command palette and enter 'Git: Commit'.
+
+ 1. Open the command palette and enter 'Git: Commit'.
7. Push your changes:
- 1. Enter and select 'Git: Push...' in the command palette.
+ 1. Enter and select 'Git: Push...' in the command palette.
-
## Run BentoML with verbose/debug logging
To view internal debug loggings for development, set the `BENTOML_DEBUG` environment variable to `TRUE`:
+
```bash
export BENTOML_DEBUG=TRUE
```
And/or use the `--verbose` option when running `bentoml` CLI command, e.g.:
+
```bash
bentoml get IrisClassifier --verbose
```
## Style check, auto-formatting, type-checking
-formatter: [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort)
+formatter: [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort), [buf](https://github.com/bufbuild/buf)
-linter: [pylint](https://pylint.org/)
+linter: [pylint](https://pylint.org/), [buf](https://github.com/bufbuild/buf)
type checker: [pyright](https://github.com/microsoft/pyright)
+We are using [buf](https://github.com/bufbuild/buf) for formatting and linting
+of our proto files. Configuration can be found [here](./bentoml/grpc/buf.yaml).
+Currently, we are running `buf` with docker, hence we kindly ask our developers
+to have docker available. Docker installation can be found [here](https://docs.docker.com/get-docker/).
+
Run linter/format script:
+
```bash
make format
@@ -201,23 +228,30 @@ make lint
```
Run type checker:
+
```bash
make type
```
+## Editing proto files
+
+The proto files for the BentoML gRPC service are located under [`bentoml/grpc`](./bentoml/grpc/).
+The generated python files are not checked in the git repository, and are instead generated via this [`script`](./scripts/generate_grpc_stubs.sh).
+If you edit the proto files, make sure to run `./scripts/generate_grpc_stubs.sh` to
+regenerate the proto stubs.
+
## Deploy with your changes
Test test out your changes in an actual BentoML model deployment, you can create a new Bento with your custom BentoML source repo:
1. Install custom BentoML in editable mode. e.g.:
- * git clone your bentoml fork
- * `pip install -e PATH_TO_THE_FORK`
+ - git clone your bentoml fork
+ - `pip install -e PATH_TO_THE_FORK`
2. Set env var `export BENTOML_BUNDLE_LOCAL_BUILD=True` and `export SETUPTOOLS_USE_DISTUTILS=stdlib`
- * make sure you have the latest setuptools installed: `pip install -U setuptools`
+ - make sure you have the latest setuptools installed: `pip install -U setuptools`
3. Build a new Bento with `bentoml build` in your project directory
-4. The new Bento will include a wheel file built from the BentoML source, and
-`bentoml containrize` will install it to override the default BentoML installation in base image
-
+4. The new Bento will include a wheel file built from the BentoML source, and
+ `bentoml containerize` will install it to override the default BentoML installation in base image
### Distribute a custom BentoML release for your team
@@ -236,18 +270,18 @@ description: "file: ./README.md"
include:
- "*.py"
python:
- packages:
- - pandas
- - git+https://github.com/{YOUR_GITHUB_USERNAME}/bentoml@{YOUR_REVISION}
+ packages:
+ - pandas
+ - git+https://github.com/{YOUR_GITHUB_USERNAME}/bentoml@{YOUR_REVISION}
docker:
- system_packages:
- - git
+ system_packages:
+ - git
```
-
## Testing
Make sure to install all test dependencies:
+
```bash
pip install -r requirements/tests-requirements.txt
```
@@ -257,12 +291,14 @@ pip install -r requirements/tests-requirements.txt
You can run unit tests in two ways:
Run all unit tests directly with pytest:
+
```bash
# GIT_ROOT=$(git rev-parse --show-toplevel)
-pytest tests/unit --cov=bentoml --cov-config="$GIT_ROOT"/setup.cfg
+pytest tests/unit --cov=bentoml --cov-config="$GIT_ROOT"/pyproject.toml
```
Run all unit tests via `./scripts/ci/run_tests.sh`:
+
```bash
./scripts/ci/run_tests.sh unit
@@ -273,6 +309,7 @@ make tests-unit
### Integration tests
Run given tests after defining a target under `scripts/ci/config.yml` with `run_tests.sh`:
+
```bash
# example: run Keras TF1 integration tests
./scripts/ci/run_tests.sh keras_tf1
@@ -281,18 +318,26 @@ Run given tests after defining a target under `scripts/ci/config.yml` with `run_
### E2E tests
```bash
-# example: run e2e tests to check for general features
-./scripts/ci/run_tests.sh general_features
+# example: run e2e tests to check for http general features
+./scripts/ci/run_tests.sh http_server
+```
+
+### Running the whole suite
+
+To run the whole test suite, minus frameworks integration, you can use:
+
+```bash
+make tests-suite
```
### Adding new test suite
-
+
If you are adding new ML framework support, it is recommended that you also add a separate test suite in our CI. Currently we are using GitHub Actions to manage our CI/CD workflow.
We recommend using [`nektos/act`](https://github.com/nektos/act) to run and test Actions locally.
-
The following tests script [run_tests.sh](./scripts/ci/run_tests.sh) can be used to run tests locally.
+
```bash
./scripts/ci/run_tests.sh -h
Running unit/integration tests with pytest and generate coverage reports. Make sure that given testcases is defined under ./scripts/ci/config.yml.
@@ -312,6 +357,7 @@ Example:
```
All tests are then defined under [config.yml](./scripts/ci/config.yml) where each field follows the following format:
+
```yaml
: &tmpl
root_test_dir: "tests/integration/frameworks"
@@ -324,24 +370,25 @@ All tests are then defined under [config.yml](./scripts/ci/config.yml) where eac
By default, each of our frameworks tests files with the format: `test__impl.py`. If `is_dir` set to `true` we will try to match the given `` under `root_test_dir` to run tests from.
-| Keys | Type | Defintions |
-|------|------|------------|
-|`root_test_dir`| ``| root directory to run a given tests |
-|`is_dir`| ``| whether `target` is a directory instead of a file |
-|`override_name_or_path`| ``| optional way to override a tests file name if doesn't match our convention |
-|`dependencies`| ``| define additional dependencies required to run the tests, accepts `requirements.txt` format |
-|`external_scripts`| ``| optional shell scripts that can be run on top of `./scripts/ci/run_tests.sh` for given testsuite |
-|`type_tests`| ``| define type of tests for given `target` |
+| Keys | Type | Defintions |
+| ----------------------- | --------------------------------------- | ------------------------------------------------------------------------------------------------ |
+| `root_test_dir` | `` | root directory to run a given tests |
+| `is_dir` | `` | whether `target` is a directory instead of a file |
+| `override_name_or_path` | `` | optional way to override a tests file name if doesn't match our convention |
+| `dependencies` | `` | define additional dependencies required to run the tests, accepts `requirements.txt` format |
+| `external_scripts` | `` | optional shell scripts that can be run on top of `./scripts/ci/run_tests.sh` for given testsuite |
+| `type_tests` | `` | define type of tests for given `target` |
When `type_tests` is set to `e2e`, `./scripts/ci/run_tests.sh` will change current directory into the given `root_test_dir`, and will run the testsuite from there.
The reason why we encourage developers to use the scripts in CI is that under the hood when we use pytest, we will create a custom report for the given tests. This report can then be used as carryforward flags on codecov for consistent reporting.
Example:
+
```yaml
# e2e tests
-general_features:
- root_test_dir: "tests/e2e/bento_server_general_features"
+http:
+ root_test_dir: "tests/e2e/bento_server_http"
is_dir: true
type_tests: "e2e"
dependencies:
@@ -359,20 +406,18 @@ pytorch_lightning:
Refer to [config.yml](./scripts/ci/config.yml) for more examples.
-
## Python tools ecosystem
-Currently, BentoML is [PEP518](https://www.python.org/dev/peps/pep-0518/) compatible via `setup.cfg` and `pyproject.toml`.
- We also define most of our config for Python tools where:
- - `pyproject.toml` contains config for `setuptools`, `black`, `pytest`, `pylint`, `isort`, `pyright`
- - `setup.cfg` contains metadata for `bentoml` library and `coverage`
+Currently, BentoML is [PEP518](https://www.python.org/dev/peps/pep-0518/) compatible. We define package configuration via [`pyproject.toml`][https://github.com/bentoml/bentoml/blob/main/pyproject.toml].
## Benchmark
+
BentoML has moved its benchmark to [`bentoml/benchmark`](https://github.com/bentoml/benchmark).
## Optional: git hooks
BentoML also provides git hooks that developers can install with:
+
```bash
make hooks
```
@@ -385,6 +430,7 @@ on how to create a pull request on github. Name your pull request
with one of the following prefixes, e.g. "feat: add support for
PyTorch". This is based on the [Conventional Commits
specification](https://www.conventionalcommits.org/en/v1.0.0/#summary)
+
- feat: (new feature for the user, not a new feature for build script)
- fix: (bug fix for the user, not a fix to a build script)
- docs: (changes to the documentation)
@@ -406,4 +452,3 @@ your pull request.
Refers to [BentoML Documentation Guide](./docs/README.md) for how to build and write
docs.
-
diff --git a/MANIFEST.in b/MANIFEST.in
index 1bcf3f99be..4117e26c8c 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,12 +1,17 @@
# All files tracked by Git are included in PyPI distribution
+include bentoml/grpc/**/*_pb2*.py
+include bentoml/grpc/**/*.pyi
# Files to exclude in PyPI distribution
exclude CONTRIBUTING.md GOVERNANCE.md CODE_OF_CONDUCT.md DEVELOPMENT.md SECURITY.md
exclude Makefile MANIFEST.in
-exclude *.yml
+exclude *.yml *.yaml
exclude .git*
+exclude bentoml/grpc/buf.yaml
+exclude bentoml/_internal/frameworks/FRAMEWORK_TEMPLATE_PY
# Directories to exclude in PyPI package
+prune .devcontainer
prune requirements
prune tests
prune typings
@@ -18,6 +23,7 @@ prune .git*
prune */__pycache__
prune */.DS_Store
prune */.ipynb_checkpoints
+prune */README*
# Patterns to exclude from any directory
global-exclude *.py[cod]
diff --git a/Makefile b/Makefile
index d6ce033305..274e0139ca 100644
--- a/Makefile
+++ b/Makefile
@@ -8,19 +8,25 @@ USE_GPU ?= false
help: ## Show all Makefile targets
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
+.PHONY: format format-proto lint lint-proto type style clean
format: ## Running code formatter: black and isort
@./scripts/tools/formatter.sh
+format-proto: ## Running proto formatter: buf
+ @echo "Formatting proto files..."
+ docker run --init --rm --volume $(GIT_ROOT):/workspace --workdir /workspace bufbuild/buf format --config "/workspace/bentoml/grpc/buf.yaml" -w --path "bentoml/grpc"
lint: ## Running lint checker: pylint
@./scripts/tools/linter.sh
+lint-proto: ## Running proto lint checker: buf
+ @echo "Linting proto files..."
+ docker run --init --rm --volume $(GIT_ROOT):/workspace --workdir /workspace bufbuild/buf lint --config "/workspace/bentoml/grpc/buf.yaml" --error-format msvs --path "bentoml/grpc"
type: ## Running type checker: pyright
@./scripts/tools/type_checker.sh
+style: format lint format-proto lint-proto ## Running formatter and linter
clean: ## Clean all generated files
@echo "Cleaning all generated files..."
@cd $(GIT_ROOT)/docs && make clean
@cd $(GIT_ROOT) || exit 1
@find . -type f -name '*.py[co]' -delete -o -type d -name __pycache__ -delete
-hooks: __check_defined_FORCE ## Install pre-defined hooks
- @./scripts/install_hooks.sh
ci-%:
@@ -33,8 +39,10 @@ ci-format: ci-black ci-isort ## Running format check in CI: black, isort
.PHONY: ci-lint
ci-lint: ci-pylint ## Running lint check in CI: pylint
+.PHONY: tests-suite
+tests-suite: tests-unit tests-http_server tests-grpc_server ## Running BentoML tests suite (unit, e2e, integration)
-tests-%: check-defined-USE_GPU check-defined-USE_VERBOSE
+tests-%:
$(eval type :=$(subst tests-, , $@))
$(eval RUN_ARGS:=$(wordlist 2,$(words $(MAKECMDGOALS)),$(MAKECMDGOALS)))
$(eval __positional:=$(foreach t, $(RUN_ARGS), -$(t)))
@@ -78,15 +86,3 @@ install-spellchecker-deps: ## Inform users to install enchant depending on their
@echo Make sure to install enchant from your distros package manager
@exit 1
endif
-
-check_defined = $(strip $(foreach 1,$1, $(call __check_defined,$1,$(strip $(value 2)))))
-__check_defined = \
- $(if $(value $1),, \
- $(error Undefined $1$(if $2, ($2))$(if $(value @), \
- required by target `$@`)))
-check-defined-% : __check_defined_FORCE
- $(eval $@_target := $(subst check-defined-, ,$@))
- @:$(call check_defined, $*, $@_target)
-
-.PHONY : __check_defined_FORCE
-__check_defined_FORCE:
diff --git a/bentoml/_internal/bento/build_config.py b/bentoml/_internal/bento/build_config.py
index bd58f6cbce..906fba6a72 100644
--- a/bentoml/_internal/bento/build_config.py
+++ b/bentoml/_internal/bento/build_config.py
@@ -64,7 +64,7 @@ def _convert_python_version(py_version: str | None) -> str | None:
target_python_version = f"{major}.{minor}"
if target_python_version != py_version:
logger.warning(
- "BentoML will install the latest python%s instead of the specified version %s. To use the exact python version, use a custom docker base image. See https://docs.bentoml.org/en/latest/concepts/bento.html#custom-base-image-advanced",
+ "BentoML will install the latest 'python%s' instead of the specified 'python%s'. To use the exact python version, use a custom docker base image. See https://docs.bentoml.org/en/latest/concepts/bento.html#custom-base-image-advanced",
target_python_version,
py_version,
)
@@ -165,20 +165,27 @@ def __attrs_post_init__(self):
if self.base_image is not None:
if self.distro is not None:
logger.warning(
- f"docker base_image {self.base_image} is used, 'distro={self.distro}' option is ignored.",
+ "docker base_image %s is used, 'distro=%s' option is ignored.",
+ self.base_image,
+ self.distro,
)
if self.python_version is not None:
logger.warning(
- f"docker base_image {self.base_image} is used, 'python={self.python_version}' option is ignored.",
+ "docker base_image %s is used, 'python=%s' option is ignored.",
+ self.base_image,
+ self.python_version,
)
if self.cuda_version is not None:
logger.warning(
- f"docker base_image {self.base_image} is used, 'cuda_version={self.cuda_version}' option is ignored.",
+ "docker base_image %s is used, 'cuda_version=%s' option is ignored.",
+ self.base_image,
+ self.cuda_version,
)
if self.system_packages:
logger.warning(
- f"docker base_image {self.base_image} is used, "
- f"'system_packages={self.system_packages}' option is ignored.",
+ "docker base_image %s is used, 'system_packages=%s' option is ignored.",
+ self.base_image,
+ self.system_packages,
)
if self.distro is not None and self.cuda_version is not None:
@@ -225,14 +232,14 @@ def write_to_bento(
try:
setup_script = resolve_user_filepath(self.setup_script, build_ctx)
except FileNotFoundError as e:
- raise InvalidArgument(f"Invalid setup_script file: {e}")
+ raise InvalidArgument(f"Invalid setup_script file: {e}") from None
if not os.access(setup_script, os.X_OK):
message = f"{setup_script} is not executable."
if not psutil.WINDOWS:
raise InvalidArgument(
f"{message} Ensure the script has a shebang line, then run 'chmod +x {setup_script}'."
- )
- raise InvalidArgument(message)
+ ) from None
+ raise InvalidArgument(message) from None
copy_file_to_fs_folder(
setup_script, bento_fs, docker_folder, "setup_script"
)
@@ -428,11 +435,14 @@ class PythonOptions:
def __attrs_post_init__(self):
if self.requirements_txt and self.packages:
logger.warning(
- f'Build option python: `requirements_txt="{self.requirements_txt}"` found, will ignore the option: `packages="{self.packages}"`.'
+ "Build option python: 'requirements_txt={self.requirements_txt}' found, will ignore the option: 'packages=%s'.",
+ self.requirements_txt,
+ self.packages,
)
if self.no_index and (self.index_url or self.extra_index_url):
logger.warning(
- f'Build option python: `no_index="{self.no_index}"` found, will ignore `index_url` and `extra_index_url` option when installing PyPI packages.'
+ "Build option python: 'no_index=%s' found, will ignore 'index_url' and 'extra_index_url' option when installing PyPI packages.",
+ self.no_index,
)
def is_empty(self) -> bool:
@@ -479,8 +489,10 @@ def write_to_bento(self, bento_fs: FS, build_ctx: str) -> None:
pip_args.extend(self.pip_args.split())
with bento_fs.open(fs.path.combine(py_folder, "install.sh"), "w") as f:
- args = " ".join(map(quote, pip_args)) if pip_args else ""
- install_script_content = (
+ args = ["--no-warn-script-location"]
+ if pip_args:
+ args.extend(pip_args)
+ install_sh = (
"""\
#!/usr/bin/env bash
set -exuo pipefail
@@ -488,8 +500,8 @@ def write_to_bento(self, bento_fs: FS, build_ctx: str) -> None:
# Parent directory https://stackoverflow.com/a/246128/8643197
BASEDIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]:-$0}"; )" &> /dev/null && pwd 2> /dev/null; )"
-PIP_ARGS=(--no-warn-script-location """
- + args
+PIP_ARGS=("""
+ + " ".join(map(quote, args))
+ """)
# BentoML by default generates two requirement files:
@@ -525,11 +537,11 @@ def write_to_bento(self, bento_fs: FS, build_ctx: str) -> None:
echo "WARNING: using BentoML version ${existing_bentoml_version}"
fi
else
-pip install bentoml=="$BENTOML_VERSION"
+ pip install bentoml=="$BENTOML_VERSION"
fi
"""
)
- f.write(install_script_content)
+ f.write(install_sh)
if self.requirements_txt is not None:
requirements_txt_file = resolve_user_filepath(
diff --git a/bentoml/_internal/bento/build_dev_bentoml_whl.py b/bentoml/_internal/bento/build_dev_bentoml_whl.py
index ae7c28aaae..2b57662314 100644
--- a/bentoml/_internal/bento/build_dev_bentoml_whl.py
+++ b/bentoml/_internal/bento/build_dev_bentoml_whl.py
@@ -1,18 +1,17 @@
from __future__ import annotations
import os
-import shutil
import logging
-import tempfile
-
-from bentoml.exceptions import BentoMLException
from ..utils.pkg import source_locations
+from ...exceptions import BentoMLException
+from ...exceptions import MissingDependencyException
from ..configuration import is_pypi_installed_bentoml
logger = logging.getLogger(__name__)
BENTOML_DEV_BUILD = "BENTOML_BUNDLE_LOCAL_BUILD"
+_exc_message = f"'{BENTOML_DEV_BUILD}=True', which requires the 'pypa/build' package. Install development dependencies with 'pip install -r requirements/dev-requirements.txt' and try again."
def build_bentoml_editable_wheel(target_path: str) -> None:
@@ -28,11 +27,27 @@ def build_bentoml_editable_wheel(target_path: str) -> None:
# skip this entirely if BentoML is installed from PyPI
return
+ try:
+ from build.env import IsolatedEnvBuilder
+
+ from build import ProjectBuilder
+ except ModuleNotFoundError as e:
+ raise MissingDependencyException(_exc_message) from e
+
# Find bentoml module path
module_location = source_locations("bentoml")
if not module_location:
raise BentoMLException("Could not find bentoml module location.")
+ try:
+ from importlib import import_module
+
+ _ = import_module("bentoml.grpc.v1alpha1.service_pb2")
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ f"Generated stubs are not found. Make sure to run '{module_location}/scripts/generate_grpc_stubs.sh' beforehand to generate gRPC stubs."
+ ) from None
+
pyproject = os.path.abspath(os.path.join(module_location, "..", "pyproject.toml"))
# this is for BentoML developer to create Service containing custom development
@@ -42,17 +57,12 @@ def build_bentoml_editable_wheel(target_path: str) -> None:
logger.info(
"BentoML is installed in `editable` mode; building BentoML distribution with the local BentoML code base. The built wheel file will be included in the target bento."
)
- try:
- from build import ProjectBuilder
- except ModuleNotFoundError:
- raise BentoMLException(
- f"Environment variable {BENTOML_DEV_BUILD}=True detected, which requires the `pypa/build` package. Make sure to install all dev dependencies via `pip install -r requirements/dev-requirements.txt` and try again."
- )
-
- with tempfile.TemporaryDirectory() as dist_dir:
+ with IsolatedEnvBuilder() as env:
builder = ProjectBuilder(os.path.dirname(pyproject))
- builder.build("wheel", dist_dir)
- shutil.copytree(dist_dir, target_path)
+ builder.python_executable = env.executable
+ builder.scripts_dir = env.scripts_dir
+ env.install(builder.build_system_requires)
+ builder.build("wheel", target_path)
else:
logger.info(
"Custom BentoML build is detected. For a Bento to use the same build at serving time, add your custom BentoML build to the pip packages list, e.g. `packages=['git+https://github.com/bentoml/bentoml.git@13dfb36']`"
diff --git a/bentoml/_internal/bento/docker.py b/bentoml/_internal/bento/docker/__init__.py
similarity index 96%
rename from bentoml/_internal/bento/docker.py
rename to bentoml/_internal/bento/docker/__init__.py
index 433b25423e..0a7a954e3b 100644
--- a/bentoml/_internal/bento/docker.py
+++ b/bentoml/_internal/bento/docker/__init__.py
@@ -6,8 +6,8 @@
import attr
-from ...exceptions import InvalidArgument
-from ...exceptions import BentoMLException
+from bentoml.exceptions import InvalidArgument
+from bentoml.exceptions import BentoMLException
if TYPE_CHECKING:
P = t.ParamSpec("P")
@@ -109,7 +109,7 @@ class DistroSpec:
),
)
- supported_cuda_versions: t.List[str] = attr.field(
+ supported_cuda_versions: t.Optional[t.List[str]] = attr.field(
default=None,
validator=attr.validators.optional(
attr.validators.deep_iterable(
diff --git a/bentoml/_internal/bento/docker/entrypoint.sh b/bentoml/_internal/bento/docker/entrypoint.sh
index fd29ad0276..0f33ea19bf 100755
--- a/bentoml/_internal/bento/docker/entrypoint.sh
+++ b/bentoml/_internal/bento/docker/entrypoint.sh
@@ -10,19 +10,23 @@ _is_sourced() {
}
_main() {
- # for backwards compatibility with the yatai<1.0.0, adapting the old "yatai" command to the new "start" command
- if [ "${#}" -gt 0 ] && [ "${1}" = 'python' ] && [ "${2}" = '-m' ] && ([ "${3}" = 'bentoml._internal.server.cli.runner' ] || [ "${3}" = "bentoml._internal.server.cli.api_server" ]); then
+ # For backwards compatibility with the yatai<1.0.0, adapting the old "yatai" command to the new "start" command.
+ if [ "${#}" -gt 0 ] && [ "${1}" = 'python' ] && [ "${2}" = '-m' ] && { [ "${3}" = 'bentoml._internal.server.cli.runner' ] || [ "${3}" = "bentoml._internal.server.cli.api_server" ]; }; then # SC2235, use { } to avoid subshell overhead
if [ "${3}" = 'bentoml._internal.server.cli.runner' ]; then
set -- bentoml start-runner-server "${@:4}"
elif [ "${3}" = 'bentoml._internal.server.cli.api_server' ]; then
set -- bentoml start-http-server "${@:4}"
fi
- # if no arg or first arg looks like a flag
- elif [ -z "$@" ] || [ "${1:0:1}" = '-' ]; then
+ # If no arg or first arg looks like a flag.
+ elif [[ "$#" -eq 0 ]] || [[ "${1:0:1}" =~ '-' ]]; then
+ # This is provided for backwards compatibility with places where user may have
+ # discover this easter egg and use it in their scripts to run the container.
if [[ -v BENTOML_SERVE_COMPONENT ]]; then
echo "\$BENTOML_SERVE_COMPONENT is set! Calling 'bentoml start-*' instead"
if [ "${BENTOML_SERVE_COMPONENT}" = 'http_server' ]; then
- set -- bentoml start-rest-server "$@" "$BENTO_PATH"
+ set -- bentoml start-http-server "$@" "$BENTO_PATH"
+ elif [ "${BENTOML_SERVE_COMPONENT}" = 'grpc_server' ]; then
+ set -- bentoml start-grpc-server "$@" "$BENTO_PATH"
elif [ "${BENTOML_SERVE_COMPONENT}" = 'runner' ]; then
set -- bentoml start-runner-server "$@" "$BENTO_PATH"
fi
@@ -30,13 +34,21 @@ _main() {
set -- bentoml serve --production "$@" "$BENTO_PATH"
fi
fi
-
- # Overide the BENTOML_PORT if PORT env var is present. Used for Heroku **and Yatai**
+ # Overide the BENTOML_PORT if PORT env var is present. Used for Heroku and Yatai.
if [[ -v PORT ]]; then
echo "\$PORT is set! Overiding \$BENTOML_PORT with \$PORT ($PORT)"
export BENTOML_PORT=$PORT
fi
- exec "$@"
+ # Handle serve and start commands that is passed to the container.
+ # Assuming that serve and start commands are the first arguments
+ # Note that this is the recommended way going forward to run all bentoml containers.
+ if [ "${#}" -gt 0 ] && { [ "${1}" = 'serve' ] || [ "${1}" = 'serve-http' ] || [ "${1}" = 'serve-grpc' ] || [ "${1}" = 'start-http-server' ] || [ "${1}" = 'start-grpc-server' ] || [ "${1}" = 'start-runner-server' ]; }; then
+ exec bentoml "$@" "$BENTO_PATH"
+ else
+ # otherwise default to run whatever the command is
+ # This should allow running bash, sh, python, etc
+ exec "$@"
+ fi
}
if ! _is_sourced; then
diff --git a/bentoml/_internal/bento/docker/templates/base.j2 b/bentoml/_internal/bento/docker/templates/base.j2
index c6bf21e4ea..5c72e34024 100644
--- a/bentoml/_internal/bento/docker/templates/base.j2
+++ b/bentoml/_internal/bento/docker/templates/base.j2
@@ -2,7 +2,7 @@
{# users can use these values #}
{% import '_macros.j2' as common %}
{% set bento__entrypoint = bento__entrypoint | default(expands_bento_path("env", "docker", "entrypoint.sh", bento_path=bento__path)) %}
-# syntax = docker/dockerfile:1.4-labs
+# syntax = docker/dockerfile:1.4.3
#
# ===========================================
#
@@ -12,7 +12,7 @@
# Block SETUP_BENTO_BASE_IMAGE
{% block SETUP_BENTO_BASE_IMAGE %}
-FROM {{ __base_image__ }}
+FROM {{ __base_image__ }} as base-{{ __options__distro }}
ENV LANG=C.UTF-8
@@ -21,7 +21,6 @@ ENV LC_ALL=C.UTF-8
ENV PYTHONIOENCODING=UTF-8
ENV PYTHONUNBUFFERED=1
-
{% endblock %}
# Block SETUP_BENTO_USER
@@ -34,7 +33,6 @@ RUN groupadd -g $BENTO_USER_GID -o $BENTO_USER && useradd -m -u $BENTO_USER_UID
{% block SETUP_BENTO_ENVARS %}
{% if __options__env is not none %}
{% for key, value in __options__env.items() -%}
-
ENV {{ key }}={{ value }}
{% endfor -%}
{% endif -%}
@@ -46,25 +44,20 @@ ENV BENTOML_HOME={{ bento__home }}
RUN mkdir $BENTO_PATH && chown {{ bento__user }}:{{ bento__user }} $BENTO_PATH -R
WORKDIR $BENTO_PATH
-# init related components
COPY --chown={{ bento__user }}:{{ bento__user }} . ./
-
{% endblock %}
# Block SETUP_BENTO_COMPONENTS
{% block SETUP_BENTO_COMPONENTS %}
-
{% set __install_python_scripts__ = expands_bento_path("env", "python", "install.sh", bento_path=bento__path) %}
{% set __pip_cache__ = common.mount_cache("/root/.cache/pip") %}
# install python packages with install.sh
RUN {{ __pip_cache__ }} bash -euxo pipefail {{ __install_python_scripts__ }}
-
{% if __options__setup_script is not none %}
{% set __setup_script__ = expands_bento_path("env", "docker", "setup_script", bento_path=bento__path) %}
RUN chmod +x {{ __setup_script__ }}
RUN {{ __setup_script__ }}
{% endif %}
-
{% endblock %}
# Block SETUP_BENTO_ENTRYPOINT
@@ -72,6 +65,9 @@ RUN {{ __setup_script__ }}
# Default port for BentoServer
EXPOSE 3000
+# Expose Prometheus port
+EXPOSE {{ __prometheus_port__ }}
+
RUN chmod +x {{ bento__entrypoint }}
USER bentoml
diff --git a/bentoml/_internal/bento/gen.py b/bentoml/_internal/bento/gen.py
index 65db6a1a44..0c4935a02f 100644
--- a/bentoml/_internal/bento/gen.py
+++ b/bentoml/_internal/bento/gen.py
@@ -5,14 +5,16 @@
import logging
from sys import version_info
from typing import TYPE_CHECKING
+from dataclasses import asdict
+from dataclasses import dataclass
-import attr
from jinja2 import Environment
from jinja2.loaders import FileSystemLoader
+from ..utils import bentoml_cattr
from ..utils import resolve_user_filepath
from .docker import DistroSpec
-from ..configuration import CLEAN_BENTOML_VERSION
+from ..configuration.containers import BentoMLContainer
logger = logging.getLogger(__name__)
@@ -22,7 +24,7 @@
from .build_config import DockerOptions
TemplateFunc = t.Callable[[DockerOptions], t.Dict[str, t.Any]]
- GenericFunc = t.Callable[P, t.Any]
+ F = t.Callable[P, t.Any]
BENTO_UID_GID = 1034
BENTO_USER = "bentoml"
@@ -45,33 +47,34 @@ def expands_bento_path(*path: str, bento_path: str = BENTO_PATH) -> str:
return "/".join([bento_path, *path])
-J2_FUNCTION: dict[str, GenericFunc[t.Any]] = {
- "expands_bento_path": expands_bento_path,
-}
+J2_FUNCTION: dict[str, F[t.Any]] = {"expands_bento_path": expands_bento_path}
+
+to_preserved_field: t.Callable[[str], str] = lambda s: f"__{s}__"
+to_bento_field: t.Callable[[str], str] = lambda s: f"bento__{s}"
+to_options_field: t.Callable[[str], str] = lambda s: f"__options__{s}"
-@attr.frozen(on_setattr=None, eq=False, repr=False)
+@dataclass
class ReservedEnv:
base_image: str
- supported_architectures: list[str]
- bentoml_version: str = attr.field(default=CLEAN_BENTOML_VERSION)
- python_version: str = attr.field(
- default=f"{version_info.major}.{version_info.minor}"
- )
+ python_version: str = f"{version_info.major}.{version_info.minor}"
- def todict(self):
- return {f"__{k}__": v for k, v in attr.asdict(self).items()}
+ def asdict(self) -> dict[str, t.Any]:
+ return {
+ **{to_preserved_field(k): v for k, v in asdict(self).items()},
+ "__prometheus_port__": BentoMLContainer.grpc.metrics.port.get(),
+ }
-@attr.frozen(on_setattr=None, eq=False, repr=False)
+@dataclass
class CustomizableEnv:
- uid_gid: int = attr.field(default=BENTO_UID_GID)
- user: str = attr.field(default=BENTO_USER)
- home: str = attr.field(default=BENTO_HOME)
- path: str = attr.field(default=BENTO_PATH)
+ uid_gid: int = BENTO_UID_GID
+ user: str = BENTO_USER
+ home: str = BENTO_HOME
+ path: str = BENTO_PATH
- def todict(self) -> dict[str, str]:
- return {f"bento__{k}": v for k, v in attr.asdict(self).items()}
+ def asdict(self) -> dict[str, t.Any]:
+ return {to_bento_field(k): v for k, v in asdict(self).items()}
def get_templates_variables(
@@ -85,6 +88,10 @@ def get_templates_variables(
distro = options.distro
cuda_version = options.cuda_version
python_version = options.python_version
+
+ # these values will be set at with_defaults() if not provided
+ # so distro and python_version won't be None here.
+ assert distro and python_version
spec = DistroSpec.from_distro(
distro, cuda=cuda_version is not None, conda=use_conda
)
@@ -97,28 +104,32 @@ def get_templates_variables(
else:
python_version = python_version
base_image = spec.image.format(spec_version=python_version)
- supported_architecture = spec.supported_architectures
else:
base_image = options.base_image
- # TODO: allow user to specify supported architectures of the base image
- supported_architecture = ["amd64"]
logger.info(
- f"BentoML will not install Python to custom base images; ensure the base image '{base_image}' has Python installed."
+ "BentoML will not install Python to custom base images; ensure the base image '%s' has Python installed.",
+ base_image,
)
# environment returns are
- # __base_image__, __supported_architectures__, __bentoml_version__, __python_version_full__
+ # __base_image__, __python_version__, __prometheus_port__
# bento__uid_gid, bento__user, bento__home, bento__path
# __options__distros, __options__base_image, __options_env, __options_system_packages, __options_setup_script
return {
- **{f"__options__{k}": v for k, v in attr.asdict(options).items()},
- **CustomizableEnv().todict(),
- **ReservedEnv(base_image, supported_architecture).todict(),
+ **{
+ to_options_field(k): v
+ for k, v in bentoml_cattr.unstructure(options).items()
+ },
+ **CustomizableEnv().asdict(),
+ **ReservedEnv(base_image=base_image).asdict(),
}
def generate_dockerfile(
- options: DockerOptions, build_ctx: str, *, use_conda: bool
+ options: DockerOptions,
+ build_ctx: str,
+ *,
+ use_conda: bool,
) -> str:
"""
Generate a Dockerfile that containerize a Bento.
@@ -179,6 +190,7 @@ def generate_dockerfile(
if options.base_image is not None:
base = "base.j2"
else:
+ assert distro # distro will be set via 'with_defaults()'
spec = DistroSpec.from_distro(distro, cuda=use_cuda, conda=use_conda)
base = f"{spec.release_type}_{distro}.j2"
diff --git a/bentoml/_internal/configuration/__init__.py b/bentoml/_internal/configuration/__init__.py
index 8061ec6dcc..bbc1144108 100644
--- a/bentoml/_internal/configuration/__init__.py
+++ b/bentoml/_internal/configuration/__init__.py
@@ -5,6 +5,7 @@
from functools import lru_cache
from bentoml.exceptions import BentoMLException
+from bentoml.exceptions import BentoMLConfigException
try:
import importlib.metadata as importlib_metadata
@@ -29,6 +30,8 @@ class version_mod:
DEBUG_ENV_VAR = "BENTOML_DEBUG"
QUIET_ENV_VAR = "BENTOML_QUIET"
CONFIG_ENV_VAR = "BENTOML_CONFIG"
+# https://github.com/grpc/grpc/blob/master/doc/environment_variables.md
+GRPC_DEBUG_ENV_VAR = "GRPC_VERBOSITY"
def expand_env_var(env_var: str) -> str:
@@ -96,6 +99,7 @@ def get_bentoml_config_file_from_env() -> t.Optional[str]:
def set_debug_mode(enabled: bool) -> None:
os.environ[DEBUG_ENV_VAR] = str(enabled)
+ os.environ[GRPC_DEBUG_ENV_VAR] = "DEBUG"
logger.info(
f"{'Enabling' if enabled else 'Disabling'} debug mode for current BentoML session"
@@ -109,9 +113,9 @@ def get_debug_mode() -> bool:
def set_quiet_mode(enabled: bool) -> None:
- os.environ[DEBUG_ENV_VAR] = str(enabled)
-
# do not log setting quiet mode
+ os.environ[QUIET_ENV_VAR] = str(enabled)
+ os.environ[GRPC_DEBUG_ENV_VAR] = "NONE"
def get_quiet_mode() -> bool:
@@ -131,15 +135,15 @@ def load_global_config(bentoml_config_file: t.Optional[str] = None):
if bentoml_config_file:
if not bentoml_config_file.endswith((".yml", ".yaml")):
- raise Exception(
+ raise BentoMLConfigException(
"BentoML config file specified in ENV VAR does not end with `.yaml`: "
f"`BENTOML_CONFIG={bentoml_config_file}`"
- )
+ ) from None
if not os.path.isfile(bentoml_config_file):
raise FileNotFoundError(
"BentoML config file specified in ENV VAR not found: "
f"`BENTOML_CONFIG={bentoml_config_file}`"
- )
+ ) from None
bentoml_configuration = BentoMLConfiguration(
override_config_file=bentoml_config_file,
diff --git a/bentoml/_internal/configuration/containers.py b/bentoml/_internal/configuration/containers.py
index 35da7fdfe8..b83c728d89 100644
--- a/bentoml/_internal/configuration/containers.py
+++ b/bentoml/_internal/configuration/containers.py
@@ -54,20 +54,19 @@
"grpc",
"http",
)
-_check_sample_rate: t.Callable[[float], None] = (
- lambda sample_rate: logger.warning(
- "Tracing enabled, but sample_rate is unset or zero. No traces will be collected. "
- "Please refer to https://docs.bentoml.org/en/latest/guides/tracing.html for more details."
- )
- if sample_rate == 0.0
- else None
-)
_larger_than: t.Callable[[int | float], t.Callable[[int | float], bool]] = (
lambda target: lambda val: val > target
)
_larger_than_zero: t.Callable[[int | float], bool] = _larger_than(0)
+def _check_sample_rate(sample_rate: float) -> None:
+ if sample_rate == 0.0:
+ logger.warning(
+ "Tracing enabled, but sample_rate is unset or zero. No traces will be collected. Please refer to https://docs.bentoml.org/en/latest/guides/tracing.html for more details."
+ )
+
+
def _is_ip_address(addr: str) -> bool:
import socket
@@ -109,12 +108,9 @@ def _is_ip_address(addr: str) -> bool:
SCHEMA = Schema(
{
"api_server": {
- "port": And(int, _larger_than_zero),
- "host": And(str, _is_ip_address),
- "backlog": And(int, _larger_than(64)),
"workers": Or(And(int, _larger_than_zero), None),
"timeout": And(int, _larger_than_zero),
- "max_request_size": And(int, _larger_than_zero),
+ "backlog": And(int, _larger_than(64)),
Optional("ssl"): {
Optional("certfile"): Or(str, None),
Optional("keyfile"): Or(str, None),
@@ -143,14 +139,30 @@ def _is_ip_address(addr: str) -> bool:
"response_content_type": Or(bool, None),
},
},
- "cors": {
- "enabled": bool,
- "access_control_allow_origin": Or(str, None),
- "access_control_allow_credentials": Or(bool, None),
- "access_control_allow_headers": Or([str], str, None),
- "access_control_allow_methods": Or([str], str, None),
- "access_control_max_age": Or(int, None),
- "access_control_expose_headers": Or([str], str, None),
+ "http": {
+ "host": And(str, _is_ip_address),
+ "port": And(int, _larger_than_zero),
+ "cors": {
+ "enabled": bool,
+ "access_control_allow_origin": Or(str, None),
+ "access_control_allow_credentials": Or(bool, None),
+ "access_control_allow_headers": Or([str], str, None),
+ "access_control_allow_methods": Or([str], str, None),
+ "access_control_max_age": Or(int, None),
+ "access_control_expose_headers": Or([str], str, None),
+ },
+ },
+ "grpc": {
+ "host": And(str, _is_ip_address),
+ "port": And(int, _larger_than_zero),
+ "metrics": {
+ "port": And(int, _larger_than_zero),
+ "host": And(str, _is_ip_address),
+ },
+ "reflection": {"enabled": bool},
+ "max_concurrent_streams": Or(int, None),
+ "max_message_length": Or(int, None),
+ "maximum_concurrent_rpcs": Or(int, None),
},
},
"runners": {
@@ -193,6 +205,10 @@ def _is_ip_address(addr: str) -> bool:
}
)
+_WARNING_MESSAGE = (
+ "field 'api_server.%s' is deprecated and has been renamed to 'api_server.http.%s'"
+)
+
class BentoMLConfiguration:
def __init__(
@@ -224,7 +240,35 @@ def __init__(
f"Config file {override_config_file} not found"
)
with open(override_config_file, "rb") as f:
- override_config = yaml.safe_load(f)
+ override_config: dict[str, t.Any] = yaml.safe_load(f)
+
+ # compatibility layer with old configuration pre gRPC features
+ # api_server.[cors|port|host] -> api_server.http.$^
+ if "api_server" in override_config:
+ user_api_config = override_config["api_server"]
+ # max_request_size is deprecated
+ if "max_request_size" in user_api_config:
+ logger.warning(
+ "'api_server.max_request_size' is deprecated and has become obsolete."
+ )
+ user_api_config.pop("max_request_size")
+ # check if user are using older configuration
+ if "http" not in user_api_config:
+ user_api_config["http"] = {}
+ # then migrate these fields to newer configuration fields.
+ for field in ["port", "host", "cors"]:
+ if field in user_api_config:
+ old_field = user_api_config.pop(field)
+ user_api_config["http"][field] = old_field
+ logger.warning(_WARNING_MESSAGE, field, field)
+
+ config_merger.merge(override_config["api_server"], user_api_config)
+
+ assert all(
+ key not in override_config["api_server"]
+ for key in ["cors", "max_request_size", "host", "port"]
+ )
+
config_merger.merge(self.config, override_config)
global_runner_cfg = {k: self.config["runners"][k] for k in RUNNER_CFG_KEYS}
@@ -250,9 +294,13 @@ def __init__(
def override(self, keys: t.List[str], value: t.Any):
if keys is None:
- raise BentoMLConfigException("Configuration override key is None.")
+ raise BentoMLConfigException(
+ "Configuration override key is None."
+ ) from None
if len(keys) == 0:
- raise BentoMLConfigException("Configuration override key is empty.")
+ raise BentoMLConfigException(
+ "Configuration override key is empty."
+ ) from None
if value is None:
return
@@ -261,7 +309,7 @@ def override(self, keys: t.List[str], value: t.Any):
if key not in c:
raise BentoMLConfigException(
"Configuration override key is invalid, %s" % keys
- )
+ ) from None
c = c[key]
c[keys[-1]] = value
@@ -330,6 +378,9 @@ def session_id() -> str:
api_server_config = config.api_server
runners_config = config.runners
+ grpc = api_server_config.grpc
+ http = api_server_config.http
+
development_mode = providers.Static(True)
@providers.SingletonFactory
@@ -342,23 +393,20 @@ def serve_info() -> ServeInfo:
@providers.SingletonFactory
@staticmethod
def access_control_options(
- allow_origins: t.List[str] = Provide[
- config.api_server.cors.access_control_allow_origin
- ],
- allow_credentials: t.List[str] = Provide[
- config.api_server.cors.access_control_allow_credentials
- ],
- expose_headers: t.List[str] = Provide[
- config.api_server.cors.access_control_expose_headers
- ],
- allow_methods: t.List[str] = Provide[
- config.api_server.cors.access_control_allow_methods
- ],
- allow_headers: t.List[str] = Provide[
- config.api_server.cors.access_control_allow_headers
- ],
- max_age: int = Provide[config.api_server.cors.access_control_max_age],
- ) -> t.Dict[str, t.Union[t.List[str], int]]:
+ allow_origins: str | None = Provide[http.cors.access_control_allow_origin],
+ allow_credentials: bool
+ | None = Provide[http.cors.access_control_allow_credentials],
+ expose_headers: list[str]
+ | str
+ | None = Provide[http.cors.access_control_expose_headers],
+ allow_methods: list[str]
+ | str
+ | None = Provide[http.cors.access_control_allow_methods],
+ allow_headers: list[str]
+ | str
+ | None = Provide[http.cors.access_control_allow_headers],
+ max_age: int | None = Provide[http.cors.access_control_max_age],
+ ) -> dict[str, list[str] | str | int]:
kwargs = dict(
allow_origins=allow_origins,
allow_credentials=allow_credentials,
@@ -368,15 +416,15 @@ def access_control_options(
max_age=max_age,
)
- filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ filtered_kwargs: dict[str, list[str] | str | int] = {
+ k: v for k, v in kwargs.items() if v is not None
+ }
return filtered_kwargs
api_server_workers = providers.Factory[int](
lambda workers: workers or (multiprocessing.cpu_count() // 2) + 1,
- config.api_server.workers,
+ api_server_config.workers,
)
- service_port = config.api_server.port
- service_host = config.api_server.host
prometheus_multiproc_dir = providers.Factory[str](
os.path.join,
@@ -388,7 +436,7 @@ def access_control_options(
@staticmethod
def metrics_client(
multiproc_dir: str = Provide[prometheus_multiproc_dir],
- ) -> "PrometheusClient":
+ ) -> PrometheusClient:
from ..server.metrics.prometheus import PrometheusClient
return PrometheusClient(multiproc_dir=multiproc_dir)
@@ -414,6 +462,7 @@ def tracer_provider(
from opentelemetry.sdk.environment_variables import OTEL_SERVICE_NAME
from opentelemetry.sdk.environment_variables import OTEL_RESOURCE_ATTRIBUTES
+ from ...exceptions import InvalidArgument
from ..utils.telemetry import ParentBasedTraceIdRatio
if sample_rate is None:
@@ -442,15 +491,10 @@ def tracer_provider(
)
if tracer_type == "zipkin" and zipkin_server_url is not None:
- # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290
- from opentelemetry.exporter.zipkin.json import (
- ZipkinExporter, # type: ignore (no opentelemetry types)
- )
+ from opentelemetry.exporter.zipkin.json import ZipkinExporter
- exporter = ZipkinExporter( # type: ignore (no opentelemetry types)
- endpoint=zipkin_server_url,
- )
- provider.add_span_processor(BatchSpanProcessor(exporter)) # type: ignore (no opentelemetry types)
+ exporter = ZipkinExporter(endpoint=zipkin_server_url)
+ provider.add_span_processor(BatchSpanProcessor(exporter))
_check_sample_rate(sample_rate)
return provider
elif (
@@ -458,16 +502,12 @@ def tracer_provider(
and jaeger_server_address is not None
and jaeger_server_port is not None
):
- # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290
- from opentelemetry.exporter.jaeger.thrift import (
- JaegerExporter, # type: ignore (no opentelemetry types)
- )
+ from opentelemetry.exporter.jaeger.thrift import JaegerExporter
- exporter = JaegerExporter( # type: ignore (no opentelemetry types)
- agent_host_name=jaeger_server_address,
- agent_port=jaeger_server_port,
+ exporter = JaegerExporter(
+ agent_host_name=jaeger_server_address, agent_port=jaeger_server_port
)
- provider.add_span_processor(BatchSpanProcessor(exporter)) # type: ignore (no opentelemetry types)
+ provider.add_span_processor(BatchSpanProcessor(exporter))
_check_sample_rate(sample_rate)
return provider
elif (
@@ -476,21 +516,16 @@ def tracer_provider(
and otlp_server_url is not None
):
if otlp_server_protocol == "grpc":
- # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290
- from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
- OTLPSpanExporter, # type: ignore (no opentelemetry types)
- )
+ from opentelemetry.exporter.otlp.proto.grpc import trace_exporter
elif otlp_server_protocol == "http":
- # pylint: disable=no-name-in-module # https://github.com/open-telemetry/opentelemetry-python-contrib/issues/290
- from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
- OTLPSpanExporter, # type: ignore (no opentelemetry types)
- )
-
- exporter = OTLPSpanExporter( # type: ignore (no opentelemetry types)
- endpoint=otlp_server_url,
- )
- provider.add_span_processor(BatchSpanProcessor(exporter)) # type: ignore (no opentelemetry types)
+ from opentelemetry.exporter.otlp.proto.http import trace_exporter
+ else:
+ raise InvalidArgument(
+ f"Invalid otlp protocol: {otlp_server_protocol}"
+ ) from None
+ exporter = trace_exporter.OTLPSpanExporter(endpoint=otlp_server_url)
+ provider.add_span_processor(BatchSpanProcessor(exporter))
_check_sample_rate(sample_rate)
return provider
else:
@@ -499,9 +534,7 @@ def tracer_provider(
@providers.SingletonFactory
@staticmethod
def tracing_excluded_urls(
- excluded_urls: t.Optional[t.Union[str, t.List[str]]] = Provide[
- config.tracing.excluded_urls
- ],
+ excluded_urls: str | list[str] | None = Provide[config.tracing.excluded_urls],
):
from opentelemetry.util.http import ExcludeList
from opentelemetry.util.http import parse_excluded_urls
@@ -520,7 +553,7 @@ def tracing_excluded_urls(
@providers.SingletonFactory
@staticmethod
def duration_buckets(
- metrics: dict[str, t.Any] = Provide[config.api_server.metrics]
+ metrics: dict[str, t.Any] = Provide[api_server_config.metrics]
) -> tuple[float, ...]:
"""
Returns a tuple of duration buckets in seconds. If not explicitly configured,
diff --git a/bentoml/_internal/configuration/default_configuration.yaml b/bentoml/_internal/configuration/default_configuration.yaml
index 4d7cba78f4..8fefc0a7f5 100644
--- a/bentoml/_internal/configuration/default_configuration.yaml
+++ b/bentoml/_internal/configuration/default_configuration.yaml
@@ -1,42 +1,53 @@
api_server:
- port: 3000
- host: 0.0.0.0
- backlog: 2048
workers: 1
timeout: 60
- max_request_size: 20971520
+ backlog: 2048
metrics:
- enabled: True
+ enabled: true
namespace: bentoml_api_server
logging:
access:
- enabled: True
- request_content_length: True
- request_content_type: True
- response_content_length: True
- response_content_type: True
- cors:
- enabled: False
- access_control_allow_origin: Null
- access_control_allow_credentials: Null
- access_control_allow_methods: Null
- access_control_allow_headers: Null
- access_control_max_age: Null
- access_control_expose_headers: Null
+ enabled: true
+ request_content_length: true
+ request_content_type: true
+ response_content_length: true
+ response_content_type: true
+ http:
+ host: 0.0.0.0
+ port: 3000
+ cors:
+ enabled: false
+ access_control_allow_origin: ~
+ access_control_allow_credentials: ~
+ access_control_allow_methods: ~
+ access_control_allow_headers: ~
+ access_control_max_age: ~
+ access_control_expose_headers: ~
+ grpc:
+ host: 0.0.0.0
+ port: 3000
+ max_concurrent_streams: ~
+ maximum_concurrent_rpcs: ~
+ max_message_length: -1
+ reflection:
+ enabled: false
+ metrics:
+ host: 0.0.0.0
+ port: 3001
runners:
batching:
- enabled: True
+ enabled: true
max_batch_size: 100
max_latency_ms: 10000
resources: ~
logging:
access:
- enabled: True
- request_content_length: True
- request_content_type: True
- response_content_length: True
- response_content_type: True
+ enabled: true
+ request_content_length: true
+ request_content_type: true
+ response_content_length: true
+ response_content_type: true
metrics:
enabled: True
namespace: bentoml_runner
@@ -44,16 +55,16 @@ runners:
tracing:
type: zipkin
- sample_rate: Null
- excluded_urls: Null
+ sample_rate: ~
+ excluded_urls: ~
zipkin:
- url: Null
+ url: ~
jaeger:
- address: Null
- port: Null
+ address: ~
+ port: ~
otlp:
- protocol: Null
- url: Null
+ protocol: ~
+ url: ~
logging:
formatting:
diff --git a/bentoml/_internal/external_typing/__init__.py b/bentoml/_internal/external_typing/__init__.py
index de2e56dcf1..9087543eb1 100644
--- a/bentoml/_internal/external_typing/__init__.py
+++ b/bentoml/_internal/external_typing/__init__.py
@@ -4,19 +4,22 @@
if TYPE_CHECKING:
from typing import Literal
- from pandas import Series as _PdSeries
+ F = t.Callable[..., t.Any]
+
+ from pandas import Series as PdSeries
from pandas import DataFrame as PdDataFrame
+ from pandas._typing import Dtype as PdDType
+ from pandas._typing import DtypeArg as PdDTypeArg
from pyarrow.plasma import ObjectID
from pyarrow.plasma import PlasmaClient
- PdSeries = _PdSeries[t.Any]
DataFrameOrient = Literal["split", "records", "index", "columns", "values", "table"]
SeriesOrient = Literal["split", "records", "index", "table"]
# numpy is always required by bentoml
from numpy import generic as NpGeneric
from numpy.typing import NDArray as _NDArray
- from numpy.typing import DTypeLike as NpDTypeLike # type: ignore (incomplete numpy types)
+ from numpy.typing import DTypeLike as NpDTypeLike
NpNDArray = _NDArray[t.Any]
@@ -30,9 +33,13 @@
from .starlette import ASGIReceive
from .starlette import AsgiMiddleware
+ WSGIApp = t.Callable[[F, t.Mapping[str, t.Any]], t.Iterable[bytes]]
+
__all__ = [
"PdSeries",
"PdDataFrame",
+ "PdDType",
+ "PdDTypeArg",
"DataFrameOrient",
"SeriesOrient",
"ObjectID",
@@ -51,4 +58,6 @@
"ASGISend",
"ASGIReceive",
"ASGIMessage",
+ # misc
+ "WSGIApp",
]
diff --git a/bentoml/_internal/io_descriptors/base.py b/bentoml/_internal/io_descriptors/base.py
index 09c669b698..4c090ba1aa 100644
--- a/bentoml/_internal/io_descriptors/base.py
+++ b/bentoml/_internal/io_descriptors/base.py
@@ -8,10 +8,11 @@
if TYPE_CHECKING:
from types import UnionType
- from typing_extensions import Self
from starlette.requests import Request
from starlette.responses import Response
+ from bentoml.grpc.types import ProtoField
+
from ..types import LazyType
from ..context import InferenceApiContext as Context
from ..service.openapi.specification import Schema
@@ -39,20 +40,12 @@ class IODescriptor(ABC, t.Generic[IOType]):
HTTP_METHODS = ["POST"]
- _init_str: str = ""
-
_mime_type: str
-
- def __new__(cls: t.Type[Self], *args: t.Any, **kwargs: t.Any) -> Self:
- self = super().__new__(cls)
- # default mime type is application/json
- self._mime_type = "application/json"
- self._init_str = cls.__qualname__
-
- return self
+ _rpc_content_type: str = "application/grpc"
+ _proto_fields: tuple[ProtoField]
def __repr__(self) -> str:
- return self._init_str
+ return self.__class__.__qualname__
@abstractmethod
def input_type(self) -> InputType:
@@ -83,3 +76,11 @@ async def to_http_response(
self, obj: IOType, ctx: Context | None = None
) -> Response:
...
+
+ @abstractmethod
+ async def from_proto(self, field: t.Any) -> IOType:
+ ...
+
+ @abstractmethod
+ async def to_proto(self, obj: IOType) -> t.Any:
+ ...
diff --git a/bentoml/_internal/io_descriptors/file.py b/bentoml/_internal/io_descriptors/file.py
index 5a21d0ab90..452088578c 100644
--- a/bentoml/_internal/io_descriptors/file.py
+++ b/bentoml/_internal/io_descriptors/file.py
@@ -13,6 +13,7 @@
from .base import IODescriptor
from ..types import FileLike
from ..utils.http import set_cookies
+from ...exceptions import BadInput
from ...exceptions import BentoMLException
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
@@ -23,10 +24,17 @@
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+
from ..context import InferenceApiContext as Context
FileKind: t.TypeAlias = t.Literal["binaryio", "textio"]
-FileType: t.TypeAlias = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]]
+else:
+ from bentoml.grpc.utils import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
+
+FileType = t.Union[io.IOBase, t.IO[bytes], FileLike[bytes]]
class File(IODescriptor[FileType]):
@@ -100,16 +108,16 @@ async def predict(input_pdf: io.BytesIO[Any]) -> NDArray[Any]:
"""
+ _proto_fields = ("file",)
+
def __new__( # pylint: disable=arguments-differ # returning subclass from new
cls, kind: FileKind = "binaryio", mime_type: str | None = None
) -> File:
mime_type = mime_type if mime_type is not None else "application/octet-stream"
-
if kind == "binaryio":
res = object.__new__(BytesIOFile)
else:
raise ValueError(f"invalid File kind '{kind}'")
-
res._mime_type = mime_type
return res
@@ -134,11 +142,7 @@ def openapi_responses(self) -> OpenAPIResponse:
content={self._mime_type: MediaType(schema=self.openapi_schema())},
)
- async def to_http_response(
- self,
- obj: FileType,
- ctx: Context | None = None,
- ):
+ async def to_http_response(self, obj: FileType, ctx: Context | None = None):
if isinstance(obj, bytes):
body = obj
else:
@@ -155,6 +159,31 @@ async def to_http_response(
res = Response(body)
return res
+ async def to_proto(self, obj: FileType) -> pb.File:
+ from bentoml.grpc.utils import mimetype_to_filetype_pb_map
+
+ if isinstance(obj, bytes):
+ body = obj
+ else:
+ body = obj.read()
+
+ try:
+ kind = mimetype_to_filetype_pb_map()[self._mime_type]
+ except KeyError:
+ raise BadInput(
+ f"{self._mime_type} doesn't have a corresponding File 'kind'"
+ ) from None
+
+ return pb.File(kind=kind, content=body)
+
+ if TYPE_CHECKING:
+
+ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
+ ...
+
+ async def from_http_request(self, request: Request) -> t.IO[bytes]:
+ ...
+
class BytesIOFile(File):
async def from_http_request(self, request: Request) -> t.IO[bytes]:
@@ -183,3 +212,29 @@ async def from_http_request(self, request: Request) -> t.IO[bytes]:
raise BentoMLException(
f"File should have Content-Type '{self._mime_type}' or 'multipart/form-data', got {content_type} instead"
)
+
+ async def from_proto(self, field: pb.File | bytes) -> FileLike[bytes]:
+ from bentoml.grpc.utils import filetype_pb_to_mimetype_map
+
+ mapping = filetype_pb_to_mimetype_map()
+ # check if the request message has the correct field
+ if isinstance(field, bytes):
+ content = field
+ else:
+ assert isinstance(field, pb.File)
+ if field.kind:
+ try:
+ mime_type = mapping[field.kind]
+ if mime_type != self._mime_type:
+ raise BadInput(
+ f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'",
+ )
+ except KeyError:
+ raise BadInput(
+ f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb.File.FileType.items()]}",
+ ) from None
+ content = field.content
+ if not content:
+ raise BadInput("Content is empty!") from None
+
+ return FileLike[bytes](io.BytesIO(content), "")
diff --git a/bentoml/_internal/io_descriptors/image.py b/bentoml/_internal/io_descriptors/image.py
index af49128e0c..88502cea07 100644
--- a/bentoml/_internal/io_descriptors/image.py
+++ b/bentoml/_internal/io_descriptors/image.py
@@ -15,7 +15,6 @@
from ..utils.http import set_cookies
from ...exceptions import BadInput
from ...exceptions import InvalidArgument
-from ...exceptions import InternalServerError
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
@@ -25,8 +24,11 @@
if TYPE_CHECKING:
from types import UnionType
+ import PIL
import PIL.Image
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+
from .. import external_typing as ext
from ..context import InferenceApiContext as Context
@@ -34,17 +36,21 @@
"1", "CMYK", "F", "HSV", "I", "L", "LAB", "P", "RGB", "RGBA", "RGBX", "YCbCr"
]
else:
+ from bentoml.grpc.utils import import_generated_stubs
# NOTE: pillow-simd only benefits users who want to do preprocessing
# TODO: add options for users to choose between simd and native mode
- _exc = f"'Pillow' is required to use {__name__}. Install with: 'pip install -U Pillow'."
+ _exc = "'Pillow' is required to use the Image IO descriptor. Install it with: 'pip install -U Pillow'."
PIL = LazyLoader("PIL", globals(), "PIL", exc_msg=_exc)
PIL.Image = LazyLoader("PIL.Image", globals(), "PIL.Image", exc_msg=_exc)
+ pb, _ = import_generated_stubs()
+
+
# NOTES: we will keep type in quotation to avoid backward compatibility
# with numpy < 1.20, since we will use the latest stubs from the main branch of numpy.
# that enable a new way to type hint an ndarray.
-ImageType: t.TypeAlias = t.Union["PIL.Image.Image", "ext.NpNDArray"]
+ImageType = t.Union["PIL.Image.Image", "ext.NpNDArray"]
DEFAULT_PIL_MODE = "RGB"
@@ -135,32 +141,26 @@ async def predict_image(f: Image) -> NDArray[Any]:
:obj:`Image`: IO Descriptor that either a :code:`PIL.Image.Image` or a :code:`np.ndarray` representing an image.
"""
- MIME_EXT_MAPPING: t.Dict[str, str] = {}
+ MIME_EXT_MAPPING: dict[str, str] = {}
+
+ _proto_fields = ("file",)
def __init__(
self,
pilmode: _Mode | None = DEFAULT_PIL_MODE,
mime_type: str = "image/jpeg",
):
- try:
- import PIL.Image
- except ImportError:
- raise InternalServerError(
- "`Pillow` is required to use {__name__}\n Instructions: `pip install -U Pillow`"
- )
PIL.Image.init()
self.MIME_EXT_MAPPING.update({v: k for k, v in PIL.Image.MIME.items()})
if mime_type.lower() not in self.MIME_EXT_MAPPING: # pragma: no cover
raise InvalidArgument(
- f"Invalid Image mime_type '{mime_type}', "
- f"Supported mime types are {', '.join(PIL.Image.MIME.values())} "
- )
+ f"Invalid Image mime_type '{mime_type}'. Supported mime types are {', '.join(PIL.Image.MIME.values())}."
+ ) from None
if pilmode is not None and pilmode not in PIL.Image.MODES: # pragma: no cover
raise InvalidArgument(
- f"Invalid Image pilmode '{pilmode}', "
- f"Supported PIL modes are {', '.join(PIL.Image.MODES)} "
- )
+ f"Invalid Image pilmode '{pilmode}'. Supported PIL modes are {', '.join(PIL.Image.MODES)}."
+ ) from None
self._mime_type = mime_type.lower()
self._pilmode: _Mode | None = pilmode
@@ -197,13 +197,12 @@ async def from_http_request(self, request: Request) -> ImageType:
bytes_ = await request.body()
else:
raise BadInput(
- f"{self.__class__.__name__} should get `multipart/form-data`, "
- f"`{self._mime_type}` or `image/*`, got {content_type} instead"
+ f"{self.__class__.__name__} should get 'multipart/form-data', '{self._mime_type}' or 'image/*', got '{content_type}' instead."
)
try:
return PIL.Image.open(io.BytesIO(bytes_))
- except PIL.UnidentifiedImageError:
- raise BadInput("Failed reading image file uploaded") from None
+ except PIL.UnidentifiedImageError as e:
+ raise BadInput(f"Failed reading image file uploaded: {e}") from None
async def to_http_response(
self, obj: ImageType, ctx: Context | None = None
@@ -213,10 +212,9 @@ async def to_http_response(
elif LazyType[PIL.Image.Image]("PIL.Image.Image").isinstance(obj):
image = obj
else:
- raise InternalServerError(
- f"Unsupported Image type received: {type(obj)}, `{self.__class__.__name__}`"
- " only supports `np.ndarray` and `PIL.Image`"
- )
+ raise BadInput(
+ f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'."
+ ) from None
filename = f"output.{self._format.lower()}"
ret = io.BytesIO()
@@ -248,3 +246,52 @@ async def to_http_response(
media_type=self._mime_type,
headers={"content-disposition": content_disposition},
)
+
+ async def from_proto(self, field: pb.File | bytes) -> ImageType:
+ from bentoml.grpc.utils import filetype_pb_to_mimetype_map
+
+ mapping = filetype_pb_to_mimetype_map()
+ # check if the request message has the correct field
+ if isinstance(field, bytes):
+ content = field
+ else:
+ assert isinstance(field, pb.File)
+ if field.kind:
+ try:
+ mime_type = mapping[field.kind]
+ if mime_type != self._mime_type:
+ raise BadInput(
+ f"Inferred mime_type from 'kind' is '{mime_type}', while '{repr(self)}' is expecting '{self._mime_type}'",
+ )
+ except KeyError:
+ raise BadInput(
+ f"{field.kind} is not a valid File kind. Accepted file kind: {[names for names,_ in pb.File.FileType.items()]}",
+ ) from None
+ content = field.content
+ if not content:
+ raise BadInput("Content is empty!") from None
+
+ return PIL.Image.open(io.BytesIO(content))
+
+ async def to_proto(self, obj: ImageType) -> pb.File:
+ from bentoml.grpc.utils import mimetype_to_filetype_pb_map
+
+ if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(obj):
+ image = PIL.Image.fromarray(obj, mode=self._pilmode)
+ elif LazyType["PIL.Image.Image"]("PIL.Image.Image").isinstance(obj):
+ image = obj
+ else:
+ raise BadInput(
+ f"Unsupported Image type received: '{type(obj)}', the Image IO descriptor only supports 'np.ndarray' and 'PIL.Image'.",
+ ) from None
+ ret = io.BytesIO()
+ image.save(ret, format=self._format)
+
+ try:
+ kind = mimetype_to_filetype_pb_map()[self._mime_type]
+ except KeyError:
+ raise BadInput(
+ f"{self._mime_type} doesn't have a corresponding File 'kind'",
+ ) from None
+
+ return pb.File(kind=kind, content=ret.getvalue())
diff --git a/bentoml/_internal/io_descriptors/json.py b/bentoml/_internal/io_descriptors/json.py
index 648cf7de7e..c17bffc7ab 100644
--- a/bentoml/_internal/io_descriptors/json.py
+++ b/bentoml/_internal/io_descriptors/json.py
@@ -10,12 +10,14 @@
from starlette.requests import Request
from starlette.responses import Response
+from bentoml.exceptions import BadInput
+
from .base import IODescriptor
from ..types import LazyType
from ..utils import LazyLoader
from ..utils import bentoml_cattr
+from ..utils.pkg import pkg_version_info
from ..utils.http import set_cookies
-from ...exceptions import BadInput
from ..service.openapi import REF_PREFIX
from ..service.openapi import SUCCESS_DESCRIPTION
from ..service.openapi.specification import Schema
@@ -28,26 +30,28 @@
import pydantic
import pydantic.schema as schema
+ from google.protobuf import struct_pb2
from .. import external_typing as ext
from ..context import InferenceApiContext as Context
- _Serializable = ext.NpNDArray | ext.PdDataFrame | t.Type[pydantic.BaseModel] | type
else:
_exc_msg = "'pydantic' must be installed to use 'pydantic_model'. Install with 'pip install pydantic'."
pydantic = LazyLoader("pydantic", globals(), "pydantic", exc_msg=_exc_msg)
schema = LazyLoader("schema", globals(), "pydantic.schema", exc_msg=_exc_msg)
+ # lazy load our proto generated.
+ struct_pb2 = LazyLoader("struct_pb2", globals(), "google.protobuf.struct_pb2")
+ # lazy load numpy for processing ndarray.
+ np = LazyLoader("np", globals(), "numpy")
JSONType = t.Union[str, t.Dict[str, t.Any], "pydantic.BaseModel", None]
-MIME_TYPE_JSON = "application/json"
-
logger = logging.getLogger(__name__)
class DefaultJsonEncoder(json.JSONEncoder):
- def default(self, o: _Serializable) -> t.Any:
+ def default(self, o: type) -> t.Any:
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(o):
@@ -63,9 +67,8 @@ def default(self, o: _Serializable) -> t.Any:
if "__root__" in obj_dict:
obj_dict = obj_dict.get("__root__")
return obj_dict
- if attr.has(o): # type: ignore (trivial case)
+ if attr.has(o):
return bentoml_cattr.unstructure(o)
-
return super().default(o)
@@ -168,7 +171,9 @@ def classify(input_data: IrisFeatures) -> NDArray[Any]:
:obj:`JSON`: IO Descriptor that represents JSON format.
"""
- _mime_type: str = MIME_TYPE_JSON
+ _proto_fields = ("json",)
+ # default mime type is application/json
+ _mime_type = "application/json"
def __init__(
self,
@@ -177,7 +182,11 @@ def __init__(
validate_json: bool | None = None,
json_encoder: t.Type[json.JSONEncoder] = DefaultJsonEncoder,
):
- if pydantic_model:
+ if pydantic_model is not None:
+ if pkg_version_info("pydantic")[0] >= 2:
+ raise BadInput(
+ "pydantic 2.x is not yet supported. Add upper bound to 'pydantic': 'pip install \"pydantic<2\"'"
+ ) from None
assert issubclass(
pydantic_model, pydantic.BaseModel
), "'pydantic_model' must be a subclass of 'pydantic.BaseModel'."
@@ -236,12 +245,12 @@ async def from_http_request(self, request: Request) -> JSONType:
except json.JSONDecodeError as e:
raise BadInput(f"Invalid JSON input received: {e}") from None
- if self._pydantic_model is not None:
+ if self._pydantic_model:
try:
pydantic_model = self._pydantic_model.parse_obj(json_obj)
return pydantic_model
except pydantic.ValidationError as e:
- raise BadInput(f"Invalid JSON input received: {e}") from e
+ raise BadInput(f"Invalid JSON input received: {e}") from None
else:
return json_obj
@@ -269,11 +278,57 @@ async def to_http_response(
if ctx is not None:
res = Response(
json_str,
- media_type=MIME_TYPE_JSON,
+ media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
- return Response(json_str, media_type=MIME_TYPE_JSON)
+ return Response(json_str, media_type=self._mime_type)
+
+ async def from_proto(self, field: struct_pb2.Value | bytes) -> JSONType:
+ from google.protobuf.json_format import MessageToDict
+
+ if isinstance(field, bytes):
+ content = field
+ if self._pydantic_model:
+ try:
+ return self._pydantic_model.parse_raw(content)
+ except pydantic.ValidationError as e:
+ raise BadInput(f"Invalid JSON input received: {e}") from None
+ try:
+ parsed = json.loads(content)
+ except json.JSONDecodeError as e:
+ raise BadInput(f"Invalid JSON input received: {e}") from None
+ else:
+ assert isinstance(field, struct_pb2.Value)
+ parsed = MessageToDict(field, preserving_proto_field_name=True)
+
+ if self._pydantic_model:
+ try:
+ return self._pydantic_model.parse_obj(parsed)
+ except pydantic.ValidationError as e:
+ raise BadInput(f"Invalid JSON input received: {e}") from None
+ return parsed
+
+ async def to_proto(self, obj: JSONType) -> struct_pb2.Value:
+ if LazyType["pydantic.BaseModel"]("pydantic.BaseModel").isinstance(obj):
+ obj = obj.dict()
+ msg = struct_pb2.Value()
+ # To handle None cases.
+ if obj is not None:
+ from google.protobuf.json_format import ParseDict
+
+ if isinstance(obj, (dict, str, list, float, int, bool)):
+ # ParseDict handles google.protobuf.Struct type
+ # directly if given object has a supported type
+ ParseDict(obj, msg)
+ else:
+ # If given object doesn't have a supported type, we will
+ # use given JSON encoder to convert it to dictionary
+ # and then parse it to google.protobuf.Struct.
+ # Note that if a custom JSON encoder is used, it mustn't
+ # take any arguments.
+ ParseDict(self._json_encoder().default(obj), msg)
+ return msg
diff --git a/bentoml/_internal/io_descriptors/multipart.py b/bentoml/_internal/io_descriptors/multipart.py
index c6f9190dd8..afc62190b4 100644
--- a/bentoml/_internal/io_descriptors/multipart.py
+++ b/bentoml/_internal/io_descriptors/multipart.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import typing as t
+import asyncio
from typing import TYPE_CHECKING
from starlette.requests import Request
@@ -21,11 +22,17 @@
if TYPE_CHECKING:
from types import UnionType
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+
from ..types import LazyType
from ..context import InferenceApiContext as Context
+else:
+ from bentoml.grpc.utils import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
-class Multipart(IODescriptor[t.Any]):
+class Multipart(IODescriptor[t.Dict[str, t.Any]]):
"""
:obj:`Multipart` defines API specification for the inputs/outputs of a Service, where inputs/outputs
of a Service can receive/send a **multipart** request/responses as specified in your API function signature.
@@ -153,14 +160,18 @@ async def predict(
:obj:`Multipart`: IO Descriptor that represents a Multipart request/response.
"""
+ _proto_fields = ("multipart",)
+ _mime_type = "multipart/form-data"
+
def __init__(self, **inputs: IODescriptor[t.Any]):
- for descriptor in inputs.values():
- if isinstance(descriptor, Multipart): # pragma: no cover
- raise InvalidArgument(
- "Multipart IO can not contain nested Multipart IO descriptor"
- )
- self._inputs: dict[str, t.Any] = inputs
- self._mime_type = "multipart/form-data"
+ if any(isinstance(descriptor, Multipart) for descriptor in inputs.values()):
+ raise InvalidArgument(
+ "Multipart IO can not contain nested Multipart IO descriptor"
+ ) from None
+ self._inputs = inputs
+
+ def __repr__(self) -> str:
+ return f"Multipart({','.join([f'{k}={v}' for k,v in zip(self._inputs, map(repr, self._inputs.values()))])})"
def input_type(
self,
@@ -171,7 +182,7 @@ def input_type(
if isinstance(inp_type, dict):
raise TypeError(
"A multipart descriptor cannot take a multi-valued I/O descriptor as input"
- )
+ ) from None
res[k] = inp_type
return res
@@ -202,22 +213,68 @@ async def from_http_request(self, request: Request) -> dict[str, t.Any]:
if ctype != b"multipart/form-data":
raise BentoMLException(
f"{self.__class__.__name__} only accepts `multipart/form-data` as Content-Type header, got {ctype} instead."
- )
-
- res: dict[str, t.Any] = dict()
- reqs = await populate_multipart_requests(request)
+ ) from None
- for k, i in self._inputs.items():
- req = reqs[k]
- v = await i.from_http_request(req)
- res[k] = v
- return res
+ to_populate = zip(
+ self._inputs.values(), (await populate_multipart_requests(request)).values()
+ )
+ reqs = await asyncio.gather(
+ *tuple(io_.from_http_request(req) for io_, req in to_populate)
+ )
+ return dict(zip(self._inputs, reqs))
async def to_http_response(
self, obj: dict[str, t.Any], ctx: Context | None = None
) -> Response:
- res_mapping: dict[str, Response] = {}
- for k, io_ in self._inputs.items():
- data = obj[k]
- res_mapping[k] = await io_.to_http_response(data, ctx)
- return await concat_to_multipart_response(res_mapping, ctx)
+ resps = await asyncio.gather(
+ *tuple(
+ io_.to_http_response(obj[key], ctx) for key, io_ in self._inputs.items()
+ )
+ )
+ return await concat_to_multipart_response(dict(zip(self._inputs, resps)), ctx)
+
+ def validate_input_mapping(self, field: t.MutableMapping[str, t.Any]) -> None:
+ if len(set(field) - set(self._inputs)) != 0:
+ raise InvalidArgument(
+ f"'{repr(self)}' accepts the following keys: {set(self._inputs)}. Given {field.__class__.__qualname__} has invalid fields: {set(field) - set(self._inputs)}",
+ ) from None
+
+ async def from_proto(self, field: pb.Multipart) -> dict[str, t.Any]:
+ if isinstance(field, bytes):
+ raise InvalidArgument(
+ f"cannot use 'serialized_bytes' with {self.__class__.__name__}"
+ ) from None
+ message = field.fields
+ self.validate_input_mapping(message)
+ reqs = await asyncio.gather(
+ *tuple(
+ io_.from_proto(getattr(input_pb, io_._proto_fields[0]))
+ for io_, input_pb in self.io_fields_mapping(message).items()
+ )
+ )
+ return dict(zip(message, reqs))
+
+ def io_fields_mapping(
+ self, message: t.MutableMapping[str, pb.Part]
+ ) -> dict[IODescriptor[t.Any], pb.Part]:
+ return {io_: part for io_, part in zip(self._inputs.values(), message.values())}
+
+ async def to_proto(self, obj: dict[str, t.Any]) -> pb.Multipart:
+ self.validate_input_mapping(obj)
+ resps = await asyncio.gather(
+ *tuple(
+ io_.to_proto(data)
+ for io_, data in zip(self._inputs.values(), obj.values())
+ )
+ )
+ return pb.Multipart(
+ fields={
+ key: pb.Part(
+ **{
+ io_._proto_fields[0]: resp
+ for io_, resp in zip(self._inputs.values(), resps)
+ }
+ )
+ for key in obj
+ }
+ )
diff --git a/bentoml/_internal/io_descriptors/numpy.py b/bentoml/_internal/io_descriptors/numpy.py
index 9ae561d47a..aa730a7dea 100644
--- a/bentoml/_internal/io_descriptors/numpy.py
+++ b/bentoml/_internal/io_descriptors/numpy.py
@@ -4,19 +4,20 @@
import typing as t
import logging
from typing import TYPE_CHECKING
+from functools import lru_cache
from starlette.requests import Request
from starlette.responses import Response
from .base import IODescriptor
-from .json import MIME_TYPE_JSON
from ..types import LazyType
+from ..utils import LazyLoader
from ..utils.http import set_cookies
from ...exceptions import BadInput
+from ...exceptions import InvalidArgument
from ...exceptions import BentoMLException
-from ...exceptions import InternalServerError
+from ...exceptions import UnprocessableEntity
from ..service.openapi import SUCCESS_DESCRIPTION
-from ..utils.lazy_loader import LazyLoader
from ..service.openapi.specification import Schema
from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
@@ -25,18 +26,79 @@
if TYPE_CHECKING:
import numpy as np
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+
from .. import external_typing as ext
from ..context import InferenceApiContext as Context
else:
+ from bentoml.grpc.utils import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
np = LazyLoader("np", globals(), "numpy")
logger = logging.getLogger(__name__)
-def _is_matched_shape(
- left: t.Optional[t.Tuple[int, ...]],
- right: t.Optional[t.Tuple[int, ...]],
-) -> bool: # pragma: no cover
+# TODO: support the following types for for protobuf message:
+# - support complex64, complex128, object and struct types
+# - BFLOAT16, QINT32, QINT16, QUINT16, QINT8, QUINT8
+#
+# For int16, uint16, int8, uint8 -> specify types in NumpyNdarray + using int_values.
+#
+# For bfloat16, half (float16) -> specify types in NumpyNdarray + using float_values.
+#
+# for string_values, use dict[pb.NDArray.DType.ValueType, ext.NpDTypeLike]:
+ # pb.NDArray.Dtype -> np.dtype
+ return {
+ pb.NDArray.DTYPE_FLOAT: np.dtype("float32"),
+ pb.NDArray.DTYPE_DOUBLE: np.dtype("double"),
+ pb.NDArray.DTYPE_INT32: np.dtype("int32"),
+ pb.NDArray.DTYPE_INT64: np.dtype("int64"),
+ pb.NDArray.DTYPE_UINT32: np.dtype("uint32"),
+ pb.NDArray.DTYPE_UINT64: np.dtype("uint64"),
+ pb.NDArray.DTYPE_BOOL: np.dtype("bool"),
+ pb.NDArray.DTYPE_STRING: np.dtype(" dict[pb.NDArray.DType.ValueType, str]:
+ return {k: npdtype_to_fieldpb_map()[v] for k, v in dtypepb_to_npdtype_map().items()}
+
+
+@lru_cache(maxsize=1)
+def fieldpb_to_npdtype_map() -> dict[str, ext.NpDTypeLike]:
+ # str -> np.dtype
+ return {k: np.dtype(v) for k, v in FIELDPB_TO_NPDTYPE_NAME_MAP.items()}
+
+
+@lru_cache(maxsize=1)
+def npdtype_to_dtypepb_map() -> dict[ext.NpDTypeLike, pb.NDArray.DType.ValueType]:
+ # np.dtype -> pb.NDArray.Dtype
+ return {v: k for k, v in dtypepb_to_npdtype_map().items()}
+
+
+@lru_cache(maxsize=1)
+def npdtype_to_fieldpb_map() -> dict[ext.NpDTypeLike, str]:
+ # np.dtype -> str
+ return {v: k for k, v in fieldpb_to_npdtype_map().items()}
+
+
+def _is_matched_shape(left: tuple[int, ...], right: tuple[int, ...]) -> bool:
if (left is None) or (right is None):
return False
@@ -52,6 +114,7 @@ def _is_matched_shape(
return True
+# TODO: when updating docs, add examples with gRPCurl
class NumpyNdarray(IODescriptor["ext.NpNDArray"]):
"""
:obj:`NumpyNdarray` defines API specification for the inputs/outputs of a Service, where
@@ -135,6 +198,9 @@ async def predict(input_array: np.ndarray) -> np.ndarray:
:obj:`~bentoml._internal.io_descriptors.IODescriptor`: IO Descriptor that represents a :code:`np.ndarray`.
"""
+ _proto_fields = ("ndarray",)
+ _mime_type = "application/json"
+
def __init__(
self,
dtype: str | ext.NpDTypeLike | None = None,
@@ -143,15 +209,11 @@ def __init__(
enforce_shape: bool = False,
):
if dtype and not isinstance(dtype, np.dtype):
- # Convert from primitive type or type string, e.g.:
- # np.dtype(float)
- # np.dtype("float64")
+ # Convert from primitive type or type string, e.g.: np.dtype(float) or np.dtype("float64")
try:
dtype = np.dtype(dtype)
except TypeError as e:
- raise BentoMLException(
- f'NumpyNdarray: Invalid dtype "{dtype}": {e}'
- ) from e
+ raise UnprocessableEntity(f'Invalid dtype "{dtype}": {e}') from None
self._dtype = dtype
self._shape = shape
@@ -160,6 +222,15 @@ def __init__(
self._sample_input = None
+ if self._enforce_dtype and not self._dtype:
+ raise InvalidArgument(
+ "'dtype' must be specified when 'enforce_dtype=True'"
+ ) from None
+ if self._enforce_shape and not self._shape:
+ raise InvalidArgument(
+ "'shape' must be specified when 'enforce_shape=True'"
+ ) from None
+
def _openapi_types(self) -> str:
# convert numpy dtypes to openapi compatible types.
var_type = "integer"
@@ -195,7 +266,9 @@ def openapi_components(self) -> dict[str, t.Any] | None:
def openapi_example(self) -> t.Any:
if self.sample_input is not None:
if isinstance(self.sample_input, np.generic):
- raise BadInput("NumpyNdarray: sample_input must be a numpy array.")
+ raise BadInput(
+ "NumpyNdarray: sample_input must be a numpy array."
+ ) from None
return self.sample_input.tolist()
return
@@ -219,33 +292,33 @@ def openapi_responses(self) -> OpenAPIResponse:
},
)
- def _verify_ndarray(
- self, obj: ext.NpNDArray, exception_cls: t.Type[Exception] = BadInput
+ def validate_array(
+ self, arr: ext.NpNDArray, exception_cls: t.Type[Exception] = BadInput
) -> ext.NpNDArray:
- if self._dtype is not None and self._dtype != obj.dtype:
+ if self._dtype is not None and self._dtype != arr.dtype:
# ‘same_kind’ means only safe casts or casts within a kind, like float64
# to float32, are allowed.
- if np.can_cast(obj.dtype, self._dtype, casting="same_kind"):
- obj = obj.astype(self._dtype, casting="same_kind") # type: ignore
+ if np.can_cast(arr.dtype, self._dtype, casting="same_kind"):
+ arr = arr.astype(self._dtype, casting="same_kind") # type: ignore
else:
- msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{obj.dtype}" was received.'
+ msg = f'{self.__class__.__name__}: Expecting ndarray of dtype "{self._dtype}", but "{arr.dtype}" was received.'
if self._enforce_dtype:
- raise exception_cls(msg)
+ raise exception_cls(msg) from None
else:
logger.debug(msg)
- if self._shape is not None and not _is_matched_shape(self._shape, obj.shape):
- msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{obj.shape}" was received.'
+ if self._shape is not None and not _is_matched_shape(self._shape, arr.shape):
+ msg = f'{self.__class__.__name__}: Expecting ndarray of shape "{self._shape}", but "{arr.shape}" was received.'
if self._enforce_shape:
- raise exception_cls(msg)
+ raise exception_cls(msg) from None
try:
- obj = obj.reshape(self._shape)
+ arr = arr.reshape(self._shape)
except ValueError as e:
logger.debug(f"{msg} Failed to reshape: {e}.")
- return obj
+ return arr
- async def from_http_request(self, request: Request) -> "ext.NpNDArray":
+ async def from_http_request(self, request: Request) -> ext.NpNDArray:
"""
Process incoming requests and convert incoming objects to ``numpy.ndarray``.
@@ -262,7 +335,7 @@ async def from_http_request(self, request: Request) -> "ext.NpNDArray":
except ValueError:
res = np.array(obj)
- return self._verify_ndarray(res)
+ return self.validate_array(res)
async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None):
"""
@@ -276,18 +349,19 @@ async def to_http_response(self, obj: ext.NpNDArray, ctx: Context | None = None)
HTTP Response of type ``starlette.responses.Response``. This can
be accessed via cURL or any external web traffic.
"""
- obj = self._verify_ndarray(obj, InternalServerError)
+ obj = self.validate_array(obj)
+
if ctx is not None:
res = Response(
json.dumps(obj.tolist()),
- media_type=MIME_TYPE_JSON,
+ media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
- return Response(json.dumps(obj.tolist()), media_type=MIME_TYPE_JSON)
+ return Response(json.dumps(obj.tolist()), media_type=self._mime_type)
@classmethod
def from_sample(
@@ -336,8 +410,8 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
"""
if isinstance(sample_input, np.generic):
raise BentoMLException(
- "NumpyNdarray.from_sample() expects a numpy.array, not numpy.generic."
- )
+ "'NumpyNdarray.from_sample()' expects a 'numpy.array', not 'numpy.generic'."
+ ) from None
inst = cls(
dtype=sample_input.dtype,
@@ -348,3 +422,88 @@ async def predict(input: NDArray[np.int16]) -> NDArray[Any]:
inst.sample_input = sample_input
return inst
+
+ async def from_proto(self, field: pb.NDArray | bytes) -> ext.NpNDArray:
+ """
+ Process incoming protobuf request and convert it to ``numpy.ndarray``
+
+ Args:
+ request: Incoming RPC request message.
+ context: grpc.ServicerContext
+
+ Returns:
+ a ``numpy.ndarray`` object. This can then be used
+ inside users defined logics.
+ """
+ if isinstance(field, bytes):
+ if not self._dtype:
+ raise BadInput(
+ "'serialized_bytes' requires specifying 'dtype'."
+ ) from None
+
+ dtype: ext.NpDTypeLike = self._dtype
+ array = np.frombuffer(field, dtype=self._dtype)
+ else:
+ assert isinstance(field, pb.NDArray)
+ if field.dtype == pb.NDArray.DTYPE_UNSPECIFIED:
+ dtype = None
+ else:
+ try:
+ dtype = dtypepb_to_npdtype_map()[field.dtype]
+ except KeyError:
+ raise BadInput(f"{field.dtype} is invalid.") from None
+ if dtype is not None:
+ values_array = getattr(field, dtypepb_to_fieldpb_map()[field.dtype])
+ else:
+ fieldpb = [
+ f.name for f, _ in field.ListFields() if f.name.endswith("_values")
+ ]
+ if len(fieldpb) == 0:
+ # input message doesn't have any fields.
+ return np.empty(shape=field.shape or 0)
+ elif len(fieldpb) > 1:
+ # when there are more than two values provided in the proto.
+ raise BadInput(
+ f"Array contents can only be one of given values key. Use one of '{fieldpb}' instead.",
+ ) from None
+
+ dtype: ext.NpDTypeLike = fieldpb_to_npdtype_map()[fieldpb[0]]
+ values_array = getattr(field, fieldpb[0])
+ try:
+ array = np.array(values_array, dtype=dtype)
+ except ValueError:
+ array = np.array(values_array)
+
+ if field.shape:
+ array = np.reshape(array, field.shape)
+
+ return self.validate_array(array)
+
+ async def to_proto(self, obj: ext.NpNDArray) -> pb.NDArray:
+ """
+ Process given objects and convert it to grpc protobuf response.
+
+ Args:
+ obj: `np.ndarray` that will be serialized to protobuf
+ context: grpc.aio.ServicerContext from grpc.aio.Server
+ Returns:
+ `io_descriptor_pb2.Array`:
+ Protobuf representation of given `np.ndarray`
+ """
+ try:
+ obj = self.validate_array(obj)
+ except BadInput as e:
+ raise e from None
+
+ try:
+ fieldpb = npdtype_to_fieldpb_map()[obj.dtype]
+ dtypepb = npdtype_to_dtypepb_map()[obj.dtype]
+ return pb.NDArray(
+ dtype=dtypepb,
+ shape=tuple(obj.shape),
+ **{fieldpb: obj.ravel().tolist()},
+ )
+ except KeyError:
+ raise BadInput(
+ f"Unsupported dtype '{obj.dtype}' for response message.",
+ ) from None
diff --git a/bentoml/_internal/io_descriptors/pandas.py b/bentoml/_internal/io_descriptors/pandas.py
index b01256bfc5..135a44b0ce 100644
--- a/bentoml/_internal/io_descriptors/pandas.py
+++ b/bentoml/_internal/io_descriptors/pandas.py
@@ -4,19 +4,20 @@
import typing as t
import logging
import functools
-import importlib.util
from enum import Enum
from typing import TYPE_CHECKING
+from concurrent.futures import ThreadPoolExecutor
from starlette.requests import Request
from starlette.responses import Response
from .base import IODescriptor
-from .json import MIME_TYPE_JSON
from ..types import LazyType
+from ..utils.pkg import find_spec
from ..utils.http import set_cookies
from ...exceptions import BadInput
from ...exceptions import InvalidArgument
+from ...exceptions import UnprocessableEntity
from ...exceptions import MissingDependencyException
from ..service.openapi import SUCCESS_DESCRIPTION
from ..utils.lazy_loader import LazyLoader
@@ -28,15 +29,20 @@
if TYPE_CHECKING:
import pandas as pd
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+
from .. import external_typing as ext
from ..context import InferenceApiContext as Context
else:
+ from bentoml.grpc.utils import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
pd = LazyLoader(
"pd",
globals(),
"pandas",
- exc_msg="`pandas` is required to use PandasDataFrame or PandasSeries. Install with `pip install -U pandas`",
+ exc_msg='pandas" is required to use PandasDataFrame or PandasSeries. Install with "pip install -U pandas"',
)
logger = logging.getLogger(__name__)
@@ -45,9 +51,9 @@
# Check for parquet support
@functools.lru_cache(maxsize=1)
def get_parquet_engine() -> str:
- if importlib.util.find_spec("pyarrow") is not None:
+ if find_spec("pyarrow") is not None:
return "pyarrow"
- elif importlib.util.find_spec("fastparquet") is not None:
+ elif find_spec("fastparquet") is not None:
return "fastparquet"
else:
logger.warning(
@@ -72,9 +78,7 @@ def _openapi_types(item: str) -> str: # pragma: no cover
return "object"
-def _openapi_schema(
- dtype: bool | dict[str, t.Any] | None
-) -> Schema: # pragma: no cover
+def _openapi_schema(dtype: bool | ext.PdDTypeArg | None) -> Schema: # pragma: no cover
if isinstance(dtype, dict):
return Schema(
type="object",
@@ -111,15 +115,12 @@ def _infer_serialization_format_from_request(
return SerializationFormat.CSV
elif content_type:
logger.debug(
- "Unknown content-type (%s), falling back to %s serialization format.",
- content_type,
- default_format,
+ f"Unknown content-type ('{content_type}'), falling back to '{default_format}' serialization format.",
)
return default_format
else:
logger.debug(
- "Content-type not specified, falling back to %s serialization format.",
- default_format,
+ f"Content-type not specified, falling back to '{default_format}' serialization format.",
)
return default_format
@@ -203,7 +204,7 @@ def predict(input_arr):
- :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``}
- :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}]
- :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}}
- - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}}
+ - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}}
- :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays
columns: List of columns name that users wish to update.
apply_column_names: Whether to update incoming DataFrame columns. If :code:`apply_column_names=True`,
@@ -248,12 +249,14 @@ def predict(input_df: pd.DataFrame) -> pd.DataFrame:
:obj:`PandasDataFrame`: IO Descriptor that represents a :code:`pd.DataFrame`.
"""
+ _proto_fields = ("dataframe",)
+
def __init__(
self,
orient: ext.DataFrameOrient = "records",
- apply_column_names: bool = False,
columns: list[str] | None = None,
- dtype: bool | dict[str, t.Any] | None = None,
+ apply_column_names: bool = False,
+ dtype: bool | ext.PdDTypeArg | None = None,
enforce_dtype: bool = False,
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
@@ -324,49 +327,21 @@ async def from_http_request(self, request: Request) -> ext.PdDataFrame:
_validate_serialization_format(serialization_format)
obj = await request.body()
- if self._enforce_dtype:
- if self._dtype is None:
- logger.warning(
- "`dtype` is None or undefined, while `enforce_dtype`=True"
- )
- # TODO(jiang): check dtype
-
if serialization_format is SerializationFormat.JSON:
+ assert not isinstance(self._dtype, bool)
res = pd.read_json(io.BytesIO(obj), dtype=self._dtype, orient=self._orient)
elif serialization_format is SerializationFormat.PARQUET:
res = pd.read_parquet(io.BytesIO(obj), engine=get_parquet_engine())
elif serialization_format is SerializationFormat.CSV:
+ assert not isinstance(self._dtype, bool)
res: ext.PdDataFrame = pd.read_csv(io.BytesIO(obj), dtype=self._dtype)
else:
raise InvalidArgument(
f"Unknown serialization format ({serialization_format})."
- )
+ ) from None
assert isinstance(res, pd.DataFrame)
-
- if self._apply_column_names:
- if self._columns is None:
- logger.warning(
- "`columns` is None or undefined, while `apply_column_names`=True"
- )
- elif len(self._columns) != res.shape[1]:
- raise BadInput(
- "length of `columns` does not match the columns of incoming data"
- )
- else:
- res.columns = pd.Index(self._columns)
- if self._enforce_shape:
- if self._shape is None:
- logger.warning(
- "`shape` is None or undefined, while `enforce_shape`=True"
- )
- else:
- assert all(
- left == right
- for left, right in zip(self._shape, res.shape) # type: ignore (shape type)
- if left != -1 and right != -1
- ), f"incoming has shape {res.shape} where enforced shape to be {self._shape}"
- return res
+ return self.validate_dataframe(res)
async def to_http_response(
self, obj: ext.PdDataFrame, ctx: Context | None = None
@@ -381,6 +356,7 @@ async def to_http_response(
HTTP Response of type `starlette.responses.Response`. This can
be accessed via cURL or any external web traffic.
"""
+ obj = self.validate_dataframe(obj)
# For the response it doesn't make sense to enforce the same serialization format as specified
# by the request's headers['content-type']. Instead we simply use the _default_format.
@@ -399,7 +375,7 @@ async def to_http_response(
else:
raise InvalidArgument(
f"Unknown serialization format ({serialization_format})."
- )
+ ) from None
if ctx is not None:
res = Response(
@@ -420,7 +396,7 @@ def from_sample(
orient: ext.DataFrameOrient = "records",
apply_column_names: bool = True,
enforce_shape: bool = True,
- enforce_dtype: bool = False,
+ enforce_dtype: bool = True,
default_format: t.Literal["json", "parquet", "csv"] = "json",
) -> PandasDataFrame:
"""
@@ -435,7 +411,7 @@ def from_sample(
- :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``}
- :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}]
- :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}}
- - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}}
+ - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}}
- :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays
apply_column_names: Update incoming DataFrame columns. ``columns`` must be specified at
function signature. If you don't want to enforce a specific columns
@@ -469,22 +445,144 @@ def from_sample(
@svc.api(input=input_spec, output=PandasDataFrame())
def predict(inputs: pd.DataFrame) -> pd.DataFrame: ...
"""
- columns = [str(x) for x in list(sample_input.columns)]
-
inst = cls(
orient=orient,
enforce_shape=enforce_shape,
shape=sample_input.shape,
apply_column_names=apply_column_names,
- columns=columns,
+ columns=[str(x) for x in list(sample_input.columns)],
enforce_dtype=enforce_dtype,
- dtype=None, # TODO: not breaking atm
+ dtype=True, # set to True to infer from given input
default_format=default_format,
)
inst.sample_input = sample_input
return inst
+ def validate_dataframe(
+ self, dataframe: ext.PdDataFrame, exception_cls: t.Type[Exception] = BadInput
+ ) -> ext.PdDataFrame:
+
+ if not LazyType["ext.PdDataFrame"]("pd.DataFrame").isinstance(dataframe):
+ raise InvalidArgument(
+ f"return object is not of type 'pd.DataFrame', got type '{type(dataframe)}' instead"
+ ) from None
+
+ # TODO: dtype check
+ # if self._dtype is not None and self._dtype != dataframe.dtypes:
+ # msg = f'{self.__class__.__name__}: Expecting DataFrame of dtype "{self._dtype}", but "{dataframe.dtypes}" was received.'
+ # if self._enforce_dtype:
+ # raise exception_cls(msg) from None
+
+ if self._columns is not None and len(self._columns) != dataframe.shape[1]:
+ msg = f"length of 'columns' ({len(self._columns)}) does not match the # of columns of incoming data."
+ if self._apply_column_names:
+ raise BadInput(msg) from None
+ else:
+ logger.debug(msg)
+ dataframe.columns = pd.Index(self._columns)
+
+ # TODO: convert from wide to long format (melt())
+ if self._shape is not None and self._shape != dataframe.shape:
+ msg = f'{self.__class__.__name__}: Expecting DataFrame of shape "{self._shape}", but "{dataframe.shape}" was received.'
+ if self._enforce_shape and not all(
+ left == right
+ for left, right in zip(self._shape, dataframe.shape)
+ if left != -1 and right != -1
+ ):
+ raise exception_cls(msg) from None
+
+ return dataframe
+
+ async def from_proto(self, field: pb.DataFrame | bytes) -> ext.PdDataFrame:
+ """
+ Process incoming protobuf request and convert it to ``pandas.DataFrame``
+
+ Args:
+ request: Incoming RPC request message.
+ context: grpc.ServicerContext
+
+ Returns:
+ a ``pandas.DataFrame`` object. This can then be used
+ inside users defined logics.
+ """
+ # TODO: support different serialization format
+ if isinstance(field, bytes):
+ # TODO: handle serialized_bytes for dataframe
+ raise NotImplementedError(
+ 'Currently not yet implemented. Use "dataframe" instead.'
+ )
+ else:
+ # note that there is a current bug where we don't check for
+ # dtype of given fields per Series to match with types of a given
+ # columns, hence, this would result in a wrong DataFrame that is not
+ # expected by our users.
+ assert isinstance(field, pb.DataFrame)
+ # columns orient: { column_name : {index : columns.series._value}}
+ if self._orient != "columns":
+ raise BadInput(
+ f"'dataframe' field currently only supports 'columns' orient. Make sure to set 'orient=columns' in {self.__class__.__name__}."
+ ) from None
+ data: list[t.Any] = []
+
+ def process_columns_contents(content: pb.Series) -> dict[str, t.Any]:
+ # To be use inside a ThreadPoolExecutor to handle
+ # large tabular data
+ if len(content.ListFields()) != 1:
+ raise BadInput(
+ f"Array contents can only be one of given values key. Use one of '{list(map(lambda f: f[0].name,content.ListFields()))}' instead."
+ ) from None
+ return {str(i): c for i, c in enumerate(content.ListFields()[0][1])}
+
+ with ThreadPoolExecutor(max_workers=10) as executor:
+ futures = executor.map(process_columns_contents, field.columns)
+ data.extend([i for i in list(futures)])
+ dataframe = pd.DataFrame(
+ dict(zip(field.column_names, data)),
+ columns=t.cast(t.List[str], field.column_names),
+ )
+ return self.validate_dataframe(dataframe)
+
+ async def to_proto(self, obj: ext.PdDataFrame) -> pb.DataFrame:
+ """
+ Process given objects and convert it to grpc protobuf response.
+
+ Args:
+ obj: ``pandas.DataFrame`` that will be serialized to protobuf
+ context: grpc.aio.ServicerContext from grpc.aio.Server
+ Returns:
+ ``service_pb2.Response``:
+ Protobuf representation of given ``pandas.DataFrame``
+ """
+ from bentoml._internal.io_descriptors.numpy import npdtype_to_fieldpb_map
+
+ # TODO: support different serialization format
+ obj = self.validate_dataframe(obj)
+ mapping = npdtype_to_fieldpb_map()
+ # note that this is not safe, since we are not checking the dtype of the series
+ # FIXME(aarnphm): validate and handle mix columns dtype
+ # currently we don't support ExtensionDtype
+ columns_name: list[str] = list(map(str, obj.columns))
+ not_supported: list[ext.PdDType] = list(
+ filter(
+ lambda x: x not in mapping,
+ map(lambda x: t.cast("ext.PdSeries", obj[x]).dtype, columns_name),
+ )
+ )
+ if len(not_supported) > 0:
+ raise UnprocessableEntity(
+ f'dtype in column "{obj.columns}" is not currently supported.'
+ ) from None
+ return pb.DataFrame(
+ column_names=columns_name,
+ columns=[
+ pb.Series(
+ **{mapping[t.cast("ext.NpDTypeLike", obj[col].dtype)]: obj[col]}
+ )
+ for col in columns_name
+ ],
+ )
+
class PandasSeries(IODescriptor["ext.PdSeries"]):
"""
@@ -551,7 +649,7 @@ def predict(input_arr):
- :obj:`split` - :code:`dict[str, Any]` ↦ {``idx`` ↠ ``[idx]``, ``columns`` ↠ ``[columns]``, ``data`` ↠ ``[values]``}
- :obj:`records` - :code:`list[Any]` ↦ [{``column`` ↠ ``value``}, ..., {``column`` ↠ ``value``}]
- :obj:`index` - :code:`dict[str, Any]` ↦ {``idx`` ↠ {``column`` ↠ ``value``}}
- - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` -> {``index`` ↠ ``value``}}
+ - :obj:`columns` - :code:`dict[str, Any]` ↦ {``column`` ↠ {``index`` ↠ ``value``}}
- :obj:`values` - :code:`dict[str, Any]` ↦ Values arrays
columns: List of columns name that users wish to update.
apply_column_names (`bool`, `optional`, default to :code:`False`):
@@ -573,8 +671,8 @@ def predict(input_arr):
from bentoml.io import PandasSeries
@svc.api(input=PandasSeries(shape=(51,10), enforce_shape=True), output=PandasSeries())
- def infer(input_df: pd.DataFrame) -> pd.DataFrame:
- # if input_df have shape (40,9), it will throw out errors
+ def infer(input_series: pd.Series) -> pd.Series:
+ # if input_series have shape (40,9), it will throw out errors
...
enforce_shape: Whether to enforce a certain shape. If ``enforce_shape=True`` then ``shape`` must be specified.
@@ -582,12 +680,13 @@ def infer(input_df: pd.DataFrame) -> pd.DataFrame:
:obj:`PandasSeries`: IO Descriptor that represents a :code:`pd.Series`.
"""
- _mime_type: str = MIME_TYPE_JSON
+ _proto_fields = ("series",)
+ _mime_type = "application/json"
def __init__(
self,
orient: ext.SeriesOrient = "records",
- dtype: bool | dict[str, t.Any] | None = None,
+ dtype: ext.PdDTypeArg | None = None,
enforce_dtype: bool = False,
shape: tuple[int, ...] | None = None,
enforce_shape: bool = False,
@@ -630,29 +729,13 @@ async def from_http_request(self, request: Request) -> ext.PdSeries:
a ``pd.Series`` object. This can then be used inside users defined logics.
"""
obj = await request.body()
- if self._enforce_dtype:
- if self._dtype is None:
- logger.warning(
- "`dtype` is None or undefined, while `enforce_dtype=True`"
- )
-
- # TODO(jiang): check dtypes when enforce_dtype is set
- res = pd.read_json(obj, typ="series", orient=self._orient, dtype=self._dtype)
-
- assert isinstance(res, pd.Series)
-
- if self._enforce_shape:
- if self._shape is None:
- logger.warning(
- "`shape` is None or undefined, while `enforce_shape`=True"
- )
- else:
- assert all(
- left == right
- for left, right in zip(self._shape, res.shape)
- if left != -1 and right != -1
- ), f"incoming has shape {res.shape} where enforced shape to be {self._shape}"
- return res
+ res: ext.PdSeries = pd.read_json(
+ obj,
+ typ="series",
+ orient=self._orient,
+ dtype=self._dtype,
+ )
+ return self.validate_series(res)
async def to_http_response(
self, obj: t.Any, ctx: Context | None = None
@@ -665,19 +748,43 @@ async def to_http_response(
Returns:
HTTP Response of type ``starlette.responses.Response``. This can be accessed via cURL or any external web traffic.
"""
- if not LazyType["ext.PdSeries"](pd.Series).isinstance(obj):
- raise InvalidArgument(
- f"return object is not of type `pd.Series`, got type {type(obj)} instead"
- )
-
+ obj = self.validate_series(obj)
if ctx is not None:
res = Response(
obj.to_json(orient=self._orient),
- media_type=MIME_TYPE_JSON,
+ media_type=self._mime_type,
headers=ctx.response.headers, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
- return Response(obj.to_json(orient=self._orient), media_type=MIME_TYPE_JSON)
+ return Response(
+ obj.to_json(orient=self._orient), media_type=self._mime_type
+ )
+
+ def validate_series(
+ self, series: ext.PdSeries, exception_cls: t.Type[Exception] = BadInput
+ ) -> ext.PdSeries:
+ # TODO: dtype check
+ if not LazyType["ext.PdSeries"]("pd.Series").isinstance(series):
+ raise InvalidArgument(
+ f"return object is not of type 'pd.Series', got type '{type(series)}' instead"
+ ) from None
+ # TODO: convert from wide to long format (melt())
+ if self._shape is not None and self._shape != series.shape:
+ msg = f"{self.__class__.__name__}: Expecting Series of shape '{self._shape}', but '{series.shape}' was received."
+ if self._enforce_shape and not all(
+ left == right
+ for left, right in zip(self._shape, series.shape)
+ if left != -1 and right != -1
+ ):
+ raise exception_cls(msg) from None
+
+ return series
+
+ async def from_proto(self, field: pb.Series | bytes) -> ext.PdSeries:
+ raise NotImplementedError("Currently not yet implemented.")
+
+ async def to_proto(self, obj: ext.PdSeries) -> pb.Series:
+ raise NotImplementedError("Currently not yet implemented.")
diff --git a/bentoml/_internal/io_descriptors/text.py b/bentoml/_internal/io_descriptors/text.py
index db372250cb..c85ab74212 100644
--- a/bentoml/_internal/io_descriptors/text.py
+++ b/bentoml/_internal/io_descriptors/text.py
@@ -11,14 +11,18 @@
from .base import IODescriptor
from ..utils.http import set_cookies
from ..service.openapi import SUCCESS_DESCRIPTION
+from ..utils.lazy_loader import LazyLoader
+from ..service.openapi.specification import Schema
+from ..service.openapi.specification import Response as OpenAPIResponse
from ..service.openapi.specification import MediaType
+from ..service.openapi.specification import RequestBody
if TYPE_CHECKING:
- from ..context import InferenceApiContext as Context
+ from google.protobuf import wrappers_pb2
-from ..service.openapi.specification import Schema
-from ..service.openapi.specification import Response as OpenAPIResponse
-from ..service.openapi.specification import RequestBody
+ from ..context import InferenceApiContext as Context
+else:
+ wrappers_pb2 = LazyLoader("wrappers_pb2", globals(), "google.protobuf.wrappers_pb2")
MIME_TYPE = "text/plain"
@@ -86,13 +90,14 @@ def predict(text: str) -> str:
:obj:`Text`: IO Descriptor that represents strings type.
"""
+ _proto_fields = ("text",)
+ _mime_type = MIME_TYPE
+
def __init__(self, *args: t.Any, **kwargs: t.Any):
if args or kwargs:
raise BentoMLException(
- "'Text' is not designed to take any args or kwargs during initialization."
- )
-
- self._mime_type = MIME_TYPE
+ f"'{self.__class__.__name__}' is not designed to take any args or kwargs during initialization."
+ ) from None
def input_type(self) -> t.Type[str]:
return str
@@ -123,11 +128,21 @@ async def to_http_response(self, obj: str, ctx: Context | None = None) -> Respon
if ctx is not None:
res = Response(
obj,
- media_type=MIME_TYPE,
+ media_type=self._mime_type,
headers=ctx.response.metadata, # type: ignore (bad starlette types)
status_code=ctx.response.status_code,
)
set_cookies(res, ctx.response.cookies)
return res
else:
- return Response(obj, media_type=MIME_TYPE)
+ return Response(obj, media_type=self._mime_type)
+
+ async def from_proto(self, field: wrappers_pb2.StringValue | bytes) -> str:
+ if isinstance(field, bytes):
+ return field.decode("utf-8")
+ else:
+ assert isinstance(field, wrappers_pb2.StringValue)
+ return field.value
+
+ async def to_proto(self, obj: str) -> wrappers_pb2.StringValue:
+ return wrappers_pb2.StringValue(value=obj)
diff --git a/bentoml/_internal/runner/runner.py b/bentoml/_internal/runner/runner.py
index 488e5b804b..0b931f3e69 100644
--- a/bentoml/_internal/runner/runner.py
+++ b/bentoml/_internal/runner/runner.py
@@ -225,7 +225,7 @@ def init_local(self, quiet: bool = False) -> None:
init local runnable instance, for testing and debugging only
"""
if not quiet:
- logger.warning("'Runner.init_local' is for debugging and testing only")
+ logger.warning("'Runner.init_local' is for debugging and testing only.")
self._init_local()
diff --git a/bentoml/_internal/runner/runner_handle/remote.py b/bentoml/_internal/runner/runner_handle/remote.py
index 2d31fcf972..d557cb5a84 100644
--- a/bentoml/_internal/runner/runner_handle/remote.py
+++ b/bentoml/_internal/runner/runner_handle/remote.py
@@ -18,7 +18,8 @@
from ...runner.utils import PAYLOAD_META_HEADER
from ...configuration.containers import BentoMLContainer
-if TYPE_CHECKING: # pragma: no cover
+if TYPE_CHECKING:
+ import yarl
from aiohttp import BaseConnector
from aiohttp.client import ClientSession
@@ -84,7 +85,7 @@ def _get_conn(self) -> BaseConnector:
)
self._addr = f"http://{parsed.netloc}"
else:
- raise ValueError(f"Unsupported bind scheme: {parsed.scheme}")
+ raise ValueError(f"Unsupported bind scheme: {parsed.scheme}") from None
return self._conn
@property
@@ -99,10 +100,7 @@ def _client(
or self._client_cache.closed
or self._loop.is_closed()
):
- import yarl
- from opentelemetry.instrumentation.aiohttp_client import (
- create_trace_config, # type: ignore (missing type stubs)
- )
+ from opentelemetry.instrumentation.aiohttp_client import create_trace_config
def strip_query_params(url: yarl.URL) -> str:
return str(url.with_query(None))
@@ -145,7 +143,7 @@ async def async_run_method(
if not payload_params.map(lambda i: i.batch_size).all_equal():
raise ValueError(
"All batchable arguments must have the same batch size."
- )
+ ) from None
path = "" if __bentoml_method.name == "__call__" else __bentoml_method.name
async with self._client.post(
@@ -164,30 +162,26 @@ async def async_run_method(
if resp.status != 200:
raise RemoteException(
f"An exception occurred in remote runner {self._runner.name}: [{resp.status}] {body.decode()}"
- )
+ ) from None
try:
meta_header = resp.headers[PAYLOAD_META_HEADER]
except KeyError:
raise RemoteException(
- f"Bento payload decode error: {PAYLOAD_META_HEADER} header not set. "
- "An exception might have occurred in the remote server."
- f"[{resp.status}] {body.decode()}"
+ f"Bento payload decode error: {PAYLOAD_META_HEADER} header not set. An exception might have occurred in the remote server. [{resp.status}] {body.decode()}"
) from None
try:
content_type = resp.headers["Content-Type"]
except KeyError:
raise RemoteException(
- f"Bento payload decode error: Content-Type header not set. "
- "An exception might have occurred in the remote server."
- f"[{resp.status}] {body.decode()}"
+ f"Bento payload decode error: Content-Type header not set. An exception might have occurred in the remote server. [{resp.status}] {body.decode()}"
) from None
if not content_type.lower().startswith("application/vnd.bentoml."):
raise RemoteException(
f"Bento payload decode error: invalid Content-Type '{content_type}'."
- )
+ ) from None
if content_type == "application/vnd.bentoml.multiple_outputs":
payloads = pickle.loads(body)
@@ -200,7 +194,7 @@ async def async_run_method(
data=body, meta=json.loads(meta_header), container=container
)
except JSONDecodeError:
- raise ValueError(f"Bento payload decode error: {meta_header}")
+ raise ValueError(f"Bento payload decode error: {meta_header}") from None
return AutoContainer.from_payload(payload)
@@ -212,11 +206,14 @@ def run_method(
) -> R:
import anyio
- return anyio.from_thread.run( # type: ignore (pyright cannot infer the return type)
- self.async_run_method,
- __bentoml_method,
- *args,
- **kwargs,
+ return t.cast(
+ "R",
+ anyio.from_thread.run(
+ self.async_run_method,
+ __bentoml_method,
+ *args,
+ **kwargs,
+ ),
)
def __del__(self) -> None:
diff --git a/bentoml/_internal/server/__init__.py b/bentoml/_internal/server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/bentoml/_internal/server/grpc/__init__.py b/bentoml/_internal/server/grpc/__init__.py
new file mode 100644
index 0000000000..1a92ea6fbb
--- /dev/null
+++ b/bentoml/_internal/server/grpc/__init__.py
@@ -0,0 +1,5 @@
+from .config import Config
+from .server import Server
+from .servicer import Servicer
+
+__all__ = ["Server", "Config", "Servicer"]
diff --git a/bentoml/_internal/server/grpc/config.py b/bentoml/_internal/server/grpc/config.py
new file mode 100644
index 0000000000..80735fedfb
--- /dev/null
+++ b/bentoml/_internal/server/grpc/config.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import sys
+import typing as t
+from typing import TYPE_CHECKING
+
+from simple_di import inject
+from simple_di import Provide
+
+from ...configuration.containers import BentoMLContainer
+
+if TYPE_CHECKING:
+ import grpc
+
+ from .servicer import Servicer
+
+
+class Config:
+ @inject
+ def __init__(
+ self,
+ servicer: Servicer,
+ bind_address: str,
+ enable_reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
+ max_message_length: int
+ | None = Provide[BentoMLContainer.grpc.max_message_length],
+ max_concurrent_streams: int
+ | None = Provide[BentoMLContainer.grpc.max_concurrent_streams],
+ maximum_concurrent_rpcs: int
+ | None = Provide[BentoMLContainer.grpc.maximum_concurrent_rpcs],
+ migration_thread_pool_workers: int = 1,
+ graceful_shutdown_timeout: float | None = None,
+ ) -> None:
+ self.servicer = servicer
+
+ # Note that the max_workers are used inside ThreadPoolExecutor.
+ # This ThreadPoolExecutor are used by aio.Server() to execute non-AsyncIO RPC handlers.
+ # Setting it to 1 makes it thread-safe for sync APIs.
+ self.migration_thread_pool_workers = migration_thread_pool_workers
+
+ # maximum_concurrent_rpcs defines the maximum number of concurrent RPCs this server
+ # will service before returning RESOURCE_EXHAUSTED status.
+ # Set to None will indicate no limit.
+ self.maximum_concurrent_rpcs = maximum_concurrent_rpcs
+
+ self.max_message_length = max_message_length
+
+ self.max_concurrent_streams = max_concurrent_streams
+
+ self.bind_address = bind_address
+ self.enable_reflection = enable_reflection
+ self.graceful_shutdown_timeout = graceful_shutdown_timeout
+
+ @property
+ def options(self) -> grpc.aio.ChannelArgumentType:
+ options: grpc.aio.ChannelArgumentType = []
+
+ if sys.platform != "win32":
+ # https://github.com/grpc/grpc/blob/master/include/grpc/impl/codegen/grpc_types.h#L294
+ # Eventhough GRPC_ARG_ALLOW_REUSEPORT is set to 1 by default, we want still
+ # want to explicitly set it to 1 so that we can spawn multiple gRPC servers in
+ # production settings.
+ options.append(("grpc.so_reuseport", 1))
+
+ if self.max_concurrent_streams:
+ options.append(("grpc.max_concurrent_streams", self.max_concurrent_streams))
+
+ if self.max_message_length:
+ options.extend(
+ (
+ ("grpc.max_message_length", self.max_message_length),
+ ("grpc.max_receive_message_length", self.max_message_length),
+ ("grpc.max_send_message_length", self.max_message_length),
+ )
+ )
+
+ return tuple(options)
+
+ @property
+ def handlers(self) -> t.Sequence[grpc.GenericRpcHandler] | None:
+ # Note that currently BentoML doesn't provide any specific
+ # handlers for gRPC. If users have any specific handlers,
+ # BentoML will pass it through to grpc.aio.Server
+ return self.servicer.bento_service.grpc_handlers
diff --git a/bentoml/_internal/server/grpc/server.py b/bentoml/_internal/server/grpc/server.py
new file mode 100644
index 0000000000..b95416f7ff
--- /dev/null
+++ b/bentoml/_internal/server/grpc/server.py
@@ -0,0 +1,158 @@
+from __future__ import annotations
+
+import typing as t
+import asyncio
+import logging
+from typing import TYPE_CHECKING
+
+from ...utils import LazyLoader
+from ...utils import cached_property
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ import grpc
+ from grpc import aio
+ from grpc_health.v1 import health_pb2 as pb_health
+ from grpc_health.v1 import health_pb2_grpc as services_health
+ from typing_extensions import Self
+
+ from bentoml.grpc.v1alpha1 import service_pb2_grpc as services
+
+ from .config import Config
+else:
+ from bentoml.grpc.utils import import_grpc
+ from bentoml.grpc.utils import import_generated_stubs
+ from bentoml._internal.utils import LazyLoader
+
+ grpc, aio = import_grpc()
+ _, services = import_generated_stubs()
+ health_exception_msg = "'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'."
+ pb_health = LazyLoader(
+ "pb_health",
+ globals(),
+ "grpc_health.v1.health_pb2",
+ exc_msg=health_exception_msg,
+ )
+ services_health = LazyLoader(
+ "services_health",
+ globals(),
+ "grpc_health.v1.health_pb2_grpc",
+ exc_msg=health_exception_msg,
+ )
+
+
+class Server:
+ """An async implementation of a gRPC server."""
+
+ def __init__(self, config: Config):
+ self.config = config
+ self.servicer = config.servicer
+ self.loaded = False
+
+ @cached_property
+ def loop(self) -> asyncio.AbstractEventLoop:
+ return asyncio.get_event_loop()
+
+ def load(self) -> Self:
+ from concurrent.futures import ThreadPoolExecutor
+
+ assert not self.loaded
+ if not bool(self.servicer):
+ self.servicer.load()
+ assert self.servicer.loaded
+
+ self.server = aio.server(
+ migration_thread_pool=ThreadPoolExecutor(
+ max_workers=self.config.migration_thread_pool_workers
+ ),
+ options=self.config.options,
+ maximum_concurrent_rpcs=self.config.maximum_concurrent_rpcs,
+ handlers=self.config.handlers,
+ interceptors=self.servicer.interceptors_stack,
+ )
+ self.loaded = True
+
+ return self
+
+ def run(self) -> None:
+ if not self.loaded:
+ self.load()
+ assert self.loaded
+
+ try:
+ self.loop.run_until_complete(self.serve())
+ finally:
+ try:
+ self.loop.call_soon_threadsafe(
+ lambda: asyncio.ensure_future(self.shutdown())
+ )
+ except Exception as e: # pylint: disable=broad-except
+ raise RuntimeError(f"Server failed unexpectedly: {e}") from None
+
+ async def serve(self) -> None:
+ self.add_insecure_port(self.config.bind_address)
+ await self.startup()
+ await self.wait_for_termination()
+
+ async def startup(self) -> None:
+ from bentoml.exceptions import MissingDependencyException
+
+ # Running on_startup callback.
+ await self.servicer.startup()
+ # register bento servicer
+ services.add_BentoServiceServicer_to_server(
+ self.servicer.bento_servicer, self.server
+ )
+ services_health.add_HealthServicer_to_server(
+ self.servicer.health_servicer, self.server
+ )
+
+ service_names = self.servicer.service_names
+ # register custom servicer
+ for (
+ user_servicer,
+ add_servicer_fn,
+ user_service_names,
+ ) in self.servicer.mount_servicers:
+ add_servicer_fn(user_servicer(), self.server)
+ service_names += tuple(user_service_names)
+
+ if self.config.enable_reflection:
+ try:
+ # reflection is required for health checking to work.
+ from grpc_reflection.v1alpha import reflection
+ except ImportError:
+ raise MissingDependencyException(
+ "reflection is enabled, which requires 'grpcio-reflection' to be installed. Install with 'pip install grpcio-reflection'."
+ )
+ service_names += (reflection.SERVICE_NAME,)
+ reflection.enable_server_reflection(service_names, self.server)
+
+ # mark all services as healthy
+ for service in service_names:
+ await self.servicer.health_servicer.set(
+ service, pb_health.HealthCheckResponse.SERVING # type: ignore (no types available)
+ )
+ await self.server.start()
+
+ async def shutdown(self):
+ # Running on_startup callback.
+ await self.servicer.shutdown()
+ await self.server.stop(grace=self.config.graceful_shutdown_timeout)
+ await self.servicer.health_servicer.enter_graceful_shutdown()
+ self.loop.stop()
+
+ async def wait_for_termination(self, timeout: int | None = None) -> bool:
+ return await self.server.wait_for_termination(timeout=timeout)
+
+ def add_insecure_port(self, address: str) -> int:
+ return self.server.add_insecure_port(address)
+
+ def add_secure_port(self, address: str, credentials: grpc.ServerCredentials) -> int:
+ return self.server.add_secure_port(address, credentials)
+
+ def add_generic_rpc_handlers(
+ self, generic_rpc_handlers: t.Sequence[grpc.GenericRpcHandler]
+ ) -> None:
+ self.server.add_generic_rpc_handlers(generic_rpc_handlers)
diff --git a/bentoml/_internal/server/grpc/servicer.py b/bentoml/_internal/server/grpc/servicer.py
new file mode 100644
index 0000000000..cbce08b3a6
--- /dev/null
+++ b/bentoml/_internal/server/grpc/servicer.py
@@ -0,0 +1,190 @@
+from __future__ import annotations
+
+import sys
+import typing as t
+import asyncio
+import logging
+from typing import TYPE_CHECKING
+
+import anyio
+
+from bentoml.grpc.utils import grpc_status_code
+
+from ....exceptions import InvalidArgument
+from ....exceptions import BentoMLException
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from logging import _ExcInfoType as ExcInfoType # type: ignore (private warning)
+
+ import grpc
+ from grpc import aio
+ from grpc_health.v1 import health
+ from typing_extensions import Self
+
+ from bentoml.grpc.types import Interceptors
+ from bentoml.grpc.types import AddServicerFn
+ from bentoml.grpc.types import ServicerClass
+ from bentoml.grpc.types import BentoServicerContext
+ from bentoml.grpc.types import GeneratedProtocolMessageType
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+ from bentoml.grpc.v1alpha1 import service_pb2_grpc as services
+
+ from ...service.service import Service
+
+else:
+ from bentoml.grpc.utils import import_grpc
+ from bentoml.grpc.utils import import_generated_stubs
+
+ from ...utils import LazyLoader
+
+ pb, services = import_generated_stubs()
+ grpc, aio = import_grpc()
+ health = LazyLoader(
+ "health",
+ globals(),
+ "grpc_health.v1.health",
+ exc_msg="'grpcio-health-checking' is required for using health checking endpoints. Install with 'pip install grpcio-health-checking'.",
+ )
+ containers = LazyLoader(
+ "containers", globals(), "google.protobuf.internal.containers"
+ )
+
+
+def log_exception(request: pb.Request, exc_info: ExcInfoType) -> None:
+ # gRPC will always send a POST request.
+ logger.error("Exception on /%s [POST]", request.api_name, exc_info=exc_info)
+
+
+class Servicer:
+ """Create an instance of gRPC Servicer."""
+
+ def __init__(
+ self: Self,
+ service: Service,
+ on_startup: t.Sequence[t.Callable[[], t.Any]] | None = None,
+ on_shutdown: t.Sequence[t.Callable[[], t.Any]] | None = None,
+ mount_servicers: t.Sequence[tuple[ServicerClass, AddServicerFn, list[str]]]
+ | None = None,
+ interceptors: Interceptors | None = None,
+ ) -> None:
+ self.bento_service = service
+
+ self.on_startup = [] if not on_startup else list(on_startup)
+ self.on_shutdown = [] if not on_shutdown else list(on_shutdown)
+ self.mount_servicers = [] if not mount_servicers else list(mount_servicers)
+ self.interceptors = [] if not interceptors else list(interceptors)
+ self.loaded = False
+
+ def load(self):
+ assert not self.loaded
+
+ self.interceptors_stack = self.build_interceptors_stack()
+
+ self.bento_servicer = create_bento_servicer(self.bento_service)
+
+ # Create a health check servicer. We use the non-blocking implementation
+ # to avoid thread starvation.
+ self.health_servicer = health.aio.HealthServicer()
+
+ self.service_names = tuple(
+ service.full_name for service in pb.DESCRIPTOR.services_by_name.values()
+ ) + (health.SERVICE_NAME,)
+ self.loaded = True
+
+ def build_interceptors_stack(self) -> list[aio.ServerInterceptor]:
+ return list(map(lambda x: x(), self.interceptors))
+
+ async def startup(self):
+ for handler in self.on_startup:
+ if is_async_iterable(handler):
+ await handler()
+ else:
+ handler()
+
+ async def shutdown(self):
+ for handler in self.on_shutdown:
+ if is_async_iterable(handler):
+ await handler()
+ else:
+ handler()
+
+ def __bool__(self):
+ return self.loaded
+
+
+def is_async_iterable(obj: t.Any) -> bool: # pragma: no cover
+ return asyncio.iscoroutinefunction(obj) or (
+ callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
+ )
+
+
+def create_bento_servicer(service: Service) -> services.BentoServiceServicer:
+ """
+ This is the actual implementation of BentoServicer.
+ Main inference entrypoint will be invoked via /bentoml.grpc..BentoService/Call
+ """
+ from ...io_descriptors.multipart import Multipart
+
+ class BentoServiceImpl(services.BentoServiceServicer):
+ """An asyncio implementation of BentoService servicer."""
+
+ async def Call( # type: ignore (no async types) # pylint: disable=invalid-overridden-method
+ self,
+ request: pb.Request,
+ context: BentoServicerContext,
+ ) -> pb.Response | None:
+ if request.api_name not in service.apis:
+ raise InvalidArgument(
+ f"given 'api_name' is not defined in {service.name}",
+ ) from None
+
+ api = service.apis[request.api_name]
+ response = pb.Response()
+
+ # NOTE: since IODescriptor._proto_fields is a tuple, the order is preserved.
+ # This is important so that we know the order of fields to process.
+ # We will use fields descriptor to determine how to process that request.
+ try:
+ # we will check if the given fields list contains a pb.Multipart.
+ field = request.WhichOneof("content")
+ if field is None:
+ raise InvalidArgument("Request cannot be empty.")
+ accepted_fields = api.input._proto_fields + ("serialized_bytes",)
+ if field not in accepted_fields:
+ raise InvalidArgument(
+ f"'{api.input.__class__.__name__}' accepts one of the following fields: '{', '.join(accepted_fields)}', and none of them are found in the request message.",
+ ) from None
+ input_ = await api.input.from_proto(getattr(request, field))
+ if asyncio.iscoroutinefunction(api.func):
+ if isinstance(api.input, Multipart):
+ output = await api.func(**input_)
+ else:
+ output = await api.func(input_)
+ else:
+ if isinstance(api.input, Multipart):
+ output = await anyio.to_thread.run_sync(api.func, **input_)
+ else:
+ output = await anyio.to_thread.run_sync(api.func, input_)
+ protos = await api.output.to_proto(output)
+ # TODO(aarnphm): support multiple proto fields
+ response = pb.Response(**{api.output._proto_fields[0]: protos})
+ except BentoMLException as e:
+ log_exception(request, sys.exc_info())
+ await context.abort(code=grpc_status_code(e), details=e.message)
+ except (RuntimeError, TypeError, NotImplementedError):
+ log_exception(request, sys.exc_info())
+ await context.abort(
+ code=grpc.StatusCode.INTERNAL,
+ details="A runtime error has occurred, see stacktrace from logs.",
+ )
+ except Exception: # pylint: disable=broad-except
+ log_exception(request, sys.exc_info())
+ await context.abort(
+ code=grpc.StatusCode.INTERNAL,
+ details="An error has occurred in BentoML user code when handling this request, find the error details in server logs.",
+ )
+ return response
+
+ return BentoServiceImpl()
diff --git a/bentoml/_internal/server/grpc_app.py b/bentoml/_internal/server/grpc_app.py
new file mode 100644
index 0000000000..c14756929b
--- /dev/null
+++ b/bentoml/_internal/server/grpc_app.py
@@ -0,0 +1,101 @@
+from __future__ import annotations
+
+import typing as t
+import logging
+from typing import TYPE_CHECKING
+from functools import partial
+
+from simple_di import inject
+from simple_di import Provide
+
+from ..configuration.containers import BentoMLContainer
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+
+ from bentoml.grpc.types import Interceptors
+
+ from ..service import Service
+ from .grpc.servicer import Servicer
+
+ OnStartup = list[t.Callable[[], None | t.Coroutine[t.Any, t.Any, None]]]
+
+
+class GRPCAppFactory:
+ """
+ GRPCApp creates an async gRPC API server based on APIs defined with a BentoService via BentoService#apis.
+ This is a light wrapper around GRPCServer with addition to `on_startup` and `on_shutdown` hooks.
+
+ Note that even though the code are similar with BaseAppFactory, gRPC protocol is different from ASGI.
+ """
+
+ @inject
+ def __init__(
+ self,
+ bento_service: Service,
+ *,
+ enable_metrics: bool = Provide[
+ BentoMLContainer.api_server_config.metrics.enabled
+ ],
+ ) -> None:
+ self.bento_service = bento_service
+ self.enable_metrics = enable_metrics
+
+ @property
+ def on_startup(self) -> OnStartup:
+ on_startup: OnStartup = [self.bento_service.on_grpc_server_startup]
+ if BentoMLContainer.development_mode.get():
+ for runner in self.bento_service.runners:
+ on_startup.append(partial(runner.init_local, quiet=True))
+ else:
+ for runner in self.bento_service.runners:
+ on_startup.append(runner.init_client)
+
+ return on_startup
+
+ @property
+ def on_shutdown(self) -> list[t.Callable[[], None]]:
+ on_shutdown = [self.bento_service.on_grpc_server_shutdown]
+ for runner in self.bento_service.runners:
+ on_shutdown.append(runner.destroy)
+
+ return on_shutdown
+
+ def __call__(self) -> Servicer:
+ from .grpc import Servicer
+
+ return Servicer(
+ self.bento_service,
+ on_startup=self.on_startup,
+ on_shutdown=self.on_shutdown,
+ mount_servicers=self.bento_service.mount_servicers,
+ interceptors=self.interceptors,
+ )
+
+ @property
+ def interceptors(self) -> Interceptors:
+ # Note that order of interceptors is important here.
+
+ from bentoml.grpc.interceptors.opentelemetry import (
+ AsyncOpenTelemetryServerInterceptor,
+ )
+
+ interceptors: Interceptors = [AsyncOpenTelemetryServerInterceptor]
+
+ if self.enable_metrics:
+ from bentoml.grpc.interceptors.prometheus import PrometheusServerInterceptor
+
+ interceptors.append(PrometheusServerInterceptor)
+
+ if BentoMLContainer.api_server_config.logging.access.enabled.get():
+ from bentoml.grpc.interceptors.access import AccessLogServerInterceptor
+
+ access_logger = logging.getLogger("bentoml.access")
+ if access_logger.getEffectiveLevel() <= logging.INFO:
+ interceptors.append(AccessLogServerInterceptor)
+
+ # add users-defined interceptors.
+ interceptors.extend(self.bento_service.interceptors)
+
+ return interceptors
diff --git a/bentoml/_internal/server/http/__init__.py b/bentoml/_internal/server/http/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/bentoml/_internal/server/access.py b/bentoml/_internal/server/http/access.py
similarity index 95%
rename from bentoml/_internal/server/access.py
rename to bentoml/_internal/server/http/access.py
index 972fa44c79..3c146abf26 100644
--- a/bentoml/_internal/server/access.py
+++ b/bentoml/_internal/server/http/access.py
@@ -1,10 +1,12 @@
+from __future__ import annotations
+
import logging
from timeit import default_timer
from typing import TYPE_CHECKING
from contextvars import ContextVar
if TYPE_CHECKING:
- from .. import external_typing as ext
+ from ... import external_typing as ext
REQ_CONTENT_LENGTH = "REQUEST_CONTENT_LENGTH"
REQ_CONTENT_TYPE = "REQUEST_CONTENT_TYPE"
@@ -41,7 +43,7 @@ class AccessLogMiddleware:
def __init__(
self,
- app: "ext.ASGIApp",
+ app: ext.ASGIApp,
has_request_content_length: bool = False,
has_request_content_type: bool = False,
has_response_content_length: bool = False,
@@ -56,9 +58,9 @@ def __init__(
async def __call__(
self,
- scope: "ext.ASGIScope",
- receive: "ext.ASGIReceive",
- send: "ext.ASGISend",
+ scope: ext.ASGIScope,
+ receive: ext.ASGIReceive,
+ send: ext.ASGISend,
) -> None:
if not scope["type"].startswith("http"):
await self.app(scope, receive, send)
diff --git a/bentoml/_internal/server/instruments.py b/bentoml/_internal/server/http/instruments.py
similarity index 94%
rename from bentoml/_internal/server/instruments.py
rename to bentoml/_internal/server/http/instruments.py
index 06e8df2359..f692f62ca4 100644
--- a/bentoml/_internal/server/instruments.py
+++ b/bentoml/_internal/server/http/instruments.py
@@ -8,25 +8,23 @@
from simple_di import inject
from simple_di import Provide
-from ..utils.metrics import metric_name
-from ..configuration.containers import BentoMLContainer
+from ...context import component_context
+from ...utils.metrics import metric_name
+from ...configuration.containers import BentoMLContainer
if TYPE_CHECKING:
- from .. import external_typing as ext
- from ..server.metrics.prometheus import PrometheusClient
+ from ... import external_typing as ext
+ from ...server.metrics.prometheus import PrometheusClient
logger = logging.getLogger(__name__)
-START_TIME_VAR: "contextvars.ContextVar[float]" = contextvars.ContextVar(
- "START_TIME_VAR"
-)
-STATUS_VAR: "contextvars.ContextVar[int]" = contextvars.ContextVar("STATUS_VAR")
-from ..context import component_context
+START_TIME_VAR: contextvars.ContextVar[float] = contextvars.ContextVar("START_TIME_VAR")
+STATUS_VAR: contextvars.ContextVar[int] = contextvars.ContextVar("STATUS_VAR")
class HTTPTrafficMetricsMiddleware:
def __init__(
self,
- app: "ext.ASGIApp",
+ app: ext.ASGIApp,
namespace: str = "bentoml_api_server",
):
self.app = app
@@ -36,7 +34,7 @@ def __init__(
@inject
def _setup(
self,
- metrics_client: "PrometheusClient" = Provide[BentoMLContainer.metrics_client],
+ metrics_client: PrometheusClient = Provide[BentoMLContainer.metrics_client],
duration_buckets: tuple[float, ...] = Provide[
BentoMLContainer.duration_buckets
],
@@ -105,9 +103,9 @@ def _setup(
async def __call__(
self,
- scope: "ext.ASGIScope",
- receive: "ext.ASGIReceive",
- send: "ext.ASGISend",
+ scope: ext.ASGIScope,
+ receive: ext.ASGIReceive,
+ send: ext.ASGISend,
) -> None:
if not self._is_setup:
self._setup()
diff --git a/bentoml/_internal/server/service_app.py b/bentoml/_internal/server/http_app.py
similarity index 95%
rename from bentoml/_internal/server/service_app.py
rename to bentoml/_internal/server/http_app.py
index e20737a0ac..3e4e33ac9a 100644
--- a/bentoml/_internal/server/service_app.py
+++ b/bentoml/_internal/server/http_app.py
@@ -85,9 +85,9 @@ def log_exception(request: Request, exc_info: t.Any) -> None:
)
-class ServiceAppFactory(BaseAppFactory):
+class HTTPAppFactory(BaseAppFactory):
"""
- ServiceApp creates a REST API server based on APIs defined with a BentoService
+ HTTPApp creates a REST API server based on APIs defined with a BentoService
via BentoService#apis. Each InferenceAPI will become one
endpoint exposed on the REST server, and the RequestHandler defined on each
InferenceAPI object will be used to handle Request object before feeding the
@@ -98,10 +98,8 @@ class ServiceAppFactory(BaseAppFactory):
def __init__(
self,
bento_service: Service,
- enable_access_control: bool = Provide[
- BentoMLContainer.api_server_config.cors.enabled
- ],
- access_control_options: dict[str, list[str] | int] = Provide[
+ enable_access_control: bool = Provide[BentoMLContainer.http.cors.enabled],
+ access_control_options: dict[str, list[str] | str | int] = Provide[
BentoMLContainer.access_control_options
],
enable_metrics: bool = Provide[
@@ -217,14 +215,14 @@ def middlewares(self) -> list[Middleware]:
# metrics middleware
if self.enable_metrics:
- from .instruments import HTTPTrafficMetricsMiddleware
+ from .http.instruments import HTTPTrafficMetricsMiddleware
middlewares.append(Middleware(HTTPTrafficMetricsMiddleware))
# otel middleware
- import opentelemetry.instrumentation.asgi as otel_asgi # type: ignore
+ import opentelemetry.instrumentation.asgi as otel_asgi
- def client_request_hook(span: Span, _scope: dict[str, t.Any]) -> None:
+ def client_request_hook(span: Span, _: dict[str, t.Any]) -> None:
if span is not None:
trace_context.request_id = span.context.span_id
@@ -241,7 +239,7 @@ def client_request_hook(span: Span, _scope: dict[str, t.Any]) -> None:
access_log_config = BentoMLContainer.api_server_config.logging.access
if access_log_config.enabled.get():
- from .access import AccessLogMiddleware
+ from .http.access import AccessLogMiddleware
access_logger = logging.getLogger("bentoml.access")
if access_logger.getEffectiveLevel() <= logging.INFO:
diff --git a/bentoml/_internal/server/metrics/prometheus.py b/bentoml/_internal/server/metrics/prometheus.py
index f18a5084af..ce2230e066 100644
--- a/bentoml/_internal/server/metrics/prometheus.py
+++ b/bentoml/_internal/server/metrics/prometheus.py
@@ -1,10 +1,17 @@
-# type: ignore[reportMissingTypeStubs]
+from __future__ import annotations
+
import os
import sys
import typing as t
import logging
+from typing import TYPE_CHECKING
from functools import partial
+if TYPE_CHECKING:
+ from prometheus_client.metrics_core import Metric
+
+ from ... import external_typing as ext
+
logger = logging.getLogger(__name__)
@@ -13,7 +20,7 @@ def __init__(
self,
*,
multiproc: bool = True,
- multiproc_dir: t.Optional[str] = None,
+ multiproc_dir: str | None = None,
):
"""
Set up multiproc_dir for prometheus to work in multiprocess mode,
@@ -87,6 +94,9 @@ def start_http_server(self, port: int, addr: str = "") -> None:
registry=self.registry,
)
+ def make_wsgi_app(self) -> ext.WSGIApp:
+ return self.prometheus_client.make_wsgi_app(registry=self.registry) # type: ignore (unfinished prometheus types)
+
def generate_latest(self):
if self.multiproc:
registry = self.prometheus_client.CollectorRegistry()
@@ -95,6 +105,13 @@ def generate_latest(self):
else:
return self.prometheus_client.generate_latest()
+ def text_string_to_metric_families(self) -> t.Generator[Metric, None, None]:
+ from prometheus_client.parser import text_string_to_metric_families
+
+ yield from text_string_to_metric_families(
+ self.generate_latest().decode("utf-8")
+ )
+
@property
def CONTENT_TYPE_LATEST(self) -> str:
return self.prometheus_client.CONTENT_TYPE_LATEST
diff --git a/bentoml/_internal/server/runner_app.py b/bentoml/_internal/server/runner_app.py
index f463590ba6..97a59f5395 100644
--- a/bentoml/_internal/server/runner_app.py
+++ b/bentoml/_internal/server/runner_app.py
@@ -183,13 +183,13 @@ def client_request_hook(span: Span, _scope: t.Dict[str, t.Any]) -> None:
)
if self.enable_metrics:
- from .instruments import RunnerTrafficMetricsMiddleware
+ from .http.instruments import RunnerTrafficMetricsMiddleware
middlewares.append(Middleware(RunnerTrafficMetricsMiddleware))
access_log_config = BentoMLContainer.runners_config.logging.access
if access_log_config.enabled.get():
- from .access import AccessLogMiddleware
+ from .http.access import AccessLogMiddleware
access_logger = logging.getLogger("bentoml.access")
if access_logger.getEffectiveLevel() <= logging.INFO:
diff --git a/bentoml/_internal/service/service.py b/bentoml/_internal/service/service.py
index 3c654438e2..864be97f18 100644
--- a/bentoml/_internal/service/service.py
+++ b/bentoml/_internal/service/service.py
@@ -3,6 +3,7 @@
import typing as t
import logging
from typing import TYPE_CHECKING
+from functools import partial
import attr
@@ -16,13 +17,19 @@
from ..io_descriptors import IODescriptor
if TYPE_CHECKING:
+ import grpc
+
+ from bentoml.grpc.types import AddServicerFn
+ from bentoml.grpc.types import ServicerClass
+
from .. import external_typing as ext
from ..bento import Bento
+ from ..server.grpc.servicer import Servicer
from .openapi.specification import OpenAPISpecification
+else:
+ from bentoml.grpc.utils import import_grpc
- WSGI_APP = t.Callable[
- [t.Callable[..., t.Any], t.Mapping[str, t.Any]], t.Iterable[bytes]
- ]
+ grpc, _ = import_grpc()
logger = logging.getLogger(__name__)
@@ -82,6 +89,7 @@ class Service:
runners: t.List[Runner]
models: t.List[Model]
+ # starlette related
mount_apps: t.List[t.Tuple[ext.ASGIApp, str, str]] = attr.field(
init=False, factory=list
)
@@ -89,6 +97,16 @@ class Service:
t.Tuple[t.Type[ext.AsgiMiddleware], t.Dict[str, t.Any]]
] = attr.field(init=False, factory=list)
+ # gRPC related
+ mount_servicers: list[tuple[ServicerClass, AddServicerFn, list[str]]] = attr.field(
+ init=False, factory=list
+ )
+ interceptors: list[partial[grpc.aio.ServerInterceptor]] = attr.field(
+ init=False, factory=list
+ )
+ grpc_handlers: list[grpc.GenericRpcHandler] = attr.field(init=False, factory=list)
+
+ # list of APIs from @svc.api
apis: t.Dict[str, InferenceAPI] = attr.field(init=False, factory=dict)
# Tag/Bento are only set when the service was loaded from a bento
@@ -197,11 +215,23 @@ def on_asgi_app_startup(self) -> None:
def on_asgi_app_shutdown(self) -> None:
pass
+ def on_grpc_server_startup(self) -> None:
+ pass
+
+ def on_grpc_server_shutdown(self) -> None:
+ pass
+
+ @property
+ def grpc_servicer(self) -> Servicer:
+ from ..server.grpc_app import GRPCAppFactory
+
+ return GRPCAppFactory(self)()
+
@property
def asgi_app(self) -> "ext.ASGIApp":
- from ..server.service_app import ServiceAppFactory
+ from ..server.http_app import HTTPAppFactory
- return ServiceAppFactory(self)()
+ return HTTPAppFactory(self)()
def mount_asgi_app(
self, app: "ext.ASGIApp", path: str = "/", name: t.Optional[str] = None
@@ -209,7 +239,7 @@ def mount_asgi_app(
self.mount_apps.append((app, path, name)) # type: ignore
def mount_wsgi_app(
- self, app: WSGI_APP, path: str = "/", name: t.Optional[str] = None
+ self, app: ext.WSGIApp, path: str = "/", name: t.Optional[str] = None
) -> None:
# TODO: Migrate to a2wsgi
from starlette.middleware.wsgi import WSGIMiddleware
@@ -217,10 +247,46 @@ def mount_wsgi_app(
self.mount_apps.append((WSGIMiddleware(app), path, name)) # type: ignore
def add_asgi_middleware(
- self, middleware_cls: t.Type["ext.AsgiMiddleware"], **options: t.Any
+ self, middleware_cls: t.Type[ext.AsgiMiddleware], **options: t.Any
) -> None:
self.middlewares.append((middleware_cls, options))
+ def mount_grpc_servicer(
+ self,
+ servicer_cls: ServicerClass,
+ add_servicer_fn: AddServicerFn,
+ service_names: list[str],
+ ) -> None:
+ self.mount_servicers.append((servicer_cls, add_servicer_fn, service_names))
+
+ def add_grpc_interceptor(
+ self, interceptor_cls: t.Type[grpc.aio.ServerInterceptor], **options: t.Any
+ ) -> None:
+ from bentoml.exceptions import BadInput
+
+ if not issubclass(interceptor_cls, grpc.aio.ServerInterceptor):
+ if isinstance(interceptor_cls, partial):
+ if options:
+ logger.debug(
+ "'%s' is a partial class, hence '%s' will be ignored.",
+ interceptor_cls,
+ options,
+ )
+ if not issubclass(interceptor_cls.func, grpc.aio.ServerInterceptor):
+ raise BadInput(
+ "'partial' class is not a subclass of 'grpc.aio.ServerInterceptor'."
+ )
+ self.interceptors.append(interceptor_cls)
+ else:
+ raise BadInput(
+ f"{interceptor_cls} is not a subclass of 'grpc.aio.ServerInterceptor'."
+ )
+
+ self.interceptors.append(partial(interceptor_cls, **options))
+
+ def add_grpc_handlers(self, handlers: list[grpc.GenericRpcHandler]) -> None:
+ self.grpc_handlers.extend(handlers)
+
def on_load_bento(svc: Service, bento: Bento):
object.__setattr__(svc, "bento", bento)
diff --git a/bentoml/_internal/utils/__init__.py b/bentoml/_internal/utils/__init__.py
index a240d04c9f..71ad30f259 100644
--- a/bentoml/_internal/utils/__init__.py
+++ b/bentoml/_internal/utils/__init__.py
@@ -6,6 +6,8 @@
import random
import socket
import typing as t
+import inspect
+import logging
import functools
import contextlib
from typing import overload
@@ -56,48 +58,41 @@
"validate_or_create_dir",
"display_path_under_home",
"rich_console",
+ "experimental",
]
+_EXPERIMENTAL_APIS: set[str] = set()
-@overload
-def kwargs_transformers(
- func: GenericFunction[t.Concatenate[str, bool, t.Iterable[str], P]],
- *,
- transformer: GenericFunction[t.Any],
-) -> GenericFunction[t.Concatenate[str, t.Iterable[str], bool, P]]:
- ...
+def _warn_experimental(f: t.Any):
+ api_name = f.__name__ if inspect.isfunction(f) else repr(f)
+ if api_name not in _EXPERIMENTAL_APIS:
+ _EXPERIMENTAL_APIS.add(api_name)
+ msg = "'%s' is an EXPERIMENTAL API and is currently not yet stable. Proceed with caution!"
+ logger = logging.getLogger(f.__module__)
+ logger.warning(msg, api_name)
-@overload
-def kwargs_transformers(
- func: None = None, *, transformer: GenericFunction[t.Any]
-) -> GenericFunction[t.Any]:
- ...
+def experimental(f: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
+ @functools.wraps(f)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
+ _warn_experimental(f)
+ return f(*args, **kwargs)
-def kwargs_transformers(
- _func: t.Callable[..., t.Any] | None = None,
- *,
- transformer: GenericFunction[t.Any],
-) -> GenericFunction[t.Any]:
- def decorator(func: GenericFunction[t.Any]) -> t.Callable[P, t.Any]:
- @functools.wraps(func)
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
- return func(*args, **{k: transformer(v) for k, v in kwargs.items()})
+ return wrapper
- return wrapper
- if _func is None:
- return decorator
- return decorator(_func)
+def add_experimental_docstring(f: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
+ f.__doc__ = "[EXPERIMENTAL] " + (f.__doc__ if f.__doc__ is not None else "")
+ return f
-@t.overload
+@overload
def first_not_none(*args: T | None, default: T) -> T:
...
-@t.overload
+@overload
def first_not_none(*args: T | None) -> T | None:
...
@@ -179,13 +174,27 @@ def _(*args: P.args, **kwargs: P.kwargs) -> t.Optional[_T_co]:
@contextlib.contextmanager
def reserve_free_port(
host: str = "localhost",
+ port: int | None = None,
prefix: t.Optional[str] = None,
max_retry: int = 50,
+ enable_so_reuseport: bool = False,
) -> t.Iterator[int]:
"""
detect free port and reserve until exit the context
"""
+ import psutil
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if enable_so_reuseport:
+ if psutil.WINDOWS:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ elif psutil.MACOS or psutil.FREEBSD:
+ sock.setsockopt(socket.SOL_SOCKET, 0x10000, 1) # SO_REUSEPORT_LB
+ else:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+
+ if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
+ raise RuntimeError("Failed to set SO_REUSEPORT.") from None
if prefix is not None:
prefix_num = int(prefix) * 10 ** (5 - len(prefix))
suffix_range = min(65535 - prefix_num, 10 ** (5 - len(prefix)))
@@ -199,13 +208,17 @@ def reserve_free_port(
continue
else:
raise RuntimeError(
- f"cannot find free port with prefix {prefix} after {max_retry} retries"
- )
+ f"Cannot find free port with prefix {prefix} after {max_retry} retries."
+ ) from None
else:
- sock.bind((host, 0))
- port = sock.getsockname()[1]
- yield port
- sock.close()
+ if port:
+ sock.bind((host, port))
+ else:
+ sock.bind((host, 0))
+ try:
+ yield sock.getsockname()[1]
+ finally:
+ sock.close()
def copy_file_to_fs_folder(
@@ -359,8 +372,6 @@ def __call__(
@contextlib.contextmanager
@functools.wraps(func)
def _func(*args: "P.args", **kwargs: "P.kwargs") -> t.Any:
- import inspect
-
bound_args = inspect.signature(func).bind(*args, **kwargs)
bound_args.apply_defaults()
if self._cache_key_template:
diff --git a/bentoml/_internal/utils/buildx.py b/bentoml/_internal/utils/buildx.py
index 34bb1d9139..53248be33a 100644
--- a/bentoml/_internal/utils/buildx.py
+++ b/bentoml/_internal/utils/buildx.py
@@ -140,8 +140,8 @@ def build(
build_args: dict[str, str] | None,
build_context: dict[str, str] | None,
builder: str | None,
- cache_from: str | list[str] | dict[str, str] | None,
- cache_to: str | list[str] | dict[str, str] | None,
+ cache_from: str | t.Iterable[str] | dict[str, str] | None,
+ cache_to: str | t.Iterable[str] | dict[str, str] | None,
cgroup_parent: str | None,
file: PathType | None,
iidfile: PathType | None,
@@ -150,14 +150,14 @@ def build(
metadata_file: PathType | None,
network: str | None,
no_cache: bool,
- no_cache_filter: list[str] | None,
+ no_cache_filter: t.Iterable[str] | None,
output: str | dict[str, str] | None,
- platform: str | list[str] | None,
+ platform: str | t.Iterable[str] | None,
progress: t.Literal["auto", "tty", "plain"],
pull: bool,
push: bool,
quiet: bool,
- secrets: str | list[str] | None,
+ secrets: str | t.Iterable[str] | None,
shm_size: str | int | None,
rm: bool,
ssh: str | None,
diff --git a/bentoml/_internal/utils/platform.py b/bentoml/_internal/utils/platform.py
deleted file mode 100644
index 0811724bbf..0000000000
--- a/bentoml/_internal/utils/platform.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import signal
-import subprocess
-from typing import TYPE_CHECKING
-
-import psutil
-
-if TYPE_CHECKING:
- import typing as t
-
-
-def kill_subprocess_tree(p: "subprocess.Popen[t.Any]") -> None:
- """
- Tell the process to terminate and kill all of its children. Availabe both on Windows and Linux.
- Note: It will return immediately rather than wait for the process to terminate.
-
- Args:
- p: subprocess.Popen object
- """
- if psutil.WINDOWS:
- subprocess.call(["taskkill", "/F", "/T", "/PID", str(p.pid)])
- else:
- p.terminate()
-
-
-def cancel_subprocess(p: "subprocess.Popen[t.Any]") -> None:
- if psutil.WINDOWS:
- p.send_signal(signal.CTRL_C_EVENT) # type: ignore
- else:
- p.send_signal(signal.SIGINT)
diff --git a/bentoml/bentos.py b/bentoml/bentos.py
index 27d2c7c3bf..57585062a1 100644
--- a/bentoml/bentos.py
+++ b/bentoml/bentos.py
@@ -1,3 +1,4 @@
+# pylint: disable=unused-argument
"""
User facing python APIs for managing local bentos and build new bentos
"""
@@ -5,10 +6,14 @@
from __future__ import annotations
import os
+import sys
import typing as t
import logging
+import tempfile
+import contextlib
import subprocess
from typing import TYPE_CHECKING
+from functools import partial
from simple_di import inject
from simple_di import Provide
@@ -22,10 +27,14 @@
from ._internal.bento.build_config import BentoBuildConfig
from ._internal.configuration.containers import BentoMLContainer
+if sys.version_info >= (3, 8):
+ from shutil import copytree
+else:
+ from backports.shutil_copytree import copytree
+
if TYPE_CHECKING:
from ._internal.bento import BentoStore
from ._internal.types import PathType
- from ._internal.models import ModelStore
logger = logging.getLogger(__name__)
@@ -261,7 +270,6 @@ def build(
version: t.Optional[str] = None,
build_ctx: t.Optional[str] = None,
_bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store],
- _model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
) -> "Bento":
"""
User-facing API for building a Bento. The available build options are identical to the keys of a
@@ -287,7 +295,6 @@ def build(
version: Override the default auto generated version str
build_ctx: Build context directory, when used as
_bento_store: save Bento created to this BentoStore
- _model_store: pull Models required from this ModelStore
Returns:
Bento: a Bento instance representing the materialized Bento saved in BentoStore
@@ -355,7 +362,6 @@ def build_bentofile(
version: t.Optional[str] = None,
build_ctx: t.Optional[str] = None,
_bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store],
- _model_store: "ModelStore" = Provide[BentoMLContainer.model_store],
) -> "Bento":
"""
Build a Bento base on options specified in a bentofile.yaml file.
@@ -368,7 +374,6 @@ def build_bentofile(
version: Override the default auto generated version str
build_ctx: Build context directory, when used as
_bento_store: save Bento created to this BentoStore
- _model_store: pull Models required from this ModelStore
"""
try:
bentofile = resolve_user_filepath(bentofile, build_ctx)
@@ -388,18 +393,70 @@ def build_bentofile(
return bento
+@contextlib.contextmanager
+def construct_dockerfile(
+ bento: Bento,
+ *,
+ features: t.Sequence[str] | None = None,
+ docker_final_stage: str | None = None,
+) -> t.Generator[tuple[str, str], None, None]:
+ dockerfile_path = os.path.join("env", "docker", "Dockerfile")
+ final_instruction = ""
+ if features is not None:
+ features = [l for s in map(lambda x: x.split(","), features) for l in s]
+ if not all(f in FEATURES for f in features):
+ raise InvalidArgument(
+ f"Available features are: {FEATURES}. Invalid fields from provided: {set(features) - set(FEATURES)}"
+ )
+ final_instruction += f"""\
+RUN --mount=type=cache,target=/root/.cache/pip pip install bentoml[{','.join(features)}]
+"""
+ if docker_final_stage:
+ final_instruction += f"""\
+{docker_final_stage}
+"""
+ with open(bento.path_of(dockerfile_path), "r") as f:
+ FINAL_DOCKERFILE = f"""\
+{f.read()}
+FROM base-{bento.info.docker.distro}
+# Additional instructions for final image.
+{final_instruction}
+"""
+ if final_instruction != "":
+ with tempfile.TemporaryDirectory("bento-tmp") as tmpdir:
+ copytree(bento.path, tmpdir, dirs_exist_ok=True)
+ with open(os.path.join(tmpdir, dockerfile_path), "w") as dockerfile:
+ dockerfile.write(FINAL_DOCKERFILE)
+ yield tmpdir, dockerfile.name
+ else:
+ yield bento.path, dockerfile_path
+
+
+# Sync with BentoML extra dependencies
+FEATURES = [
+ "tracing",
+ "grpc",
+ "tracing.zipkin",
+ "tracing.jaeger",
+ "tracing.otlp",
+]
+
+
@inject
def containerize(
tag: Tag | str,
- docker_image_tag: t.List[str] | None = None,
+ docker_image_tag: t.Iterable[str] | None = None,
*,
+ # containerize options
+ features: t.Sequence[str] | None = None,
+ # docker options
add_host: dict[str, str] | None = None,
allow: t.List[str] | None = None,
build_args: dict[str, str] | None = None,
build_context: dict[str, str] | None = None,
builder: str | None = None,
- cache_from: str | t.List[str] | dict[str, str] | None = None,
- cache_to: str | t.List[str] | dict[str, str] | None = None,
+ cache_from: str | t.Iterable[str] | dict[str, str] | None = None,
+ cache_to: str | t.Iterable[str] | dict[str, str] | None = None,
cgroup_parent: str | None = None,
iidfile: PathType | None = None,
labels: dict[str, str] | None = None,
@@ -407,14 +464,14 @@ def containerize(
metadata_file: PathType | None = None,
network: str | None = None,
no_cache: bool = False,
- no_cache_filter: t.List[str] | None = None,
+ no_cache_filter: t.Iterable[str] | None = None,
output: str | dict[str, str] | None = None,
- platform: str | t.List[str] | None = None,
+ platform: str | t.Iterable[str] | None = None,
progress: t.Literal["auto", "tty", "plain"] = "auto",
pull: bool = False,
push: bool = False,
quiet: bool = False,
- secrets: str | t.List[str] | None = None,
+ secrets: str | t.Iterable[str] | None = None,
shm_size: str | int | None = None,
rm: bool = False,
ssh: str | None = None,
@@ -423,6 +480,8 @@ def containerize(
_bento_store: "BentoStore" = Provide[BentoMLContainer.bento_store],
) -> bool:
+ import psutil
+
from bentoml._internal.utils import buildx
env = {"DOCKER_BUILDKIT": "1", "DOCKER_SCAN_SUGGEST": "false"}
@@ -431,61 +490,58 @@ def containerize(
if not docker_image_tag:
docker_image_tag = [str(bento.tag)]
- dockerfile_path = os.path.join("env", "docker", "Dockerfile")
-
logger.info(f"Building docker image for {bento}...")
- try:
- buildx.build(
- subprocess_env=env,
- cwd=bento.path,
- file=dockerfile_path,
- tags=docker_image_tag,
- add_host=add_host,
- allow=allow,
- build_args=build_args,
- build_context=build_context,
- builder=builder,
- cache_from=cache_from,
- cache_to=cache_to,
- cgroup_parent=cgroup_parent,
- iidfile=iidfile,
- labels=labels,
- load=load,
- metadata_file=metadata_file,
- network=network,
- no_cache=no_cache,
- no_cache_filter=no_cache_filter,
- output=output,
- platform=platform,
- progress=progress,
- pull=pull,
- push=push,
- quiet=quiet,
- secrets=secrets,
- shm_size=shm_size,
- rm=rm,
- ssh=ssh,
- target=target,
- ulimit=ulimit,
+ if platform and not psutil.LINUX and platform != "linux/amd64":
+ logger.warning(
+ 'Current platform is set to "%s". To avoid issue, we recommend you to build the container with x86_64 (amd64): "bentoml containerize %s --platform linux/amd64"',
+ ",".join(platform),
+ str(bento.tag),
)
+ run_buildx = partial(
+ buildx.build,
+ subprocess_env=env,
+ tags=docker_image_tag,
+ add_host=add_host,
+ allow=allow,
+ build_args=build_args,
+ build_context=build_context,
+ builder=builder,
+ cache_from=cache_from,
+ cache_to=cache_to,
+ cgroup_parent=cgroup_parent,
+ iidfile=iidfile,
+ labels=labels,
+ load=load,
+ metadata_file=metadata_file,
+ network=network,
+ no_cache=no_cache,
+ no_cache_filter=no_cache_filter,
+ output=output,
+ platform=platform,
+ progress=progress,
+ pull=pull,
+ push=push,
+ quiet=quiet,
+ secrets=secrets,
+ shm_size=shm_size,
+ rm=rm,
+ ssh=ssh,
+ target=target,
+ ulimit=ulimit,
+ )
+ clean_context = contextlib.ExitStack()
+ required = clean_context.enter_context(
+ construct_dockerfile(bento, features=features)
+ )
+ try:
+ build_path, dockerfile_path = required
+ run_buildx(cwd=build_path, file=dockerfile_path)
+ return True
except subprocess.CalledProcessError as e:
logger.error(f"Failed building docker image: {e}")
- if platform != "linux/amd64":
- logger.debug(
- f"""If you run into the following error: "failed to solve: pull access denied, repository does not exist or may require authorization: server message: insufficient_scope: authorization failed". This means Docker doesn't have context of your build platform {platform}. By default BentoML will set target build platform to the current machine platform via `uname -m`. Try again by specifying to build x86_64 (amd64) platform: bentoml containerize {str(bento.tag)} --platform linux/amd64"""
- )
return False
- else:
- logger.info(
- 'Successfully built docker image for "%s" with tags "%s"',
- str(bento.tag),
- ",".join(docker_image_tag),
- )
- logger.info(
- 'To run your newly built Bento container, use one of the above tags, and pass it to "docker run". i.e: "docker run -it --rm -p 3000:3000 %s"',
- docker_image_tag[0],
- )
- return True
+ finally:
+ clean_context.close()
__all__ = [
diff --git a/bentoml/exceptions.py b/bentoml/exceptions.py
index 34924e7b28..3e9e292e9b 100644
--- a/bentoml/exceptions.py
+++ b/bentoml/exceptions.py
@@ -73,6 +73,14 @@ class NotFound(BentoMLException):
error_code = HTTPStatus.NOT_FOUND
+class UnprocessableEntity(BentoMLException):
+ """
+ Raise when API server receiving unprocessable entity request
+ """
+
+ error_code = HTTPStatus.UNPROCESSABLE_ENTITY
+
+
class TooManyRequests(BentoMLException):
"""
Raise when incoming requests exceeds the capacity of a server
diff --git a/bentoml/grpc/__init__.py b/bentoml/grpc/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/bentoml/grpc/buf.yaml b/bentoml/grpc/buf.yaml
new file mode 100644
index 0000000000..4e26d9abb8
--- /dev/null
+++ b/bentoml/grpc/buf.yaml
@@ -0,0 +1,22 @@
+version: v1
+lint:
+ use:
+ - DEFAULT
+ - COMMENT_ENUM
+ - COMMENT_MESSAGE
+ - COMMENT_RPC
+ - COMMENT_SERVICE
+ except:
+ - DIRECTORY_SAME_PACKAGE
+ - RPC_REQUEST_STANDARD_NAME
+ - RPC_RESPONSE_STANDARD_NAME
+ ignore_only:
+ DEFAULT:
+ - bentoml/grpc/v1alpha1/service_test.proto
+ ENUM_VALUE_PREFIX:
+ - bentoml/grpc/v1alpha1/service.proto
+ enum_zero_value_suffix: _UNSPECIFIED
+ rpc_allow_same_request_response: true
+ rpc_allow_google_protobuf_empty_requests: true
+ rpc_allow_google_protobuf_empty_responses: true
+ service_suffix: Service
diff --git a/bentoml/grpc/interceptors/__init__.py b/bentoml/grpc/interceptors/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/bentoml/grpc/interceptors/access.py b/bentoml/grpc/interceptors/access.py
new file mode 100644
index 0000000000..a89cc0f0a1
--- /dev/null
+++ b/bentoml/grpc/interceptors/access.py
@@ -0,0 +1,100 @@
+from __future__ import annotations
+
+import typing as t
+import logging
+import functools
+from timeit import default_timer
+from typing import TYPE_CHECKING
+
+from bentoml.grpc.utils import to_http_status
+from bentoml.grpc.utils import wrap_rpc_handler
+from bentoml.grpc.utils import GRPC_CONTENT_TYPE
+
+if TYPE_CHECKING:
+ import grpc
+ from grpc import aio
+ from grpc.aio._typing import MetadataType # pylint: disable=unused-import
+
+ from bentoml.grpc.types import Request
+ from bentoml.grpc.types import Response
+ from bentoml.grpc.types import RpcMethodHandler
+ from bentoml.grpc.types import AsyncHandlerMethod
+ from bentoml.grpc.types import HandlerCallDetails
+ from bentoml.grpc.types import BentoServicerContext
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+else:
+ from bentoml.grpc.utils import import_grpc
+ from bentoml.grpc.utils import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
+ grpc, aio = import_grpc()
+
+
+class AccessLogServerInterceptor(aio.ServerInterceptor):
+ """
+ An asyncio interceptor for access logging.
+ """
+
+ async def intercept_service(
+ self,
+ continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]],
+ handler_call_details: HandlerCallDetails,
+ ) -> RpcMethodHandler:
+ logger = logging.getLogger("bentoml.access")
+ handler = await continuation(handler_call_details)
+ method_name = handler_call_details.method
+
+ if handler and (handler.response_streaming or handler.request_streaming):
+ return handler
+
+ def wrapper(behaviour: AsyncHandlerMethod[Response]):
+ @functools.wraps(behaviour)
+ async def new_behaviour(
+ request: Request, context: BentoServicerContext
+ ) -> Response | t.Awaitable[Response]:
+ content_type = GRPC_CONTENT_TYPE
+ trailing_metadata: MetadataType | None = context.trailing_metadata()
+ if trailing_metadata:
+ trailing = dict(trailing_metadata)
+ content_type = trailing.get("content-type", GRPC_CONTENT_TYPE)
+
+ response = pb.Response()
+ start = default_timer()
+ try:
+ response = await behaviour(request, context)
+ except Exception as e: # pylint: disable=broad-except
+ context.set_code(grpc.StatusCode.INTERNAL)
+ context.set_details(str(e))
+ finally:
+ latency = max(default_timer() - start, 0) * 1000
+
+ req = [
+ "scheme=http", # TODO: support https when ssl is added
+ f"path={method_name}",
+ f"type={content_type}",
+ f"size={request.ByteSize()}",
+ ]
+
+ # Note that in order AccessLogServerInterceptor to work, the
+ # interceptor must be added to the server after AsyncOpenTeleServerInterceptor
+ # and PrometheusServerInterceptor.
+ typed_context_code = t.cast(grpc.StatusCode, context.code())
+ resp = [
+ f"http_status={to_http_status(typed_context_code)}",
+ f"grpc_status={typed_context_code.value[0]}",
+ f"type={content_type}",
+ f"size={response.ByteSize()}",
+ ]
+
+ logger.info(
+ "%s (%s) (%s) %.3fms",
+ context.peer(),
+ ",".join(req),
+ ",".join(resp),
+ latency,
+ )
+ return response
+
+ return new_behaviour
+
+ return wrap_rpc_handler(wrapper, handler)
diff --git a/bentoml/grpc/interceptors/opentelemetry.py b/bentoml/grpc/interceptors/opentelemetry.py
new file mode 100644
index 0000000000..b36b19a12c
--- /dev/null
+++ b/bentoml/grpc/interceptors/opentelemetry.py
@@ -0,0 +1,277 @@
+from __future__ import annotations
+
+import typing as t
+import logging
+import functools
+from typing import TYPE_CHECKING
+from contextlib import asynccontextmanager
+
+from simple_di import inject
+from simple_di import Provide
+from opentelemetry import trace
+from opentelemetry.context import attach
+from opentelemetry.context import detach
+from opentelemetry.propagate import extract
+from opentelemetry.trace.status import Status
+from opentelemetry.trace.status import StatusCode
+from opentelemetry.semconv.trace import SpanAttributes
+
+from bentoml.grpc.utils import wrap_rpc_handler
+from bentoml.grpc.utils import GRPC_CONTENT_TYPE
+from bentoml.grpc.utils import parse_method_name
+from bentoml._internal.utils.pkg import get_pkg_version
+from bentoml._internal.configuration.containers import BentoMLContainer
+
+if TYPE_CHECKING:
+ import grpc
+ from grpc import aio
+ from grpc.aio._typing import MetadataKey
+ from grpc.aio._typing import MetadataType
+ from grpc.aio._typing import MetadataValue
+ from opentelemetry.trace import Span
+ from opentelemetry.sdk.trace import TracerProvider
+
+ from bentoml.grpc.types import Request
+ from bentoml.grpc.types import Response
+ from bentoml.grpc.types import RpcMethodHandler
+ from bentoml.grpc.types import AsyncHandlerMethod
+ from bentoml.grpc.types import HandlerCallDetails
+ from bentoml.grpc.types import BentoServicerContext
+else:
+ from bentoml.grpc.utils import import_grpc
+
+ grpc, aio = import_grpc()
+
+logger = logging.getLogger(__name__)
+
+
+class _OpenTelemetryServicerContext(aio.ServicerContext["Request", "Response"]):
+ def __init__(self, servicer_context: BentoServicerContext, active_span: Span):
+ self._servicer_context = servicer_context
+ self._active_span = active_span
+ self._code = grpc.StatusCode.OK
+ self._details = ""
+ super().__init__()
+
+ def __getattr__(self, attr: str) -> t.Any:
+ return getattr(self._servicer_context, attr)
+
+ async def read(self) -> Request:
+ return await self._servicer_context.read()
+
+ async def write(self, message: Response) -> None:
+ return await self._servicer_context.write(message)
+
+ def trailing_metadata(self) -> aio.Metadata:
+ return self._servicer_context.trailing_metadata() # type: ignore (unfinished type)
+
+ def auth_context(self) -> t.Mapping[str, t.Iterable[bytes]]:
+ return self._servicer_context.auth_context()
+
+ def peer_identity_key(self) -> str | None:
+ return self._servicer_context.peer_identity_key()
+
+ def peer_identities(self) -> t.Iterable[bytes] | None:
+ return self._servicer_context.peer_identities()
+
+ def peer(self) -> str:
+ return self._servicer_context.peer()
+
+ def disable_next_message_compression(self) -> None:
+ self._servicer_context.disable_next_message_compression()
+
+ def set_compression(self, compression: grpc.Compression) -> None:
+ return self._servicer_context.set_compression(compression)
+
+ def invocation_metadata(self) -> aio.Metadata | None:
+ return self._servicer_context.invocation_metadata()
+
+ def set_trailing_metadata(self, trailing_metadata: MetadataType) -> None:
+ self._servicer_context.set_trailing_metadata(trailing_metadata)
+
+ async def send_initial_metadata(self, initial_metadata: MetadataType) -> None:
+ return await self._servicer_context.send_initial_metadata(initial_metadata)
+
+ async def abort(
+ self,
+ code: grpc.StatusCode,
+ details: str = "",
+ trailing_metadata: MetadataType = tuple(),
+ ) -> None:
+ self._code = code
+ self._details = details
+ self._active_span.set_attribute(
+ SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0]
+ )
+ self._active_span.set_status(
+ Status(status_code=StatusCode.ERROR, description=f"{code}:{details}")
+ )
+ return await self._servicer_context.abort(
+ code, details=details, trailing_metadata=trailing_metadata
+ )
+
+ def set_code(self, code: grpc.StatusCode) -> None:
+ self._code = code
+ details = self._details or code.value[1]
+ self._active_span.set_attribute(
+ SpanAttributes.RPC_GRPC_STATUS_CODE, code.value[0]
+ )
+ if code != grpc.StatusCode.OK:
+ self._active_span.set_status(
+ Status(status_code=StatusCode.ERROR, description=f"{code}:{details}")
+ )
+ return self._servicer_context.set_code(code)
+
+ def code(self) -> grpc.StatusCode:
+ return self._code
+
+ def set_details(self, details: str) -> None:
+ self._details = details
+ if self._code != grpc.StatusCode.OK:
+ self._active_span.set_status(
+ Status(
+ status_code=StatusCode.ERROR, description=f"{self._code}:{details}"
+ )
+ )
+ return self._servicer_context.set_details(details)
+
+ def details(self) -> str:
+ return self._details
+
+
+# Since opentelemetry doesn't provide an async implementation for the server interceptor,
+# we will need to create an async implementation ourselves.
+# By doing this we will have more control over how to handle span and context propagation.
+#
+# Until there is a solution upstream, this implementation is sufficient for our needs.
+class AsyncOpenTelemetryServerInterceptor(aio.ServerInterceptor):
+ @inject
+ def __init__(
+ self,
+ *,
+ tracer_provider: TracerProvider = Provide[BentoMLContainer.tracer_provider],
+ schema_url: str | None = None,
+ ):
+ self._tracer = tracer_provider.get_tracer(
+ "opentelemetry.instrumentation.grpc",
+ get_pkg_version("opentelemetry-instrumentation-grpc"),
+ schema_url=schema_url,
+ )
+
+ @asynccontextmanager
+ async def set_remote_context(
+ self, servicer_context: BentoServicerContext
+ ) -> t.AsyncGenerator[None, None]:
+ metadata = servicer_context.invocation_metadata()
+ if metadata:
+ md: dict[MetadataKey, MetadataValue] = {m.key: m.value for m in metadata}
+ ctx = extract(md)
+ token = attach(ctx)
+ try:
+ yield
+ finally:
+ detach(token)
+ else:
+ yield
+
+ def start_span(
+ self,
+ method_name: str,
+ context: BentoServicerContext,
+ set_status_on_exception: bool = False,
+ ) -> t.ContextManager[Span]:
+ attributes: dict[str, str | bytes] = {
+ SpanAttributes.RPC_SYSTEM: "grpc",
+ SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[0],
+ }
+
+ # method_name shouldn't be none, otherwise
+ # it will never reach this point.
+ method_rpc, _ = parse_method_name(method_name)
+ attributes.update(
+ {
+ SpanAttributes.RPC_METHOD: method_rpc.method,
+ SpanAttributes.RPC_SERVICE: method_rpc.fully_qualified_service,
+ }
+ )
+
+ # add some attributes from the metadata
+ metadata = context.invocation_metadata()
+ if metadata:
+ dct: dict[str, str | bytes] = dict(metadata)
+ if "user-agent" in dct:
+ attributes["rpc.user_agent"] = dct["user-agent"]
+
+ # get trailing metadata
+ trailing_metadata: MetadataType | None = context.trailing_metadata()
+ if trailing_metadata:
+ trailing = dict(trailing_metadata)
+ attributes["rpc.content_type"] = trailing.get(
+ "content-type", GRPC_CONTENT_TYPE
+ )
+
+ # Split up the peer to keep with how other telemetry sources
+ # do it. This looks like:
+ # * ipv6:[::1]:57284
+ # * ipv4:127.0.0.1:57284
+ # * ipv4:10.2.1.1:57284,127.0.0.1:57284
+ #
+ # the process ip and port would be [::1] 57284
+ try:
+ ipv4_addr = context.peer().split(",")[0]
+ ip, port = ipv4_addr.split(":", 1)[1].rsplit(":", 1)
+ attributes.update(
+ {
+ SpanAttributes.NET_PEER_IP: ip,
+ SpanAttributes.NET_PEER_PORT: port,
+ }
+ )
+ # other telemetry sources add this, so we will too
+ if ip in ("[::1]", "127.0.0.1"):
+ attributes[SpanAttributes.NET_PEER_NAME] = "localhost"
+ except IndexError:
+ logger.warning(f"Failed to parse peer address '{context.peer()}'")
+
+ return self._tracer.start_as_current_span(
+ name=method_name,
+ kind=trace.SpanKind.SERVER,
+ attributes=attributes,
+ set_status_on_exception=set_status_on_exception,
+ )
+
+ async def intercept_service(
+ self,
+ continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]],
+ handler_call_details: HandlerCallDetails,
+ ) -> RpcMethodHandler:
+ handler = await continuation(handler_call_details)
+ method_name = handler_call_details.method
+
+ # Currently not support streaming RPCs.
+ if handler and (handler.response_streaming or handler.request_streaming):
+ return handler
+
+ def wrapper(behaviour: AsyncHandlerMethod[Response]):
+ @functools.wraps(behaviour)
+ async def new_behaviour(
+ request: Request, context: BentoServicerContext
+ ) -> Response | t.Awaitable[Response]:
+
+ async with self.set_remote_context(context):
+ with self.start_span(method_name, context) as span:
+ # wrap context
+ wrapped_context = _OpenTelemetryServicerContext(context, span)
+
+ # And now we run the actual RPC.
+ try:
+ return await behaviour(request, wrapped_context)
+ except Exception as e:
+ # We are interested in uncaught exception, otherwise
+ # it will be handled by gRPC.
+ if type(e) != Exception:
+ span.record_exception(e)
+ raise e
+
+ return new_behaviour
+
+ return wrap_rpc_handler(wrapper, handler)
diff --git a/bentoml/grpc/interceptors/prometheus.py b/bentoml/grpc/interceptors/prometheus.py
new file mode 100644
index 0000000000..62424b7d5e
--- /dev/null
+++ b/bentoml/grpc/interceptors/prometheus.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import typing as t
+import logging
+import functools
+import contextvars
+from timeit import default_timer
+from typing import TYPE_CHECKING
+
+from simple_di import inject
+from simple_di import Provide
+
+from bentoml.grpc.utils import to_http_status
+from bentoml.grpc.utils import wrap_rpc_handler
+from bentoml._internal.context import component_context
+from bentoml._internal.configuration.containers import BentoMLContainer
+
+START_TIME_VAR: contextvars.ContextVar[float] = contextvars.ContextVar("START_TIME_VAR")
+
+if TYPE_CHECKING:
+ import grpc
+ from grpc import aio
+
+ from bentoml.grpc.types import Request
+ from bentoml.grpc.types import Response
+ from bentoml.grpc.types import RpcMethodHandler
+ from bentoml.grpc.types import AsyncHandlerMethod
+ from bentoml.grpc.types import HandlerCallDetails
+ from bentoml.grpc.types import BentoServicerContext
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+ from bentoml._internal.server.metrics.prometheus import PrometheusClient
+else:
+ from bentoml.grpc.utils import import_grpc
+ from bentoml.grpc.utils import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
+ grpc, aio = import_grpc()
+
+
+logger = logging.getLogger(__name__)
+
+
+class PrometheusServerInterceptor(aio.ServerInterceptor):
+ """
+ An async interceptor for Prometheus metrics.
+ """
+
+ def __init__(self, *, namespace: str = "bentoml_api_server"):
+ self._is_setup = False
+ self.namespace = namespace
+
+ @inject
+ def _setup(
+ self,
+ metrics_client: PrometheusClient = Provide[BentoMLContainer.metrics_client],
+ duration_buckets: tuple[float, ...] = Provide[
+ BentoMLContainer.duration_buckets
+ ],
+ ): # pylint: disable=attribute-defined-outside-init
+ self.metrics_request_duration = metrics_client.Histogram(
+ namespace=self.namespace,
+ name="request_duration_seconds",
+ documentation="API GRPC request duration in seconds",
+ labelnames=[
+ "api_name",
+ "service_name",
+ "service_version",
+ "http_response_code",
+ ],
+ buckets=duration_buckets,
+ )
+ self.metrics_request_total = metrics_client.Counter(
+ namespace=self.namespace,
+ name="request_total",
+ documentation="Total number of GRPC requests",
+ labelnames=[
+ "api_name",
+ "service_name",
+ "service_version",
+ "http_response_code",
+ ],
+ )
+ self.metrics_request_in_progress = metrics_client.Gauge(
+ namespace=self.namespace,
+ name="request_in_progress",
+ documentation="Total number of GRPC requests in progress now",
+ labelnames=["api_name", "service_name", "service_version"],
+ multiprocess_mode="livesum",
+ )
+ self._is_setup = True
+
+ async def intercept_service(
+ self,
+ continuation: t.Callable[[HandlerCallDetails], t.Awaitable[RpcMethodHandler]],
+ handler_call_details: HandlerCallDetails,
+ ) -> RpcMethodHandler:
+ if not self._is_setup:
+ self._setup()
+
+ handler = await continuation(handler_call_details)
+
+ if handler and (handler.response_streaming or handler.request_streaming):
+ return handler
+
+ START_TIME_VAR.set(default_timer())
+
+ def wrapper(behaviour: AsyncHandlerMethod[Response]):
+ @functools.wraps(behaviour)
+ async def new_behaviour(
+ request: Request, context: BentoServicerContext
+ ) -> Response | t.Awaitable[Response]:
+ if not isinstance(request, pb.Request):
+ return await behaviour(request, context)
+
+ api_name = request.api_name
+
+ # instrument request total count
+ self.metrics_request_total.labels(
+ api_name=api_name,
+ service_name=component_context.bento_name,
+ service_version=component_context.bento_version,
+ http_response_code=to_http_status(
+ t.cast(grpc.StatusCode, context.code())
+ ),
+ ).inc()
+
+ # instrument request duration
+ assert START_TIME_VAR.get() != 0
+ total_time = max(default_timer() - START_TIME_VAR.get(), 0)
+ self.metrics_request_duration.labels( # type: ignore (unfinished prometheus types)
+ api_name=api_name,
+ service_name=component_context.bento_name,
+ service_version=component_context.bento_version,
+ http_response_code=to_http_status(
+ t.cast(grpc.StatusCode, context.code())
+ ),
+ ).observe(
+ total_time
+ )
+ START_TIME_VAR.set(0)
+ # instrument request in progress
+ with self.metrics_request_in_progress.labels(
+ api_name=api_name,
+ service_version=component_context.bento_version,
+ service_name=component_context.bento_name,
+ ).track_inprogress():
+ response = await behaviour(request, context)
+ return response
+
+ return new_behaviour
+
+ return wrap_rpc_handler(wrapper, handler)
diff --git a/bentoml/grpc/types.py b/bentoml/grpc/types.py
new file mode 100644
index 0000000000..9fa1dceee3
--- /dev/null
+++ b/bentoml/grpc/types.py
@@ -0,0 +1,108 @@
+# pragma: no cover
+"""
+Specific types for BentoService gRPC server.
+"""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ import typing as t
+ from functools import partial
+
+ import grpc
+ from grpc import aio
+
+ from bentoml.grpc.v1alpha1.service_pb2 import Request
+ from bentoml.grpc.v1alpha1.service_pb2 import Response
+ from bentoml.grpc.v1alpha1.service_pb2_grpc import BentoServiceServicer
+
+ P = t.TypeVar("P")
+
+ BentoServicerContext = aio.ServicerContext[Request, Response]
+
+ RequestDeserializerFn = t.Callable[[Request | None], object] | None
+ ResponseSerializerFn = t.Callable[[bytes], Response | None] | None
+
+ HandlerMethod = t.Callable[[Request, BentoServicerContext], P]
+ AsyncHandlerMethod = t.Callable[[Request, BentoServicerContext], t.Awaitable[P]]
+
+ class RpcMethodHandler(
+ t.NamedTuple(
+ "RpcMethodHandler",
+ request_streaming=bool,
+ response_streaming=bool,
+ request_deserializer=RequestDeserializerFn,
+ response_serializer=ResponseSerializerFn,
+ unary_unary=t.Optional[HandlerMethod[Response]],
+ unary_stream=t.Optional[HandlerMethod[Response]],
+ stream_unary=t.Optional[HandlerMethod[Response]],
+ stream_stream=t.Optional[HandlerMethod[Response]],
+ ),
+ grpc.RpcMethodHandler,
+ ):
+ """An implementation of a single RPC method."""
+
+ request_streaming: bool
+ response_streaming: bool
+ request_deserializer: RequestDeserializerFn
+ response_serializer: ResponseSerializerFn
+ unary_unary: t.Optional[HandlerMethod[Response]]
+ unary_stream: t.Optional[HandlerMethod[Response]]
+ stream_unary: t.Optional[HandlerMethod[Response]]
+ stream_stream: t.Optional[HandlerMethod[Response]]
+
+ class HandlerCallDetails(
+ t.NamedTuple(
+ "HandlerCallDetails", method=str, invocation_metadata=aio.Metadata
+ ),
+ grpc.HandlerCallDetails,
+ ):
+ """Describes an RPC that has just arrived for service.
+
+ Attributes:
+ method: The method name of the RPC.
+ invocation_metadata: A sequence of metadatum, a key-value pair included in the HTTP header.
+ An example is: ``('binary-metadata-bin', b'\\x00\\xFF')``
+ """
+
+ method: str
+ invocation_metadata: aio.Metadata
+
+ # Servicer types
+ ServicerImpl = t.TypeVar("ServicerImpl")
+ Servicer = t.Annotated[ServicerImpl, object]
+ ServicerClass = t.Type[Servicer[t.Any]]
+ AddServicerFn = t.Callable[[Servicer[t.Any], aio.Server | grpc.Server], None]
+
+ # accepted proto fields
+ ProtoField = t.Annotated[
+ str,
+ t.Literal[
+ "dataframe",
+ "file",
+ "json",
+ "ndarray",
+ "series",
+ "text",
+ "multipart",
+ "serialized_bytes",
+ ],
+ ]
+
+ Interceptors = list[
+ t.Callable[[], aio.ServerInterceptor] | partial[aio.ServerInterceptor]
+ ]
+
+ # types defined for client interceptors
+ BentoUnaryUnaryCall = aio.UnaryUnaryCall[Request, Response]
+
+ __all__ = [
+ "Request",
+ "Response",
+ "BentoServicerContext",
+ "BentoServiceServicer",
+ "HandlerCallDetails",
+ "RpcMethodHandler",
+ "BentoUnaryUnaryCall",
+ ]
diff --git a/bentoml/grpc/utils/__init__.py b/bentoml/grpc/utils/__init__.py
new file mode 100644
index 0000000000..4f252146cc
--- /dev/null
+++ b/bentoml/grpc/utils/__init__.py
@@ -0,0 +1,200 @@
+from __future__ import annotations
+
+import typing as t
+import logging
+from http import HTTPStatus
+from typing import TYPE_CHECKING
+from functools import lru_cache
+from dataclasses import dataclass
+
+from bentoml._internal.utils.lazy_loader import LazyLoader
+
+if TYPE_CHECKING:
+ import types
+ from enum import Enum
+
+ import grpc
+ from google.protobuf import descriptor as descriptor_mod
+
+ from bentoml.exceptions import BentoMLException
+ from bentoml.grpc.types import RpcMethodHandler
+ from bentoml.grpc.v1alpha1 import service_pb2 as pb
+
+ # We need this here so that __all__ is detected due to lazy import
+ def import_generated_stubs(
+ version: str = "v1alpha1",
+ ) -> tuple[types.ModuleType, types.ModuleType]:
+ ...
+
+ def import_grpc() -> tuple[types.ModuleType, types.ModuleType]:
+ ...
+
+else:
+ from bentoml.grpc.utils._import_hook import import_grpc
+ from bentoml.grpc.utils._import_hook import import_generated_stubs
+
+ pb, _ = import_generated_stubs()
+ grpc, _ = import_grpc()
+ descriptor_mod = LazyLoader(
+ "descriptor_mod", globals(), "google.protobuf.descriptor"
+ )
+
+__all__ = [
+ "grpc_status_code",
+ "parse_method_name",
+ "to_http_status",
+ "GRPC_CONTENT_TYPE",
+ "import_generated_stubs",
+ "import_grpc",
+]
+
+logger = logging.getLogger(__name__)
+
+# content-type is always application/grpc
+GRPC_CONTENT_TYPE = "application/grpc"
+
+
+def get_field_by_name(
+ descriptor: descriptor_mod.FieldDescriptor | descriptor_mod.Descriptor,
+ field: str,
+) -> descriptor_mod.FieldDescriptor:
+ if isinstance(descriptor, descriptor_mod.FieldDescriptor):
+ # descriptor is a FieldDescriptor
+ return descriptor.message_type.fields_by_name[field]
+ elif isinstance(descriptor, descriptor_mod.Descriptor):
+ # descriptor is a Descriptor
+ return descriptor.fields_by_name[field]
+ else:
+ raise NotImplementedError(f"Type {type(descriptor)} is not yet supported.")
+
+
+def is_map_field(field: descriptor_mod.FieldDescriptor) -> bool:
+ return (
+ field.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE
+ and field.message_type.has_options
+ and field.message_type.GetOptions().map_entry
+ )
+
+
+@lru_cache(maxsize=1)
+def http_status_to_grpc_status_map() -> dict[Enum, grpc.StatusCode]:
+ # Maps HTTP status code to grpc.StatusCode
+ from http import HTTPStatus
+
+ return {
+ HTTPStatus.OK: grpc.StatusCode.OK,
+ HTTPStatus.UNAUTHORIZED: grpc.StatusCode.UNAUTHENTICATED,
+ HTTPStatus.FORBIDDEN: grpc.StatusCode.PERMISSION_DENIED,
+ HTTPStatus.NOT_FOUND: grpc.StatusCode.UNIMPLEMENTED,
+ HTTPStatus.TOO_MANY_REQUESTS: grpc.StatusCode.UNAVAILABLE,
+ HTTPStatus.BAD_GATEWAY: grpc.StatusCode.UNAVAILABLE,
+ HTTPStatus.SERVICE_UNAVAILABLE: grpc.StatusCode.UNAVAILABLE,
+ HTTPStatus.GATEWAY_TIMEOUT: grpc.StatusCode.DEADLINE_EXCEEDED,
+ HTTPStatus.BAD_REQUEST: grpc.StatusCode.INVALID_ARGUMENT,
+ HTTPStatus.INTERNAL_SERVER_ERROR: grpc.StatusCode.INTERNAL,
+ HTTPStatus.UNPROCESSABLE_ENTITY: grpc.StatusCode.FAILED_PRECONDITION,
+ }
+
+
+@lru_cache(maxsize=1)
+def grpc_status_to_http_status_map() -> dict[grpc.StatusCode, Enum]:
+ return {v: k for k, v in http_status_to_grpc_status_map().items()}
+
+
+@lru_cache(maxsize=1)
+def filetype_pb_to_mimetype_map() -> dict[pb.File.FileType.ValueType, str]:
+ return {
+ pb.File.FILE_TYPE_CSV: "text/csv",
+ pb.File.FILE_TYPE_PLAINTEXT: "text/plain",
+ pb.File.FILE_TYPE_JSON: "application/json",
+ pb.File.FILE_TYPE_BYTES: "application/octet-stream",
+ pb.File.FILE_TYPE_PDF: "application/pdf",
+ pb.File.FILE_TYPE_PNG: "image/png",
+ pb.File.FILE_TYPE_JPEG: "image/jpeg",
+ pb.File.FILE_TYPE_GIF: "image/gif",
+ pb.File.FILE_TYPE_TIFF: "image/tiff",
+ pb.File.FILE_TYPE_BMP: "image/bmp",
+ pb.File.FILE_TYPE_WEBP: "image/webp",
+ pb.File.FILE_TYPE_SVG: "image/svg+xml",
+ }
+
+
+@lru_cache(maxsize=1)
+def mimetype_to_filetype_pb_map() -> dict[str, pb.File.FileType.ValueType]:
+ return {v: k for k, v in filetype_pb_to_mimetype_map().items()}
+
+
+def grpc_status_code(err: BentoMLException) -> grpc.StatusCode:
+ """
+ Convert BentoMLException.error_code to grpc.StatusCode.
+ """
+ return http_status_to_grpc_status_map().get(err.error_code, grpc.StatusCode.UNKNOWN)
+
+
+def to_http_status(status_code: grpc.StatusCode) -> int:
+ """
+ Convert grpc.StatusCode to HTTPStatus.
+ """
+ status = grpc_status_to_http_status_map().get(
+ status_code, HTTPStatus.INTERNAL_SERVER_ERROR
+ )
+
+ return status.value
+
+
+@dataclass
+class MethodName:
+ """
+ Represents a gRPC method name.
+
+ Attributes:
+ package: This is defined by `package foo.bar`, designation in the protocol buffer definition
+ service: service name in protocol buffer definition (eg: service SearchService { ... })
+ method: method name
+ """
+
+ package: str = ""
+ service: str = ""
+ method: str = ""
+
+ @property
+ def fully_qualified_service(self):
+ """return the service name prefixed with package"""
+ return f"{self.package}.{self.service}" if self.package else self.service
+
+
+def parse_method_name(method_name: str) -> tuple[MethodName, bool]:
+ """
+ Infers the grpc service and method name from the handler_call_details.
+ e.g. /package.ServiceName/MethodName
+ """
+ method = method_name.split("/", maxsplit=2)
+ # sanity check for method.
+ if len(method) != 3:
+ return MethodName(), False
+ _, package_service, method = method
+ *packages, service = package_service.rsplit(".", maxsplit=1)
+ package = packages[0] if packages else ""
+ return MethodName(package, service, method), True
+
+
+def wrap_rpc_handler(
+ wrapper: t.Callable[..., t.Any],
+ handler: RpcMethodHandler | None,
+) -> RpcMethodHandler | None:
+ if not handler:
+ return None
+ if not handler.request_streaming and not handler.response_streaming:
+ assert handler.unary_unary
+ return handler._replace(unary_unary=wrapper(handler.unary_unary))
+ elif not handler.request_streaming and handler.response_streaming:
+ assert handler.unary_stream
+ return handler._replace(unary_stream=wrapper(handler.unary_stream))
+ elif handler.request_streaming and not handler.response_streaming:
+ assert handler.stream_unary
+ return handler._replace(stream_unary=wrapper(handler.stream_unary))
+ elif handler.request_streaming and handler.response_streaming:
+ assert handler.stream_stream
+ return handler._replace(stream_stream=wrapper(handler.stream_stream))
+ else:
+ raise RuntimeError(f"RPC method handler {handler} does not exist.") from None
diff --git a/bentoml/grpc/utils/_import_hook.py b/bentoml/grpc/utils/_import_hook.py
new file mode 100644
index 0000000000..285e0abec0
--- /dev/null
+++ b/bentoml/grpc/utils/_import_hook.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from pathlib import Path
+
+if TYPE_CHECKING:
+ import types
+
+
+def import_generated_stubs(
+ version: str = "v1alpha1",
+ file: str = "service.proto",
+) -> tuple[types.ModuleType, types.ModuleType]:
+ """
+ Import generated stubs.
+
+ Args:
+ version: The version of the proto file to import.
+ file: The name of the proto file to import.
+
+ Returns:
+ A tuple of the generated stubs for the proto file.
+
+ Examples:
+
+ .. code-block:: python
+
+ from bentoml.grpc.utils import import_generated_stubs
+
+ # given proto file bentoml/grpc/v1alpha2/service.proto exists
+ pb, services = import_generated_stubs(version="v1alpha2", file="service.proto")
+ """
+ # generate git root from this file's path
+ from bentoml._internal.utils import LazyLoader
+
+ GIT_ROOT = Path(__file__).parent.parent.parent.parent
+
+ exception_message = f"Generated stubs for '{version}/{file}' are missing. To generate stubs, run '{GIT_ROOT}/scripts/generate_grpc_stubs.sh'"
+ file = file.split(".")[0]
+
+ service_pb2 = LazyLoader(
+ f"{file}_pb2",
+ globals(),
+ f"bentoml.grpc.{version}.{file}_pb2",
+ exc_msg=exception_message,
+ )
+ service_pb2_grpc = LazyLoader(
+ f"{file}_pb2_grpc",
+ globals(),
+ f"bentoml.grpc.{version}.{file}_pb2_grpc",
+ exc_msg=exception_message,
+ )
+ return service_pb2, service_pb2_grpc
+
+
+def import_grpc() -> tuple[types.ModuleType, types.ModuleType]:
+ from bentoml._internal.utils import LazyLoader
+
+ exception_message = "'grpcio' is required for gRPC support. Install with 'pip install bentoml[grpc]'."
+ grpc = LazyLoader(
+ "grpc",
+ globals(),
+ "grpc",
+ exc_msg=exception_message,
+ )
+ aio = LazyLoader("aio", globals(), "grpc.aio", exc_msg=exception_message)
+ return grpc, aio
diff --git a/bentoml/grpc/v1alpha1/__init__.py b/bentoml/grpc/v1alpha1/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/bentoml/grpc/v1alpha1/service.proto b/bentoml/grpc/v1alpha1/service.proto
new file mode 100644
index 0000000000..c166ee715c
--- /dev/null
+++ b/bentoml/grpc/v1alpha1/service.proto
@@ -0,0 +1,279 @@
+syntax = "proto3";
+
+package bentoml.grpc.v1alpha1;
+
+import "google/protobuf/struct.proto";
+import "google/protobuf/wrappers.proto";
+
+// cc_enable_arenas pre-allocate memory for given message to improve speed. (C++ only)
+option cc_enable_arenas = true;
+option go_package = "github.com/bentoml/grpc/v1alpha1";
+option java_multiple_files = true;
+option java_outer_classname = "ServiceProto";
+option java_package = "com.bentoml.grpc.v1alpha1";
+option objc_class_prefix = "SVC";
+option py_generic_services = true;
+
+// a gRPC BentoServer.
+service BentoService {
+ // Call handles methodcaller of given API entrypoint.
+ rpc Call(Request) returns (Response) {}
+}
+
+// Request message for incoming Call.
+message Request {
+ // api_name defines the API entrypoint to call.
+ // api_name is the name of the function defined in bentoml.Service.
+ // Example:
+ //
+ // @svc.api(input=NumpyNdarray(), output=File())
+ // def predict(input: NDArray[float]) -> bytes:
+ // ...
+ //
+ // api_name is "predict" in this case.
+ string api_name = 1;
+
+ oneof content {
+ // NDArray represents a n-dimensional array of arbitrary type.
+ NDArray ndarray = 3;
+
+ // DataFrame represents any tabular data type. We are using
+ // DataFrame as a trivial representation for tabular type.
+ DataFrame dataframe = 5;
+
+ // Series portrays a series of values. This can be used for
+ // representing Series types in tabular data.
+ Series series = 6;
+
+ // File represents for any arbitrary file type. This can be
+ // plaintext, image, video, audio, etc.
+ File file = 7;
+
+ // Text represents a string inputs.
+ google.protobuf.StringValue text = 8;
+
+ // JSON is represented by using google.protobuf.Value.
+ // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto
+ google.protobuf.Value json = 9;
+
+ // Multipart represents a multipart message.
+ // It comprises of a mapping from given type name to a subset of aforementioned types.
+ Multipart multipart = 10;
+
+ // serialized_bytes is for data serialized in BentoML's internal serialization format.
+ bytes serialized_bytes = 2;
+ }
+
+ // Tensor is similiar to ndarray but with a name
+ // We are reserving it for now for future use.
+ // repeated Tensor tensors = 4;
+ reserved 4, 11 to 13;
+}
+
+// Request message for incoming Call.
+message Response {
+ oneof content {
+ // NDArray represents a n-dimensional array of arbitrary type.
+ NDArray ndarray = 1;
+
+ // DataFrame represents any tabular data type. We are using
+ // DataFrame as a trivial representation for tabular type.
+ DataFrame dataframe = 3;
+
+ // Series portrays a series of values. This can be used for
+ // representing Series types in tabular data.
+ Series series = 5;
+
+ // File represents for any arbitrary file type. This can be
+ // plaintext, image, video, audio, etc.
+ File file = 6;
+
+ // Text represents a string inputs.
+ google.protobuf.StringValue text = 7;
+
+ // JSON is represented by using google.protobuf.Value.
+ // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto
+ google.protobuf.Value json = 8;
+
+ // Multipart represents a multipart message.
+ // It comprises of a mapping from given type name to a subset of aforementioned types.
+ Multipart multipart = 9;
+
+ // serialized_bytes is for data serialized in BentoML's internal serialization format.
+ bytes serialized_bytes = 2;
+ }
+ // Tensor is similiar to ndarray but with a name
+ // We are reserving it for now for future use.
+ // repeated Tensor tensors = 4;
+ reserved 4, 10 to 13;
+}
+
+// Part represents possible value types for multipart message.
+// These are the same as the types in Request message.
+message Part {
+ oneof representation {
+ // NDArray represents a n-dimensional array of arbitrary type.
+ NDArray ndarray = 1;
+
+ // DataFrame represents any tabular data type. We are using
+ // DataFrame as a trivial representation for tabular type.
+ DataFrame dataframe = 3;
+
+ // Series portrays a series of values. This can be used for
+ // representing Series types in tabular data.
+ Series series =5;
+
+ // File represents for any arbitrary file type. This can be
+ // plaintext, image, video, audio, etc.
+ File file = 6;
+
+ // Text represents a string inputs.
+ google.protobuf.StringValue text = 7;
+
+ // JSON is represented by using google.protobuf.Value.
+ // see https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/struct.proto
+ google.protobuf.Value json = 8;
+ }
+
+ // Tensor is similiar to ndarray but with a name
+ // We are reserving it for now for future use.
+ // Tensor tensors = 4;
+ reserved 2, 9 to 13;
+}
+
+// Multipart represents a multipart message.
+// It comprises of a mapping from given type name to a subset of aforementioned types.
+message Multipart {
+ map fields = 1;
+}
+
+// File represents for any arbitrary file type. This can be
+// plaintext, image, video, audio, etc.
+message File {
+ // FileType represents possible file type to be handled by BentoML.
+ // Currently, we only support plaintext (Text()), image (Image()), and file (File()).
+ // TODO: support audio and video streaming file types.
+ enum FileType {
+ FILE_TYPE_UNSPECIFIED = 0;
+
+ // file types
+ FILE_TYPE_CSV = 1;
+ FILE_TYPE_PLAINTEXT = 2;
+ FILE_TYPE_JSON = 3;
+ FILE_TYPE_BYTES = 4;
+ FILE_TYPE_PDF = 5;
+
+ // image types
+ FILE_TYPE_PNG = 6;
+ FILE_TYPE_JPEG = 7;
+ FILE_TYPE_GIF = 8;
+ FILE_TYPE_BMP = 9;
+ FILE_TYPE_TIFF = 10;
+ FILE_TYPE_WEBP = 11;
+ FILE_TYPE_SVG = 12;
+ }
+
+ // optional type of file, let it be csv, text, parquet, etc.
+ optional FileType kind = 1;
+
+ // contents of file as bytes.
+ bytes content = 2;
+}
+
+// DataFrame represents any tabular data type. We are using
+// DataFrame as a trivial representation for tabular type.
+// This message carries given implementation of tabular data based on given orientation.
+// TODO: support index, records, etc.
+message DataFrame {
+ // columns name
+ repeated string column_names = 1;
+
+ // columns orient.
+ // { column ↠ { index ↠ value } }
+ repeated Series columns = 2;
+}
+
+// Series portrays a series of values. This can be used for
+// representing Series types in tabular data.
+message Series {
+ // A bool parameter value
+ repeated bool bool_values = 1 [packed = true];
+
+ // A float parameter value
+ repeated float float_values = 2 [packed = true];
+
+ // A int32 parameter value
+ repeated int32 int32_values = 3 [packed = true];
+
+ // A int64 parameter value
+ repeated int64 int64_values = 6 [packed = true];
+
+ // A string parameter value
+ repeated string string_values = 5;
+
+ // represents a double parameter value.
+ repeated double double_values = 4 [packed = true];
+}
+
+// NDArray represents a n-dimensional array of arbitrary type.
+message NDArray {
+ // Represents data type of a given array.
+ enum DType {
+ // Represents a None type.
+ DTYPE_UNSPECIFIED = 0;
+
+ // Represents an float type.
+ DTYPE_FLOAT = 1;
+
+ // Represents an double type.
+ DTYPE_DOUBLE = 2;
+
+ // Represents a bool type.
+ DTYPE_BOOL = 3;
+
+ // Represents an int32 type.
+ DTYPE_INT32 = 4;
+
+ // Represents an int64 type.
+ DTYPE_INT64 = 5;
+
+ // Represents a uint32 type.
+ DTYPE_UINT32 = 6;
+
+ // Represents a uint64 type.
+ DTYPE_UINT64 = 7;
+
+ // Represents a string type.
+ DTYPE_STRING = 8;
+ }
+
+ // DTYPE is the data type of given array
+ DType dtype = 1;
+
+ // shape is the shape of given array.
+ repeated int32 shape = 2;
+
+ // represents a string parameter value.
+ repeated string string_values = 5;
+
+ // represents a float parameter value.
+ repeated float float_values = 3 [packed = true];
+
+ // represents a double parameter value.
+ repeated double double_values = 4 [packed = true];
+
+ // represents a bool parameter value.
+ repeated bool bool_values = 6 [packed = true];
+
+ // represents a int32 parameter value.
+ repeated int32 int32_values = 7 [packed = true];
+
+ // represents a int64 parameter value.
+ repeated int64 int64_values = 8 [packed = true];
+
+ // represents a uint32 parameter value.
+ repeated uint32 uint32_values = 9 [packed = true];
+
+ // represents a uint64 parameter value.
+ repeated uint64 uint64_values = 10 [packed = true];
+}
diff --git a/bentoml/grpc/v1alpha1/service_test.proto b/bentoml/grpc/v1alpha1/service_test.proto
new file mode 100644
index 0000000000..8fccaac169
--- /dev/null
+++ b/bentoml/grpc/v1alpha1/service_test.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package bentoml.testing.v1alpha1;
+
+option cc_enable_arenas = true;
+option go_package = "github.com/bentoml/testing/v1alpha1";
+option optimize_for = SPEED;
+option py_generic_services = true;
+
+// Represents a request for TestService.
+message ExecuteRequest {
+ string input = 1;
+}
+
+// Represents a response from TestService.
+message ExecuteResponse {
+ string output = 1;
+}
+
+// Use for testing interceptors per RPC call.
+service TestService {
+ // Unary API
+ rpc Execute(ExecuteRequest) returns (ExecuteResponse);
+}
diff --git a/bentoml/serve.py b/bentoml/serve.py
index 0d2541cf8f..fe34741577 100644
--- a/bentoml/serve.py
+++ b/bentoml/serve.py
@@ -9,27 +9,35 @@
import logging
import tempfile
import contextlib
+from typing import TYPE_CHECKING
from pathlib import Path
+from functools import partial
import psutil
from simple_di import inject
from simple_di import Provide
-from bentoml import load
-
-from ._internal.log import SERVER_LOGGING_CONFIG
-from ._internal.utils import reserve_free_port
-from ._internal.resource import CpuResource
-from ._internal.utils.uri import path_to_uri
-from ._internal.utils.circus import create_standalone_arbiter
-from ._internal.utils.analytics import track_serve
+from ._internal.utils import experimental
from ._internal.configuration.containers import BentoMLContainer
+if TYPE_CHECKING:
+ from circus.watcher import Watcher
+
+
logger = logging.getLogger(__name__)
+PROMETHEUS_MESSAGE = (
+ 'Prometheus metrics for %s BentoServer from "%s" can be accessed at %s.'
+)
SCRIPT_RUNNER = "bentoml_cli.worker.runner"
SCRIPT_API_SERVER = "bentoml_cli.worker.http_api_server"
SCRIPT_DEV_API_SERVER = "bentoml_cli.worker.http_dev_api_server"
+SCRIPT_GRPC_API_SERVER = "bentoml_cli.worker.grpc_api_server"
+SCRIPT_GRPC_DEV_API_SERVER = "bentoml_cli.worker.grpc_dev_api_server"
+SCRIPT_GRPC_PROMETHEUS_SERVER = "bentoml_cli.worker.grpc_prometheus_server"
+
+API_SERVER_NAME = "_bento_api_server"
+PROMETHEUS_SERVER_NAME = "_prometheus_server"
@inject
@@ -75,12 +83,90 @@ def ensure_prometheus_dir(
return alternative
+def create_watcher(
+ name: str,
+ args: list[str],
+ *,
+ use_sockets: bool = True,
+ **kwargs: t.Any,
+) -> Watcher:
+ from circus.watcher import Watcher
+
+ return Watcher(
+ name=name,
+ cmd=sys.executable,
+ args=args,
+ copy_env=True,
+ stop_children=True,
+ use_sockets=use_sockets,
+ **kwargs,
+ )
+
+
+def log_grpcui_instruction(port: int) -> None:
+ # logs instruction on how to start gRPCUI
+ docker_run = partial(
+ "docker run -it --rm {network_args} fullstorydev/grpcui -plaintext {platform_deps}:{port}".format,
+ port=port,
+ )
+ message = "To use gRPC UI, run the following command: '%s', followed by opening 'http://0.0.0.0:8080' in your browser of choice."
+
+ linux_instruction = docker_run(
+ platform_deps="0.0.0.0", network_args="--network=host"
+ )
+ mac_win_instruction = docker_run(
+ platform_deps="host.docker.internal", network_args="-p 8080:8080"
+ )
+
+ if os.path.exists("/.dockerenv"):
+ logger.info(
+ "Detected running Bento inside an OCI container. In order to use gRPC UI, do as follows: If your local machine are either MacOS or Windows , then use '%s'. Otherwise use '%s'.",
+ mac_win_instruction,
+ linux_instruction,
+ )
+ elif psutil.WINDOWS or psutil.MACOS:
+ logger.info(message, mac_win_instruction)
+ elif psutil.LINUX:
+ logger.info(message, linux_instruction)
+
+
+def construct_ssl_args(
+ ssl_certfile: str | None,
+ ssl_keyfile: str | None,
+ ssl_keyfile_password: str | None,
+ ssl_version: int | None,
+ ssl_cert_reqs: int | None,
+ ssl_ca_certs: str | None,
+ ssl_ciphers: str | None,
+) -> list[str]:
+ args: list[str] = []
+
+ # Add optional SSL args if they exist
+ if ssl_certfile:
+ args.extend(["--ssl-certfile", str(ssl_certfile)])
+ if ssl_keyfile:
+ args.extend(["--ssl-keyfile", str(ssl_keyfile)])
+ if ssl_keyfile_password:
+ args.extend(["--ssl-keyfile-password", ssl_keyfile_password])
+ if ssl_ca_certs:
+ args.extend(["--ssl-ca-certs", str(ssl_ca_certs)])
+
+ # match with default uvicorn values.
+ if ssl_version:
+ args.extend(["--ssl-version", str(ssl_version)])
+ if ssl_cert_reqs:
+ args.extend(["--ssl-cert-reqs", str(ssl_cert_reqs)])
+ if ssl_ciphers:
+ args.extend(["--ssl-ciphers", ssl_ciphers])
+ return args
+
+
@inject
-def serve_development(
+def serve_http_development(
bento_identifier: str,
working_dir: str,
- port: int = Provide[BentoMLContainer.api_server_config.port],
- host: str = Provide[BentoMLContainer.api_server_config.host],
+ port: int = Provide[BentoMLContainer.http.port],
+ host: str = Provide[BentoMLContainer.http.host],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
ssl_certfile: str | None = Provide[BentoMLContainer.api_server_config.ssl.certfile],
@@ -94,70 +180,61 @@ def serve_development(
ssl_ciphers: str | None = Provide[BentoMLContainer.api_server_config.ssl.ciphers],
reload: bool = False,
) -> None:
- working_dir = os.path.realpath(os.path.expanduser(working_dir))
- svc = load(bento_identifier, working_dir=working_dir) # verify service loading
+ from circus.sockets import CircusSocket
- from circus.sockets import CircusSocket # type: ignore
- from circus.watcher import Watcher # type: ignore
-
- prometheus_dir = ensure_prometheus_dir()
+ from bentoml import load
- watchers: t.List[Watcher] = []
+ from ._internal.log import SERVER_LOGGING_CONFIG
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
- circus_sockets: t.List[CircusSocket] = []
- circus_sockets.append(
- CircusSocket(
- name="_bento_api_server",
- host=host,
- port=port,
- backlog=backlog,
- )
- )
+ working_dir = os.path.realpath(os.path.expanduser(working_dir))
+ svc = load(bento_identifier, working_dir=working_dir)
- args: list[str | int] = [
- "-m",
- SCRIPT_DEV_API_SERVER,
- bento_identifier,
- "--fd",
- "$(circus.sockets._bento_api_server)",
- "--working-dir",
- working_dir,
- "--prometheus-dir",
- prometheus_dir,
- ]
+ prometheus_dir = ensure_prometheus_dir()
- # Add optional SSL args if they exist
- if ssl_certfile:
- args.extend(["--ssl-certfile", str(ssl_certfile)])
- if ssl_keyfile:
- args.extend(["--ssl-keyfile", str(ssl_keyfile)])
- if ssl_keyfile_password:
- args.extend(["--ssl-keyfile-password", ssl_keyfile_password])
- if ssl_ca_certs:
- args.extend(["--ssl-ca-certs", str(ssl_ca_certs)])
+ watchers: list[Watcher] = []
- # match with default uvicorn values.
- if ssl_version:
- args.extend(["--ssl-version", int(ssl_version)])
- if ssl_cert_reqs:
- args.extend(["--ssl-cert-reqs", int(ssl_cert_reqs)])
- if ssl_ciphers:
- args.extend(["--ssl-ciphers", ssl_ciphers])
+ circus_sockets: list[CircusSocket] = [
+ CircusSocket(name=API_SERVER_NAME, host=host, port=port, backlog=backlog)
+ ]
watchers.append(
- Watcher(
+ create_watcher(
name="dev_api_server",
- cmd=sys.executable,
- args=args,
- copy_env=True,
- stop_children=True,
- use_sockets=True,
+ args=[
+ "-m",
+ SCRIPT_DEV_API_SERVER,
+ bento_identifier,
+ "--fd",
+ f"$(circus.sockets.{API_SERVER_NAME})",
+ "--working-dir",
+ working_dir,
+ "--prometheus-dir",
+ prometheus_dir,
+ *construct_ssl_args(
+ ssl_certfile=ssl_certfile,
+ ssl_keyfile=ssl_keyfile,
+ ssl_keyfile_password=ssl_keyfile_password,
+ ssl_version=ssl_version,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ ssl_ciphers=ssl_ciphers,
+ ),
+ ],
working_dir=working_dir,
# we don't want to close stdin for child process in case user use debugger.
# See https://circus.readthedocs.io/en/latest/for-ops/configuration/
close_child_stdin=False,
)
)
+ if BentoMLContainer.api_server_config.metrics.enabled.get():
+ logger.info(
+ PROMETHEUS_MESSAGE,
+ "HTTP",
+ bento_identifier,
+ f"http://{host}:{port}/metrics",
+ )
plugins = []
if reload:
@@ -169,7 +246,7 @@ def serve_development(
"--reload is passed. BentoML will watch file changes based on 'bentofile.yaml' and '.bentoignore' respectively."
)
- # initialize dictionary with {} is faster than using dict()
+ # NOTE: {} is faster than dict()
plugins = [
# reloader plugin
{
@@ -178,6 +255,7 @@ def serve_development(
"bentoml_home": bentoml_home,
},
]
+
arbiter = create_standalone_arbiter(
watchers,
sockets=circus_sockets,
@@ -190,8 +268,11 @@ def serve_development(
with track_serve(svc, production=False):
arbiter.start(
cb=lambda _: logger.info( # type: ignore
- f'Starting development BentoServer from "{bento_identifier}" '
- f"running on http://{host}:{port} (Press CTRL+C to quit)"
+ 'Starting development %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)',
+ "HTTP",
+ bento_identifier,
+ host,
+ port,
),
)
@@ -200,11 +281,11 @@ def serve_development(
@inject
-def serve_production(
+def serve_http_production(
bento_identifier: str,
working_dir: str,
- port: int = Provide[BentoMLContainer.api_server_config.port],
- host: str = Provide[BentoMLContainer.api_server_config.host],
+ port: int = Provide[BentoMLContainer.http.port],
+ host: str = Provide[BentoMLContainer.http.host],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
api_workers: int | None = None,
ssl_certfile: str | None = Provide[BentoMLContainer.api_server_config.ssl.certfile],
@@ -217,11 +298,18 @@ def serve_production(
ssl_ca_certs: str | None = Provide[BentoMLContainer.api_server_config.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.api_server_config.ssl.ciphers],
) -> None:
+ from bentoml import load
+
+ from ._internal.utils import reserve_free_port
+ from ._internal.resource import CpuResource
+ from ._internal.utils.uri import path_to_uri
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
+
working_dir = os.path.realpath(os.path.expanduser(working_dir))
svc = load(bento_identifier, working_dir=working_dir, standalone_load=True)
from circus.sockets import CircusSocket # type: ignore
- from circus.watcher import Watcher # type: ignore
watchers: t.List[Watcher] = []
circus_socket_map: t.Dict[str, CircusSocket] = {}
@@ -245,9 +333,8 @@ def serve_production(
)
watchers.append(
- Watcher(
+ create_watcher(
name=f"runner_{runner.name}",
- cmd=sys.executable,
args=[
"-m",
SCRIPT_RUNNER,
@@ -265,10 +352,7 @@ def serve_production(
"--prometheus-dir",
prometheus_dir,
],
- copy_env=True,
- stop_children=True,
working_dir=working_dir,
- use_sockets=True,
numprocesses=runner.scheduled_worker_count,
)
)
@@ -289,9 +373,8 @@ def serve_production(
)
watchers.append(
- Watcher(
+ create_watcher(
name=f"runner_{runner.name}",
- cmd=sys.executable,
args=[
"-m",
SCRIPT_RUNNER,
@@ -304,15 +387,10 @@ def serve_production(
working_dir,
"--no-access-log",
"--worker-id",
- "$(circus.wid)",
+ "$(CIRCUS.WID)",
"--worker-env-map",
json.dumps(runner.scheduled_worker_env_map),
- "--prometheus-dir",
- prometheus_dir,
],
- copy_env=True,
- stop_children=True,
- use_sockets=True,
working_dir=working_dir,
numprocesses=runner.scheduled_worker_count,
)
@@ -323,64 +401,57 @@ def serve_production(
else:
raise NotImplementedError("Unsupported platform: {}".format(sys.platform))
- logger.debug("Runner map: %s", runner_bind_map)
+ logger.debug(f"Runner map: {runner_bind_map}")
- circus_socket_map["_bento_api_server"] = CircusSocket(
- name="_bento_api_server",
+ circus_socket_map[API_SERVER_NAME] = CircusSocket(
+ name=API_SERVER_NAME,
host=host,
port=port,
backlog=backlog,
)
- args: list[str | int] = [
- "-m",
- SCRIPT_API_SERVER,
- bento_identifier,
- "--fd",
- "$(circus.sockets._bento_api_server)",
- "--runner-map",
- json.dumps(runner_bind_map),
- "--working-dir",
- working_dir,
- "--backlog",
- f"{backlog}",
- "--worker-id",
- "$(CIRCUS.WID)",
- "--prometheus-dir",
- prometheus_dir,
- ]
-
- # Add optional SSL args if they exist
- if ssl_certfile:
- args.extend(["--ssl-certfile", str(ssl_certfile)])
- if ssl_keyfile:
- args.extend(["--ssl-keyfile", str(ssl_keyfile)])
- if ssl_keyfile_password:
- args.extend(["--ssl-keyfile-password", ssl_keyfile_password])
- if ssl_ca_certs:
- args.extend(["--ssl-ca-certs", str(ssl_ca_certs)])
-
- # match with default uvicorn values.
- if ssl_version:
- args.extend(["--ssl-version", int(ssl_version)])
- if ssl_cert_reqs:
- args.extend(["--ssl-cert-reqs", int(ssl_cert_reqs)])
- if ssl_ciphers:
- args.extend(["--ssl-ciphers", ssl_ciphers])
-
watchers.append(
- Watcher(
+ create_watcher(
name="api_server",
- cmd=sys.executable,
- args=args,
- copy_env=True,
- numprocesses=api_workers or math.ceil(CpuResource.from_system()),
- stop_children=True,
- use_sockets=True,
+ args=[
+ "-m",
+ SCRIPT_API_SERVER,
+ bento_identifier,
+ "--fd",
+ f"$(circus.sockets.{API_SERVER_NAME})",
+ "--runner-map",
+ json.dumps(runner_bind_map),
+ "--working-dir",
+ working_dir,
+ "--backlog",
+ f"{backlog}",
+ "--worker-id",
+ "$(CIRCUS.WID)",
+ "--prometheus-dir",
+ prometheus_dir,
+ *construct_ssl_args(
+ ssl_certfile=ssl_certfile,
+ ssl_keyfile=ssl_keyfile,
+ ssl_keyfile_password=ssl_keyfile_password,
+ ssl_version=ssl_version,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ ssl_ciphers=ssl_ciphers,
+ ),
+ ],
working_dir=working_dir,
+ numprocesses=api_workers or math.ceil(CpuResource.from_system()),
)
)
+ if BentoMLContainer.api_server_config.metrics.enabled.get():
+ logger.info(
+ PROMETHEUS_MESSAGE,
+ "HTTP",
+ bento_identifier,
+ f"http://{host}:{port}/metrics",
+ )
+
arbiter = create_standalone_arbiter(
watchers=watchers,
sockets=list(circus_socket_map.values()),
@@ -390,8 +461,400 @@ def serve_production(
try:
arbiter.start(
cb=lambda _: logger.info( # type: ignore
- f'Starting production BentoServer from "{bento_identifier}" '
- f"running on http://{host}:{port} (Press CTRL+C to quit)"
+ 'Starting production %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)',
+ "HTTP",
+ bento_identifier,
+ host,
+ port,
+ ),
+ )
+ finally:
+ if uds_path is not None:
+ shutil.rmtree(uds_path)
+
+
+@experimental
+@inject
+def serve_grpc_development(
+ bento_identifier: str,
+ working_dir: str,
+ port: int = Provide[BentoMLContainer.grpc.port],
+ host: str = Provide[BentoMLContainer.grpc.host],
+ bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
+ reload: bool = False,
+ reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
+ max_concurrent_streams: int
+ | None = Provide[BentoMLContainer.grpc.max_concurrent_streams],
+ backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
+) -> None:
+ from circus.sockets import CircusSocket
+
+ from bentoml import load
+
+ from ._internal.log import SERVER_LOGGING_CONFIG
+ from ._internal.utils import reserve_free_port
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
+
+ working_dir = os.path.realpath(os.path.expanduser(working_dir))
+ svc = load(bento_identifier, working_dir=working_dir)
+
+ prometheus_dir = ensure_prometheus_dir()
+
+ watchers: list[Watcher] = []
+
+ circus_sockets: list[CircusSocket] = []
+
+ if not reflection:
+ logger.info(
+ "'reflection' is disabled by default. Tools such as gRPCUI or grpcurl relies on server reflection. To use those, pass '--enable-reflection' to the CLI."
+ )
+ else:
+ log_grpcui_instruction(port)
+
+ with contextlib.ExitStack() as port_stack:
+ api_port = port_stack.enter_context(
+ reserve_free_port(host, port=port, enable_so_reuseport=True)
+ )
+
+ args = [
+ "-m",
+ SCRIPT_GRPC_DEV_API_SERVER,
+ bento_identifier,
+ "--host",
+ host,
+ "--port",
+ str(api_port),
+ "--working-dir",
+ working_dir,
+ ]
+
+ if reflection:
+ args.append("--enable-reflection")
+ if max_concurrent_streams:
+ args.extend(
+ [
+ "--max-concurrent-streams",
+ str(max_concurrent_streams),
+ ]
+ )
+
+ # use circus_sockets. CircusSocket support for SO_REUSEPORT
+ watchers.append(
+ create_watcher(
+ name="grpc_dev_api_server",
+ args=args,
+ use_sockets=False,
+ working_dir=working_dir,
+ # we don't want to close stdin for child process in case user use debugger.
+ # See https://circus.readthedocs.io/en/latest/for-ops/configuration/
+ close_child_stdin=False,
+ )
+ )
+
+ if BentoMLContainer.api_server_config.metrics.enabled.get():
+ metrics_host = BentoMLContainer.grpc.metrics.host.get()
+ metrics_port = BentoMLContainer.grpc.metrics.port.get()
+
+ circus_sockets.append(
+ CircusSocket(
+ name=PROMETHEUS_SERVER_NAME,
+ host=metrics_host,
+ port=metrics_port,
+ backlog=backlog,
+ )
+ )
+
+ watchers.append(
+ create_watcher(
+ name="prom_server",
+ args=[
+ "-m",
+ SCRIPT_GRPC_PROMETHEUS_SERVER,
+ "--fd",
+ f"$(circus.sockets.{PROMETHEUS_SERVER_NAME})",
+ "--prometheus-dir",
+ prometheus_dir,
+ "--backlog",
+ f"{backlog}",
+ ],
+ working_dir=working_dir,
+ numprocesses=1,
+ singleton=True,
+ # we don't want to close stdin for child process in case user use debugger.
+ # See https://circus.readthedocs.io/en/latest/for-ops/configuration/
+ close_child_stdin=False,
+ )
+ )
+
+ logger.info(
+ PROMETHEUS_MESSAGE,
+ "gRPC",
+ bento_identifier,
+ f"http://{metrics_host}:{metrics_port}",
+ )
+
+ plugins = []
+
+ if reload:
+ if sys.platform == "win32":
+ logger.warning(
+ "Due to circus limitations, output from the reloader plugin will not be shown on Windows."
+ )
+ logger.debug(
+ "--reload is passed. BentoML will watch file changes based on 'bentofile.yaml' and '.bentoignore' respectively."
+ )
+
+ # NOTE: {} is faster than dict()
+ plugins = [
+ # reloader plugin
+ {
+ "use": "bentoml._internal.utils.circus.watchfilesplugin.ServiceReloaderPlugin",
+ "working_dir": working_dir,
+ "bentoml_home": bentoml_home,
+ },
+ ]
+
+ arbiter = create_standalone_arbiter(
+ watchers,
+ sockets=circus_sockets,
+ plugins=plugins,
+ debug=True if sys.platform != "win32" else False,
+ loggerconfig=SERVER_LOGGING_CONFIG,
+ loglevel="ERROR",
+ )
+
+ with track_serve(svc, production=False):
+ arbiter.start(
+ cb=lambda _: logger.info( # type: ignore
+ 'Starting development %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)',
+ "gRPC",
+ bento_identifier,
+ host,
+ port,
+ ),
+ )
+
+
+@experimental
+@inject
+def serve_grpc_production(
+ bento_identifier: str,
+ working_dir: str,
+ port: int = Provide[BentoMLContainer.grpc.port],
+ host: str = Provide[BentoMLContainer.grpc.host],
+ backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
+ api_workers: int | None = None,
+ reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
+ max_concurrent_streams: int
+ | None = Provide[BentoMLContainer.grpc.max_concurrent_streams],
+) -> None:
+ from bentoml import load
+ from bentoml.exceptions import UnprocessableEntity
+
+ from ._internal.utils import reserve_free_port
+ from ._internal.resource import CpuResource
+ from ._internal.utils.uri import path_to_uri
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
+
+ working_dir = os.path.realpath(os.path.expanduser(working_dir))
+ svc = load(bento_identifier, working_dir=working_dir, standalone_load=True)
+
+ from circus.sockets import CircusSocket # type: ignore
+
+ watchers: list[Watcher] = []
+ circus_socket_map: dict[str, CircusSocket] = {}
+ runner_bind_map: dict[str, str] = {}
+ uds_path = None
+
+ prometheus_dir = ensure_prometheus_dir()
+
+ # Check whether users are running --grpc on windows
+ # also raising warning if users running on MacOS or FreeBSD
+ if psutil.WINDOWS:
+ raise UnprocessableEntity(
+ "'grpc' is not supported on Windows with '--production'. The reason being SO_REUSEPORT socket option is only available on UNIX system, and gRPC implementation depends on this behaviour."
+ )
+ if psutil.MACOS or psutil.FREEBSD:
+ logger.warning(
+ "Due to gRPC implementation on exposing SO_REUSEPORT, '--production' behaviour on %s is not correct. We recommend to containerize BentoServer as a Linux container instead.",
+ "MacOS" if psutil.MACOS else "FreeBSD",
+ )
+
+ if psutil.POSIX:
+ # use AF_UNIX sockets for Circus
+ uds_path = tempfile.mkdtemp()
+ for runner in svc.runners:
+ sockets_path = os.path.join(uds_path, f"{id(runner)}.sock")
+ assert len(sockets_path) < MAX_AF_UNIX_PATH_LENGTH
+
+ runner_bind_map[runner.name] = path_to_uri(sockets_path)
+ circus_socket_map[runner.name] = CircusSocket(
+ name=runner.name,
+ path=sockets_path,
+ backlog=backlog,
+ )
+
+ watchers.append(
+ create_watcher(
+ name=f"runner_{runner.name}",
+ args=[
+ "-m",
+ SCRIPT_RUNNER,
+ bento_identifier,
+ "--runner-name",
+ runner.name,
+ "--fd",
+ f"$(circus.sockets.{runner.name})",
+ "--working-dir",
+ working_dir,
+ "--worker-id",
+ "$(CIRCUS.WID)",
+ "--worker-env-map",
+ json.dumps(runner.scheduled_worker_env_map),
+ ],
+ working_dir=working_dir,
+ numprocesses=runner.scheduled_worker_count,
+ )
+ )
+
+ elif psutil.WINDOWS:
+ # Windows doesn't (fully) support AF_UNIX sockets
+ with contextlib.ExitStack() as port_stack:
+ for runner in svc.runners:
+ runner_port = port_stack.enter_context(reserve_free_port())
+ runner_host = "127.0.0.1"
+
+ runner_bind_map[runner.name] = f"tcp://{runner_host}:{runner_port}"
+ circus_socket_map[runner.name] = CircusSocket(
+ name=runner.name,
+ host=runner_host,
+ port=runner_port,
+ backlog=backlog,
+ )
+
+ watchers.append(
+ create_watcher(
+ name=f"runner_{runner.name}",
+ args=[
+ "-m",
+ SCRIPT_RUNNER,
+ bento_identifier,
+ "--runner-name",
+ runner.name,
+ "--fd",
+ f"$(circus.sockets.{runner.name})",
+ "--working-dir",
+ working_dir,
+ "--no-access-log",
+ "--worker-id",
+ "$(circus.wid)",
+ "--worker-env-map",
+ json.dumps(runner.scheduled_worker_env_map),
+ "--prometheus-dir",
+ prometheus_dir,
+ ],
+ working_dir=working_dir,
+ numprocesses=runner.scheduled_worker_count,
+ )
+ )
+ # reserve one more to avoid conflicts
+ port_stack.enter_context(reserve_free_port())
+ else:
+ raise NotImplementedError("Unsupported platform: {}".format(sys.platform))
+
+ logger.debug(f"Runner map: {runner_bind_map}")
+
+ with contextlib.ExitStack() as port_stack:
+ api_port = port_stack.enter_context(
+ reserve_free_port(host, port=port, enable_so_reuseport=True)
+ )
+ args = [
+ "-m",
+ SCRIPT_GRPC_API_SERVER,
+ bento_identifier,
+ "--host",
+ host,
+ "--port",
+ str(api_port),
+ "--runner-map",
+ json.dumps(runner_bind_map),
+ "--working-dir",
+ working_dir,
+ "--worker-id",
+ "$(CIRCUS.WID)",
+ ]
+ if reflection:
+ args.append("--enable-reflection")
+
+ if max_concurrent_streams:
+ args.extend(
+ [
+ "--max-concurrent-streams",
+ str(max_concurrent_streams),
+ ]
+ )
+
+ watchers.append(
+ create_watcher(
+ name="grpc_api_server",
+ args=args,
+ use_sockets=False,
+ working_dir=working_dir,
+ numprocesses=api_workers or math.ceil(CpuResource.from_system()),
+ )
+ )
+
+ if BentoMLContainer.api_server_config.metrics.enabled.get():
+ metrics_host = BentoMLContainer.grpc.metrics.host.get()
+ metrics_port = BentoMLContainer.grpc.metrics.port.get()
+
+ circus_socket_map[PROMETHEUS_SERVER_NAME] = CircusSocket(
+ name=PROMETHEUS_SERVER_NAME,
+ host=metrics_host,
+ port=metrics_port,
+ backlog=backlog,
+ )
+
+ watchers.append(
+ create_watcher(
+ name="prom_server",
+ args=[
+ "-m",
+ SCRIPT_GRPC_PROMETHEUS_SERVER,
+ "--fd",
+ f"$(circus.sockets.{PROMETHEUS_SERVER_NAME})",
+ "--prometheus-dir",
+ prometheus_dir,
+ "--backlog",
+ f"{backlog}",
+ ],
+ working_dir=working_dir,
+ numprocesses=1,
+ singleton=True,
+ )
+ )
+
+ logger.info(
+ PROMETHEUS_MESSAGE,
+ "gRPC",
+ bento_identifier,
+ f"http://{metrics_host}:{metrics_port}",
+ )
+ arbiter = create_standalone_arbiter(
+ watchers=watchers, sockets=list(circus_socket_map.values())
+ )
+
+ with track_serve(svc, production=True):
+ try:
+ arbiter.start(
+ cb=lambda _: logger.info( # type: ignore
+ 'Starting production %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)',
+ "gRPC",
+ bento_identifier,
+ host,
+ port,
),
)
finally:
diff --git a/bentoml/start.py b/bentoml/start.py
index fc327c48d1..1541f61e89 100644
--- a/bentoml/start.py
+++ b/bentoml/start.py
@@ -4,80 +4,26 @@
import sys
import json
import math
-import shutil
import typing as t
import logging
-import tempfile
import contextlib
-from pathlib import Path
from simple_di import inject
from simple_di import Provide
-from bentoml import load
-
-from ._internal.utils import reserve_free_port
-from ._internal.resource import CpuResource
-from ._internal.utils.circus import create_standalone_arbiter
-from ._internal.utils.analytics import track_serve
from ._internal.configuration.containers import BentoMLContainer
logger = logging.getLogger(__name__)
SCRIPT_RUNNER = "bentoml_cli.worker.runner"
SCRIPT_API_SERVER = "bentoml_cli.worker.http_api_server"
-SCRIPT_DEV_API_SERVER = "bentoml_cli.worker.http_dev_api_server"
+SCRIPT_GRPC_API_SERVER = "bentoml_cli.worker.grpc_api_server"
+SCRIPT_GRPC_PROMETHEUS_SERVER = "bentoml_cli.worker.grpc_prometheus_server"
API_SERVER = "api_server"
RUNNER = "runner"
-@inject
-def ensure_prometheus_dir(
- directory: str = Provide[BentoMLContainer.prometheus_multiproc_dir],
- clean: bool = True,
- use_alternative: bool = True,
-) -> str:
- try:
- path = Path(directory)
- if path.exists():
- if not path.is_dir() or any(path.iterdir()):
- if clean:
- shutil.rmtree(str(path))
- path.mkdir()
- return str(path.absolute())
- else:
- raise RuntimeError(
- "Prometheus multiproc directory {} is not empty".format(path)
- )
- else:
- return str(path.absolute())
- else:
- path.mkdir(parents=True)
- return str(path.absolute())
- except shutil.Error as e:
- if not use_alternative:
- raise RuntimeError(
- f"Failed to clean the prometheus multiproc directory {directory}: {e}"
- )
- except OSError as e:
- if not use_alternative:
- raise RuntimeError(
- f"Failed to create the prometheus multiproc directory {directory}: {e}"
- )
- assert use_alternative
- alternative = tempfile.mkdtemp()
- logger.warning(
- f"Failed to ensure the prometheus multiproc directory {directory}, "
- f"using alternative: {alternative}",
- )
- BentoMLContainer.prometheus_multiproc_dir.set(alternative)
- return alternative
-
-
-MAX_AF_UNIX_PATH_LENGTH = 103
-
-
@inject
def start_runner_server(
bento_identifier: str,
@@ -90,6 +36,13 @@ def start_runner_server(
"""
Experimental API for serving a BentoML runner.
"""
+ from bentoml import load
+
+ from .serve import ensure_prometheus_dir
+ from ._internal.utils import reserve_free_port
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
+
working_dir = os.path.realpath(os.path.expanduser(working_dir))
svc = load(bento_identifier, working_dir=working_dir, standalone_load=True)
@@ -98,7 +51,6 @@ def start_runner_server(
watchers: t.List[Watcher] = []
circus_socket_map: t.Dict[str, CircusSocket] = {}
- uds_path = None
ensure_prometheus_dir()
@@ -147,25 +99,19 @@ def start_runner_server(
f"Runner {runner_name} not found in the service: `{bento_identifier}`, "
f"available runners: {[r.name for r in svc.runners]}"
)
-
arbiter = create_standalone_arbiter(
watchers=watchers,
sockets=list(circus_socket_map.values()),
)
-
with track_serve(svc, production=True, component=RUNNER):
- try:
- arbiter.start(
- cb=lambda _: logger.info( # type: ignore
- 'Starting RunnerServer from "%s"\n running on http://%s:%s (Press CTRL+C to quit)',
- bento_identifier,
- host,
- port,
- ),
- )
- finally:
- if uds_path is not None:
- shutil.rmtree(uds_path)
+ arbiter.start(
+ cb=lambda _: logger.info( # type: ignore
+ 'Starting RunnerServer from "%s" running on http://%s:%s (Press CTRL+C to quit)',
+ bento_identifier,
+ host,
+ port,
+ ),
+ )
@inject
@@ -187,14 +133,24 @@ def start_http_server(
ssl_ca_certs: str | None = Provide[BentoMLContainer.api_server_config.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.api_server_config.ssl.ciphers],
) -> None:
+ from bentoml import load
+
+ from .serve import create_watcher
+ from .serve import API_SERVER_NAME
+ from .serve import construct_ssl_args
+ from .serve import PROMETHEUS_MESSAGE
+ from .serve import ensure_prometheus_dir
+ from ._internal.resource import CpuResource
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
+
working_dir = os.path.realpath(os.path.expanduser(working_dir))
svc = load(bento_identifier, working_dir=working_dir, standalone_load=True)
runner_requirements = {runner.name for runner in svc.runners}
if not runner_requirements.issubset(set(runner_map)):
raise ValueError(
- f"{bento_identifier} requires runners {runner_requirements}, but only "
- f"{set(runner_map)} are provided"
+ f"{bento_identifier} requires runners {runner_requirements}, but only {set(runner_map)} are provided."
)
from circus.sockets import CircusSocket # type: ignore
@@ -202,81 +158,202 @@ def start_http_server(
watchers: t.List[Watcher] = []
circus_socket_map: t.Dict[str, CircusSocket] = {}
- uds_path = None
prometheus_dir = ensure_prometheus_dir()
logger.debug("Runner map: %s", runner_map)
- circus_socket_map["_bento_api_server"] = CircusSocket(
- name="_bento_api_server",
+ circus_socket_map[API_SERVER_NAME] = CircusSocket(
+ name=API_SERVER_NAME,
host=host,
port=port,
backlog=backlog,
)
- args: list[str | int] = [
- "-m",
- SCRIPT_API_SERVER,
- bento_identifier,
- "--fd",
- "$(circus.sockets._bento_api_server)",
- "--runner-map",
- json.dumps(runner_map),
- "--working-dir",
- working_dir,
- "--backlog",
- f"{backlog}",
- "--worker-id",
- "$(CIRCUS.WID)",
- "--prometheus-dir",
- prometheus_dir,
- ]
-
- # Add optional SSL args if they exist
- if ssl_certfile:
- args.extend(["--ssl-certfile", str(ssl_certfile)])
- if ssl_keyfile:
- args.extend(["--ssl-keyfile", str(ssl_keyfile)])
- if ssl_keyfile_password:
- args.extend(["--ssl-keyfile-password", ssl_keyfile_password])
- if ssl_ca_certs:
- args.extend(["--ssl-ca-certs", str(ssl_ca_certs)])
-
- # match with default uvicorn values.
- if ssl_version:
- args.extend(["--ssl-version", int(ssl_version)])
- if ssl_cert_reqs:
- args.extend(["--ssl-cert-reqs", int(ssl_cert_reqs)])
- if ssl_ciphers:
- args.extend(["--ssl-ciphers", ssl_ciphers])
-
watchers.append(
- Watcher(
- name=API_SERVER,
- cmd=sys.executable,
- args=args,
- copy_env=True,
- numprocesses=api_workers or math.ceil(CpuResource.from_system()),
- stop_children=True,
- use_sockets=True,
+ create_watcher(
+ name="api_server",
+ args=[
+ "-m",
+ SCRIPT_API_SERVER,
+ bento_identifier,
+ "--fd",
+ f"$(circus.sockets.{API_SERVER_NAME})",
+ "--runner-map",
+ json.dumps(runner_map),
+ "--working-dir",
+ working_dir,
+ "--backlog",
+ f"{backlog}",
+ "--worker-id",
+ "$(CIRCUS.WID)",
+ "--prometheus-dir",
+ prometheus_dir,
+ *construct_ssl_args(
+ ssl_certfile=ssl_certfile,
+ ssl_keyfile=ssl_keyfile,
+ ssl_keyfile_password=ssl_keyfile_password,
+ ssl_version=ssl_version,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ ssl_ciphers=ssl_ciphers,
+ ),
+ ],
working_dir=working_dir,
+ numprocesses=api_workers or math.ceil(CpuResource.from_system()),
)
)
+ if BentoMLContainer.api_server_config.metrics.enabled.get():
+ logger.info(
+ PROMETHEUS_MESSAGE,
+ "HTTP",
+ bento_identifier,
+ f"http://{host}:{port}/metrics",
+ )
arbiter = create_standalone_arbiter(
watchers=watchers,
sockets=list(circus_socket_map.values()),
)
-
with track_serve(svc, production=True, component=API_SERVER):
- try:
- arbiter.start(
- cb=lambda _: logger.info( # type: ignore
- f'Starting bare Bento API server from "{bento_identifier}" '
- f"running on http://{host}:{port} (Press CTRL+C to quit)"
- ),
+ arbiter.start(
+ cb=lambda _: logger.info( # type: ignore
+ 'Starting bare %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)',
+ "HTTP",
+ bento_identifier,
+ host,
+ port,
+ ),
+ )
+
+
+@inject
+def start_grpc_server(
+ bento_identifier: str,
+ runner_map: dict[str, str],
+ working_dir: str,
+ port: int = Provide[BentoMLContainer.grpc.port],
+ host: str = Provide[BentoMLContainer.grpc.host],
+ backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
+ api_workers: int | None = None,
+ reflection: bool = Provide[BentoMLContainer.grpc.reflection.enabled],
+ max_concurrent_streams: int
+ | None = Provide[BentoMLContainer.grpc.max_concurrent_streams],
+) -> None:
+ from bentoml import load
+
+ from .serve import create_watcher
+ from .serve import PROMETHEUS_MESSAGE
+ from .serve import ensure_prometheus_dir
+ from .serve import PROMETHEUS_SERVER_NAME
+ from ._internal.utils import reserve_free_port
+ from ._internal.resource import CpuResource
+ from ._internal.utils.circus import create_standalone_arbiter
+ from ._internal.utils.analytics import track_serve
+
+ working_dir = os.path.realpath(os.path.expanduser(working_dir))
+ svc = load(bento_identifier, working_dir=working_dir, standalone_load=True)
+
+ runner_requirements = {runner.name for runner in svc.runners}
+ if not runner_requirements.issubset(set(runner_map)):
+ raise ValueError(
+ f"{bento_identifier} requires runners {runner_requirements}, but only {set(runner_map)} are provided."
+ )
+
+ from circus.sockets import CircusSocket # type: ignore
+ from circus.watcher import Watcher # type: ignore
+
+ watchers: list[Watcher] = []
+ circus_socket_map: dict[str, CircusSocket] = {}
+ prometheus_dir = ensure_prometheus_dir()
+ logger.debug("Runner map: %s", runner_map)
+ with contextlib.ExitStack() as port_stack:
+ api_port = port_stack.enter_context(
+ reserve_free_port(host=host, port=port, enable_so_reuseport=True)
+ )
+
+ args = [
+ "-m",
+ SCRIPT_GRPC_API_SERVER,
+ bento_identifier,
+ "--host",
+ host,
+ "--port",
+ str(api_port),
+ "--runner-map",
+ json.dumps(runner_map),
+ "--working-dir",
+ working_dir,
+ "--worker-id",
+ "$(CIRCUS.WID)",
+ ]
+ if reflection:
+ args.append("--enable-reflection")
+
+ if max_concurrent_streams:
+ args.extend(
+ [
+ "--max-concurrent-streams",
+ str(max_concurrent_streams),
+ ]
)
- finally:
- if uds_path is not None:
- shutil.rmtree(uds_path)
+
+ watchers.append(
+ create_watcher(
+ name="grpc_api_server",
+ args=args,
+ use_sockets=False,
+ working_dir=working_dir,
+ numprocesses=api_workers or math.ceil(CpuResource.from_system()),
+ )
+ )
+
+ if BentoMLContainer.api_server_config.metrics.enabled.get():
+ metrics_host = BentoMLContainer.grpc.metrics.host.get()
+ metrics_port = BentoMLContainer.grpc.metrics.port.get()
+
+ circus_socket_map[PROMETHEUS_SERVER_NAME] = CircusSocket(
+ name=PROMETHEUS_SERVER_NAME,
+ host=metrics_host,
+ port=metrics_port,
+ backlog=backlog,
+ )
+
+ watchers.append(
+ create_watcher(
+ name="prom_server",
+ args=[
+ "-m",
+ SCRIPT_GRPC_PROMETHEUS_SERVER,
+ "--fd",
+ f"$(circus.sockets.{PROMETHEUS_SERVER_NAME})",
+ "--prometheus-dir",
+ prometheus_dir,
+ "--backlog",
+ f"{backlog}",
+ ],
+ working_dir=working_dir,
+ numprocesses=1,
+ singleton=True,
+ )
+ )
+
+ logger.info(
+ PROMETHEUS_MESSAGE,
+ "gRPC",
+ bento_identifier,
+ f"http://{metrics_host}:{metrics_port}",
+ )
+ arbiter = create_standalone_arbiter(
+ watchers=watchers, sockets=list(circus_socket_map.values())
+ )
+ with track_serve(svc, production=True, component=API_SERVER):
+ arbiter.start(
+ cb=lambda _: logger.info( # type: ignore
+ 'Starting bare %s BentoServer from "%s" running on http://%s:%d (Press CTRL+C to quit)',
+ "gRPC",
+ bento_identifier,
+ host,
+ port,
+ ),
+ )
diff --git a/bentoml/testing/server.py b/bentoml/testing/server.py
index d520e0a88d..8553f0b69f 100644
--- a/bentoml/testing/server.py
+++ b/bentoml/testing/server.py
@@ -18,10 +18,11 @@
from typing import TYPE_CHECKING
from contextlib import contextmanager
+import psutil
+
from .._internal.tag import Tag
from .._internal.utils import reserve_free_port
from .._internal.utils import cached_contextmanager
-from .._internal.utils.platform import kill_subprocess_tree
logger = logging.getLogger("bentoml")
@@ -75,6 +76,19 @@ async def async_request(
return r.status, Headers(headers), r_body
+def kill_subprocess_tree(p: subprocess.Popen[t.Any]) -> None:
+ """
+ Tell the process to terminate and kill all of its children. Availabe both on Windows and Linux.
+ Note: It will return immediately rather than wait for the process to terminate.
+ Args:
+ p: subprocess.Popen object
+ """
+ if psutil.WINDOWS:
+ subprocess.call(["taskkill", "/F", "/T", "/PID", str(p.pid)])
+ else:
+ p.terminate()
+
+
def _wait_until_api_server_ready(
host_url: str,
timeout: float,
diff --git a/bentoml_cli/containerize.py b/bentoml_cli/containerize.py
index e6964ca23e..bf5d6959f7 100644
--- a/bentoml_cli/containerize.py
+++ b/bentoml_cli/containerize.py
@@ -3,14 +3,11 @@
import sys
import typing as t
import logging
+from typing import TYPE_CHECKING
-import click
-
-from bentoml.bentos import containerize as containerize_bento
-from bentoml._internal.utils import kwargs_transformers
-from bentoml._internal.utils.docker import validate_tag
-
-logger = logging.getLogger("bentoml")
+if TYPE_CHECKING:
+ F = t.Callable[..., t.Any]
+ from click import Group
def containerize_transformer(
@@ -23,7 +20,15 @@ def containerize_transformer(
return value
-def add_containerize_command(cli: click.Group) -> None:
+def add_containerize_command(cli: Group) -> None:
+ import click
+
+ from bentoml.bentos import FEATURES
+ from bentoml.bentos import containerize as containerize_bento
+ from bentoml_cli.utils import kwargs_transformers
+ from bentoml._internal.utils.docker import validate_tag
+ from bentoml._internal.configuration.containers import BentoMLContainer
+
@cli.command()
@click.argument("bento_tag", type=click.STRING)
@click.option(
@@ -160,35 +165,43 @@ def add_containerize_command(cli: click.Group) -> None:
@click.option(
"--ulimit", type=click.STRING, default=None, help="Ulimit options (default [])."
)
+ @click.option(
+ "--enable-features",
+ multiple=True,
+ nargs=1,
+ metavar="[features,]",
+ help=f"Enable additional BentoML features. Available features are: {', '.join(FEATURES)}.",
+ )
@kwargs_transformers(transformer=containerize_transformer)
def containerize( # type: ignore
bento_tag: str,
- docker_image_tag: list[str],
- add_host: t.Iterable[str],
- allow: t.Iterable[str],
- build_arg: t.List[str],
- build_context: t.List[str],
+ docker_image_tag: tuple[str],
+ add_host: tuple[str],
+ allow: tuple[str],
+ build_arg: tuple[str],
+ build_context: tuple[str],
builder: str,
- cache_from: t.List[str],
- cache_to: t.List[str],
+ cache_from: tuple[str],
+ cache_to: tuple[str],
cgroup_parent: str,
iidfile: str,
- label: t.List[str],
+ label: tuple[str],
load: bool,
network: str,
metadata_file: str,
no_cache: bool,
- no_cache_filter: t.List[str],
- output: t.List[str],
- platform: t.List[str],
+ no_cache_filter: tuple[str],
+ output: tuple[str],
+ platform: tuple[str],
progress: t.Literal["auto", "tty", "plain"],
pull: bool,
push: bool,
- secret: t.List[str],
+ secret: tuple[str],
shm_size: str,
ssh: str,
target: str,
ulimit: str,
+ enable_features: tuple[str],
) -> None:
"""Containerizes given Bento into a ready-to-use Docker image.
@@ -217,11 +230,11 @@ def containerize( # type: ignore
By doing so, BentoML will leverage Docker Buildx features such as multi-node
builds for cross-platform images, Full BuildKit capabilities with all of the
familiar UI from 'docker build'.
-
- We also pass all given args for 'docker buildx' through 'bentoml containerize' with ease.
"""
from bentoml._internal.utils import buildx
+ logger = logging.getLogger("bentoml")
+
# run health check whether buildx is install locally
buildx.health()
@@ -253,7 +266,7 @@ def containerize( # type: ignore
key, value = label_str.split("=")
labels[key] = value
- output_ = None
+ output_: dict[str, t.Any] | None = None
if output:
output_ = {}
for arg in output:
@@ -276,6 +289,9 @@ def containerize( # type: ignore
exit_code = not containerize_bento(
bento_tag,
docker_image_tag=docker_image_tag,
+ # containerize options
+ features=enable_features,
+ # docker options
add_host=add_hosts,
allow=allow_,
build_args=build_args,
@@ -291,7 +307,7 @@ def containerize( # type: ignore
network=network,
no_cache=no_cache,
no_cache_filter=no_cache_filter,
- output=output_, # type: ignore
+ output=output_,
platform=platform,
progress=progress,
pull=pull,
@@ -303,4 +319,17 @@ def containerize( # type: ignore
target=target,
ulimit=ulimit,
)
+ if not exit_code:
+ grpc_metrics_port = BentoMLContainer.grpc.metrics.port.get()
+ logger.info(
+ 'Successfully built docker image for "%s" with tags "%s"',
+ str(bento_tag),
+ ",".join(docker_image_tag),
+ )
+ logger.info(
+ 'To run your newly built Bento container, use one of the above tags, and pass it to "docker run". i.e: "docker run -it --rm -p 3000:3000 %s". To use gRPC, pass "-e BENTOML_USE_GRPC=true -p %s:%s" to "docker run".',
+ docker_image_tag[0],
+ grpc_metrics_port,
+ grpc_metrics_port,
+ )
sys.exit(exit_code)
diff --git a/bentoml_cli/serve.py b/bentoml_cli/serve.py
index e34f54abe8..cde029fdbc 100644
--- a/bentoml_cli/serve.py
+++ b/bentoml_cli/serve.py
@@ -15,7 +15,7 @@ def add_serve_command(cli: click.Group) -> None:
from bentoml._internal.log import configure_server_logging
from bentoml._internal.configuration.containers import BentoMLContainer
- @cli.command()
+ @cli.command(aliases=["serve-http"])
@click.argument("bento", type=click.STRING, default=".")
@click.option(
"--production",
@@ -26,9 +26,10 @@ def add_serve_command(cli: click.Group) -> None:
show_default=True,
)
@click.option(
+ "-p",
"--port",
type=click.INT,
- default=BentoMLContainer.service_port.get(),
+ default=BentoMLContainer.http.port.get(),
help="The port to listen on for the REST api server",
envvar="BENTOML_PORT",
show_default=True,
@@ -36,9 +37,10 @@ def add_serve_command(cli: click.Group) -> None:
@click.option(
"--host",
type=click.STRING,
- default=BentoMLContainer.service_host.get(),
- help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]",
+ default=BentoMLContainer.http.host.get(),
+ help="The host to bind for the REST api server",
envvar="BENTOML_HOST",
+ show_default=True,
)
@click.option(
"--api-workers",
@@ -46,6 +48,7 @@ def add_serve_command(cli: click.Group) -> None:
default=None,
help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode",
envvar="BENTOML_API_WORKERS",
+ show_default=True,
)
@click.option(
"--backlog",
@@ -74,42 +77,49 @@ def add_serve_command(cli: click.Group) -> None:
type=str,
default=None,
help="SSL certificate file",
+ show_default=True,
)
@click.option(
"--ssl-keyfile",
type=str,
default=None,
help="SSL key file",
+ show_default=True,
)
@click.option(
"--ssl-keyfile-password",
type=str,
default=None,
help="SSL keyfile password",
+ show_default=True,
)
@click.option(
"--ssl-version",
type=int,
default=None,
help="SSL version to use (see stdlib 'ssl' module)",
+ show_default=True,
)
@click.option(
"--ssl-cert-reqs",
type=int,
default=None,
help="Whether client certificate is required (see stdlib 'ssl' module)",
+ show_default=True,
)
@click.option(
"--ssl-ca-certs",
type=str,
default=None,
help="CA certificates file",
+ show_default=True,
)
@click.option(
"--ssl-ciphers",
type=str,
default=None,
help="Ciphers to use (see stdlib 'ssl' module)",
+ show_default=True,
)
def serve( # type: ignore (unused warning)
bento: str,
@@ -128,36 +138,38 @@ def serve( # type: ignore (unused warning)
ssl_ca_certs: str | None,
ssl_ciphers: str | None,
) -> None:
- """Start a :code:`BentoServer` from a given ``BENTO`` 🍱
+ """Start a HTTP BentoServer from a given 🍱
- ``BENTO`` is the serving target, it can be the import as:
- - the import path of a :code:`bentoml.Service` instance
- - a tag to a Bento in local Bento store
- - a folder containing a valid `bentofile.yaml` build file with a `service` field, which provides the import path of a :code:`bentoml.Service` instance
- - a path to a built Bento (for internal & debug use only)
+ \b
+ BENTO is the serving target, it can be the import as:
+ - the import path of a 'bentoml.Service' instance
+ - a tag to a Bento in local Bento store
+ - a folder containing a valid 'bentofile.yaml' build file with a 'service' field, which provides the import path of a 'bentoml.Service' instance
+ - a path to a built Bento (for internal & debug use only)
e.g.:
\b
Serve from a bentoml.Service instance source code (for development use only):
- :code:`bentoml serve fraud_detector.py:svc`
+ 'bentoml serve fraud_detector.py:svc'
\b
Serve from a Bento built in local store:
- :code:`bentoml serve fraud_detector:4tht2icroji6zput3suqi5nl2`
- :code:`bentoml serve fraud_detector:latest`
+ 'bentoml serve fraud_detector:4tht2icroji6zput3suqi5nl2'
+ 'bentoml serve fraud_detector:latest'
\b
Serve from a Bento directory:
- :code:`bentoml serve ./fraud_detector_bento`
+ 'bentoml serve ./fraud_detector_bento'
\b
- If :code:`--reload` is provided, BentoML will detect code and model store changes during development, and restarts the service automatically.
+ If '--reload' is provided, BentoML will detect code and model store changes during development, and restarts the service automatically.
- The `--reload` flag will:
- - be default, all file changes under `--working-dir` (default to current directory) will trigger a restart
- - when specified, respect :obj:`include` and :obj:`exclude` under :obj:`bentofile.yaml` as well as the :obj:`.bentoignore` file in `--working-dir`, for code and file changes
- - all model store changes will also trigger a restart (new model saved or existing model removed)
+ \b
+ The '--reload' flag will:
+ - be default, all file changes under '--working-dir' (default to current directory) will trigger a restart
+ - when specified, respect 'include' and 'exclude' under 'bentofile.yaml' as well as the '.bentoignore' file in '--working-dir', for code and file changes
+ - all model store changes will also trigger a restart (new model saved or existing model removed)
"""
configure_server_logging()
@@ -170,9 +182,9 @@ def serve( # type: ignore (unused warning)
"'--reload' is not supported with '--production'; ignoring"
)
- from bentoml.serve import serve_production
+ from bentoml.serve import serve_http_production
- serve_production(
+ serve_http_production(
bento,
working_dir=working_dir,
port=port,
@@ -188,13 +200,13 @@ def serve( # type: ignore (unused warning)
ssl_ciphers=ssl_ciphers,
)
else:
- from bentoml.serve import serve_development
+ from bentoml.serve import serve_http_development
- serve_development(
+ serve_http_development(
bento,
working_dir=working_dir,
port=port,
- host=DEFAULT_DEV_SERVER_HOST if host is None else host,
+ host=DEFAULT_DEV_SERVER_HOST if not host else host,
reload=reload,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
@@ -204,3 +216,155 @@ def serve( # type: ignore (unused warning)
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
)
+
+ from bentoml._internal.utils import add_experimental_docstring
+
+ @cli.command(name="serve-grpc")
+ @click.argument("bento", type=click.STRING, default=".")
+ @click.option(
+ "--production",
+ type=click.BOOL,
+ help="Run the BentoServer in production mode",
+ is_flag=True,
+ default=False,
+ show_default=True,
+ )
+ @click.option(
+ "-p",
+ "--port",
+ type=click.INT,
+ default=BentoMLContainer.grpc.port.get(),
+ help="The port to listen on for the REST api server",
+ envvar="BENTOML_PORT",
+ show_default=True,
+ )
+ @click.option(
+ "--host",
+ type=click.STRING,
+ default=BentoMLContainer.grpc.host.get(),
+ help="The host to bind for the gRPC server",
+ envvar="BENTOML_HOST",
+ show_default=True,
+ )
+ @click.option(
+ "--api-workers",
+ type=click.INT,
+ default=None,
+ help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode",
+ envvar="BENTOML_API_WORKERS",
+ show_default=True,
+ )
+ @click.option(
+ "--reload",
+ type=click.BOOL,
+ is_flag=True,
+ help="Reload Service when code changes detected, this is only available in development mode",
+ default=False,
+ show_default=True,
+ )
+ @click.option(
+ "--backlog",
+ type=click.INT,
+ default=BentoMLContainer.api_server_config.backlog.get(),
+ help="The maximum number of pending connections.",
+ show_default=True,
+ )
+ @click.option(
+ "--working-dir",
+ type=click.Path(),
+ help="When loading from source code, specify the directory to find the Service instance",
+ default=".",
+ show_default=True,
+ )
+ @click.option(
+ "--enable-reflection",
+ is_flag=True,
+ default=BentoMLContainer.grpc.reflection.enabled.get(),
+ type=click.BOOL,
+ help="Enable reflection.",
+ show_default=True,
+ )
+ @click.option(
+ "--max-concurrent-streams",
+ default=BentoMLContainer.grpc.max_concurrent_streams.get(),
+ type=click.INT,
+ help="Maximum number of concurrent incoming streams to allow on a http2 connection.",
+ show_default=True,
+ )
+ @add_experimental_docstring
+ def serve_grpc( # type: ignore (unused warning)
+ bento: str,
+ production: bool,
+ port: int,
+ host: str,
+ api_workers: int | None,
+ backlog: int,
+ reload: bool,
+ working_dir: str,
+ enable_reflection: bool,
+ max_concurrent_streams: int | None,
+ ):
+ """Start a gRPC BentoServer from a given 🍱
+
+ \b
+ BENTO is the serving target, it can be the import as:
+ - the import path of a 'bentoml.Service' instance
+ - a tag to a Bento in local Bento store
+ - a folder containing a valid 'bentofile.yaml' build file with a 'service' field, which provides the import path of a 'bentoml.Service' instance
+ - a path to a built Bento (for internal & debug use only)
+
+ e.g.:
+
+ \b
+ Serve from a bentoml.Service instance source code (for development use only):
+ 'bentoml serve-grpc fraud_detector.py:svc'
+
+ \b
+ Serve from a Bento built in local store:
+ 'bentoml serve-grpc fraud_detector:4tht2icroji6zput3suqi5nl2'
+ 'bentoml serve-grpc fraud_detector:latest'
+
+ \b
+ Serve from a Bento directory:
+ 'bentoml serve-grpc ./fraud_detector_bento'
+
+ If '--reload' is provided, BentoML will detect code and model store changes during development, and restarts the service automatically.
+
+ \b
+ The '--reload' flag will:
+ - be default, all file changes under '--working-dir' (default to current directory) will trigger a restart
+ - when specified, respect 'include' and 'exclude' under 'bentofile.yaml' as well as the '.bentoignore' file in '--working-dir', for code and file changes
+ - all model store changes will also trigger a restart (new model saved or existing model removed)
+ """
+ configure_server_logging()
+ if production:
+ if reload:
+ logger.warning(
+ "'--reload' is not supported with '--production'; ignoring"
+ )
+
+ from bentoml.serve import serve_grpc_production
+
+ serve_grpc_production(
+ bento,
+ working_dir=working_dir,
+ port=port,
+ host=host,
+ backlog=backlog,
+ api_workers=api_workers,
+ max_concurrent_streams=max_concurrent_streams,
+ reflection=enable_reflection,
+ )
+ else:
+ from bentoml.serve import serve_grpc_development
+
+ serve_grpc_development(
+ bento,
+ working_dir=working_dir,
+ port=port,
+ backlog=backlog,
+ reload=reload,
+ host=DEFAULT_DEV_SERVER_HOST if not host else host,
+ max_concurrent_streams=max_concurrent_streams,
+ reflection=enable_reflection,
+ )
diff --git a/bentoml_cli/start.py b/bentoml_cli/start.py
index 786f5edb3d..6fd61bdd2f 100644
--- a/bentoml_cli/start.py
+++ b/bentoml_cli/start.py
@@ -12,7 +12,7 @@
def add_start_command(cli: click.Group) -> None:
- from bentoml._internal.log import configure_server_logging
+ from bentoml._internal.utils import add_experimental_docstring
from bentoml._internal.configuration.containers import BentoMLContainer
@cli.command(hidden=True)
@@ -41,7 +41,7 @@ def add_start_command(cli: click.Group) -> None:
@click.option(
"--port",
type=click.INT,
- default=BentoMLContainer.service_port.get(),
+ default=BentoMLContainer.http.port.get(),
help="The port to listen on for the REST api server",
envvar="BENTOML_PORT",
show_default=True,
@@ -49,7 +49,7 @@ def add_start_command(cli: click.Group) -> None:
@click.option(
"--host",
type=click.STRING,
- default=BentoMLContainer.service_host.get(),
+ default=BentoMLContainer.http.host.get(),
help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]",
envvar="BENTOML_HOST",
)
@@ -60,6 +60,13 @@ def add_start_command(cli: click.Group) -> None:
help="The maximum number of pending connections.",
show_default=True,
)
+ @click.option(
+ "--api-workers",
+ type=click.INT,
+ default=None,
+ help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode",
+ envvar="BENTOML_API_WORKERS",
+ )
@click.option(
"--working-dir",
type=click.Path(),
@@ -109,6 +116,7 @@ def add_start_command(cli: click.Group) -> None:
default=None,
help="Ciphers to use (see stdlib 'ssl' module)",
)
+ @add_experimental_docstring
def start_http_server( # type: ignore (unused warning)
bento: str,
remote_runner: list[str] | None,
@@ -118,6 +126,7 @@ def start_http_server( # type: ignore (unused warning)
host: str,
backlog: int,
working_dir: str,
+ api_workers: int | None,
ssl_certfile: str | None,
ssl_keyfile: str | None,
ssl_keyfile_password: str | None,
@@ -126,7 +135,9 @@ def start_http_server( # type: ignore (unused warning)
ssl_ca_certs: str | None,
ssl_ciphers: str | None,
) -> None:
- configure_server_logging()
+ """
+ Start a HTTP API server standalone. This will be used inside Yatai.
+ """
if sys.path[0] != working_dir:
sys.path.insert(0, working_dir)
@@ -155,6 +166,7 @@ def start_http_server( # type: ignore (unused warning)
port=port,
host=host,
backlog=backlog,
+ api_workers=api_workers,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
@@ -183,7 +195,7 @@ def start_http_server( # type: ignore (unused warning)
@click.option(
"--port",
type=click.INT,
- default=BentoMLContainer.service_port.get(),
+ default=BentoMLContainer.http.port.get(),
help="The port to listen on for the REST api server",
envvar="BENTOML_PORT",
show_default=True,
@@ -191,7 +203,7 @@ def start_http_server( # type: ignore (unused warning)
@click.option(
"--host",
type=click.STRING,
- default=BentoMLContainer.service_host.get(),
+ default=BentoMLContainer.http.host.get(),
help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]",
envvar="BENTOML_HOST",
)
@@ -209,6 +221,7 @@ def start_http_server( # type: ignore (unused warning)
default=".",
show_default=True,
)
+ @add_experimental_docstring
def start_runner_server( # type: ignore (unused warning)
bento: str,
runner_name: str,
@@ -218,7 +231,9 @@ def start_runner_server( # type: ignore (unused warning)
backlog: int,
working_dir: str,
) -> None:
- configure_server_logging()
+ """
+ Start Runner server standalone. This will be used inside Yatai.
+ """
if sys.path[0] != working_dir:
sys.path.insert(0, working_dir)
@@ -238,3 +253,95 @@ def start_runner_server( # type: ignore (unused warning)
host=host,
backlog=backlog,
)
+
+ @cli.command(hidden=True)
+ @click.argument("bento", type=click.STRING, default=".")
+ @click.option(
+ "--remote-runner",
+ type=click.STRING,
+ multiple=True,
+ envvar="BENTOML_SERVE_RUNNER_MAP",
+ help="JSON string of runners map",
+ )
+ @click.option(
+ "--port",
+ type=click.INT,
+ default=BentoMLContainer.grpc.port.get(),
+ help="The port to listen on for the gRPC server",
+ envvar="BENTOML_PORT",
+ show_default=True,
+ )
+ @click.option(
+ "--host",
+ type=click.STRING,
+ default=BentoMLContainer.grpc.host.get(),
+ help="The host to bind for the gRPC server (defaults: 0.0.0.0)",
+ envvar="BENTOML_HOST",
+ )
+ @click.option(
+ "--backlog",
+ type=click.INT,
+ default=BentoMLContainer.api_server_config.backlog.get(),
+ help="The maximum number of pending connections.",
+ show_default=True,
+ )
+ @click.option(
+ "--working-dir",
+ type=click.Path(),
+ help="When loading from source code, specify the directory to find the Service instance",
+ default=".",
+ show_default=True,
+ )
+ @click.option(
+ "--api-workers",
+ type=click.INT,
+ default=None,
+ help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode",
+ envvar="BENTOML_API_WORKERS",
+ )
+ @click.option(
+ "--enable-reflection",
+ is_flag=True,
+ default=BentoMLContainer.grpc.reflection.enabled.get(),
+ type=click.BOOL,
+ help="Enable reflection.",
+ )
+ @click.option(
+ "--max-concurrent-streams",
+ default=BentoMLContainer.grpc.max_concurrent_streams.get(),
+ type=click.INT,
+ help="Maximum number of concurrent incoming streams to allow on a http2 connection.",
+ )
+ @add_experimental_docstring
+ def start_grpc_server( # type: ignore (unused warning)
+ bento: str,
+ remote_runner: list[str] | None,
+ port: int,
+ host: str,
+ backlog: int,
+ api_workers: int | None,
+ working_dir: str,
+ enable_reflection: bool,
+ max_concurrent_streams: int | None,
+ ) -> None:
+ """
+ Start a gRPC API server standalone. This will be used inside Yatai.
+ """
+ if sys.path[0] != working_dir:
+ sys.path.insert(0, working_dir)
+
+ from bentoml.start import start_grpc_server
+
+ runner_map = dict([s.split("=", maxsplit=2) for s in remote_runner or []])
+ logger.info(" Using remote runners: %s", runner_map)
+ start_grpc_server(
+ bento,
+ runner_map=runner_map,
+ working_dir=working_dir,
+ port=port,
+ host=host,
+ backlog=backlog,
+ api_workers=api_workers,
+ reflection=enable_reflection,
+ max_concurrent_streams=max_concurrent_streams,
+ )
diff --git a/bentoml_cli/utils.py b/bentoml_cli/utils.py
index 62593975a9..bc60ed9dde 100644
--- a/bentoml_cli/utils.py
+++ b/bentoml_cli/utils.py
@@ -15,6 +15,8 @@
from bentoml.exceptions import BentoMLException
from bentoml._internal.log import configure_logging
+from bentoml._internal.configuration import DEBUG_ENV_VAR
+from bentoml._internal.configuration import QUIET_ENV_VAR
from bentoml._internal.configuration import get_debug_mode
from bentoml._internal.configuration import set_debug_mode
from bentoml._internal.configuration import set_quiet_mode
@@ -28,6 +30,7 @@
from click import Command
from click import Context
from click import Parameter
+ from click import HelpFormatter
P = t.ParamSpec("P")
@@ -49,9 +52,39 @@ def __call__( # pylint: disable=no-method-argument
logger = logging.getLogger("bentoml")
+def kwargs_transformers(
+ _func: F[t.Any] | None = None,
+ *,
+ transformer: F[t.Any],
+) -> F[t.Any]:
+ def decorator(func: F[t.Any]) -> t.Callable[P, t.Any]:
+ @functools.wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
+ return func(*args, **{k: transformer(v) for k, v in kwargs.items()})
+
+ return wrapper
+
+ if _func is None:
+ return decorator
+ return decorator(_func)
+
+
class BentoMLCommandGroup(click.Group):
- """Click command class customized for BentoML CLI, allow specifying a default
+ """
+ Click command class customized for BentoML CLI, allow specifying a default
command for each group defined.
+
+ This command groups will also introduce support for aliases for commands.
+
+ Example:
+
+ .. code-block:: python
+
+ @click.group(cls=BentoMLCommandGroup)
+ def cli(): ...
+
+ @cli.command(aliases=["serve-http"])
+ def serve(): ...
"""
NUMBER_OF_COMMON_PARAMS = 3
@@ -65,6 +98,7 @@ def bentoml_common_params(func: F[P]) -> WrappedCLI[bool, bool]:
"--quiet",
is_flag=True,
default=False,
+ envvar=QUIET_ENV_VAR,
help="Suppress all warnings and info logs",
)
@click.option(
@@ -72,6 +106,7 @@ def bentoml_common_params(func: F[P]) -> WrappedCLI[bool, bool]:
"--debug",
is_flag=True,
default=False,
+ envvar=DEBUG_ENV_VAR,
help="Generate debug information",
)
@click.option(
@@ -173,10 +208,17 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> t.Any:
return t.cast("ClickFunctionWrapper[t.Any]", wrapper)
+ def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
+ super(BentoMLCommandGroup, self).__init__(*args, **kwargs)
+ # these two dictionaries will store known aliases for commands and groups
+ self._commands: dict[str, list[str]] = {}
+ self._aliases: dict[str, str] = {}
+
def command(self, *args: t.Any, **kwargs: t.Any) -> t.Callable[[F[P]], Command]:
if "context_settings" not in kwargs:
kwargs["context_settings"] = {}
kwargs["context_settings"]["max_content_width"] = 120
+ aliases = kwargs.pop("aliases", None)
def wrapper(func: F[P]) -> Command:
# add common parameters to command.
@@ -191,10 +233,48 @@ def wrapper(func: F[P]) -> Command:
wrapped.__click_params__[-self.NUMBER_OF_COMMON_PARAMS :]
+ wrapped.__click_params__[: -self.NUMBER_OF_COMMON_PARAMS]
)
- return super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped)
+ cmd = super(BentoMLCommandGroup, self).command(*args, **kwargs)(wrapped)
+ # add aliases to a given commands if it is specified.
+ if aliases is not None:
+ assert cmd.name
+ self._commands[cmd.name] = aliases
+ self._aliases.update({alias: cmd.name for alias in aliases})
+ return cmd
return wrapper
+ def resolve_alias(self, cmd_name: str):
+ return self._aliases[cmd_name] if cmd_name in self._aliases else cmd_name
+
+ def get_command(self, ctx: Context, cmd_name: str) -> Command | None:
+ cmd_name = self.resolve_alias(cmd_name)
+ return super(BentoMLCommandGroup, self).get_command(ctx, cmd_name)
+
+ def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None:
+ rows: list[tuple[str, str]] = []
+ sub_commands = self.list_commands(ctx)
+
+ max_len = max(len(cmd) for cmd in sub_commands)
+ limit = formatter.width - 6 - max_len
+
+ for sub_command in sub_commands:
+ cmd = self.get_command(ctx, sub_command)
+ if cmd is None:
+ continue
+ # If the command is hidden, then we skip it.
+ if hasattr(cmd, "hidden") and cmd.hidden:
+ continue
+ if sub_command in self._commands:
+ aliases = ",".join(sorted(self._commands[sub_command]))
+ sub_command = "%s (%s)" % (sub_command, aliases)
+ # this cmd_help is available since click>=7
+ # BentoML requires click>=7.
+ cmd_help = cmd.get_short_help_str(limit)
+ rows.append((sub_command, cmd_help))
+ if rows:
+ with formatter.section("Commands"):
+ formatter.write_dl(rows)
+
def resolve_command(
self, ctx: Context, args: list[str]
) -> tuple[str | None, Command | None, list[str]]:
@@ -237,13 +317,9 @@ def unparse_click_params(
Unparse click call to a list of arguments. Used to modify some parameters and
restore to system command. The goal is to unpack cases where parameters can be parsed multiple times.
- Refers to ./buildx.py for examples of this usage. This is also used to unparse parameters for running API server.
-
Args:
- params (`dict[str, t.Any]`):
- The dictionary of the parameters that is parsed from click.Context.
- command_params (`list[click.Parameter]`):
- The list of paramters (Arguments/Options) that is part of a given command.
+ params: The dictionary of the parameters that is parsed from click.Context.
+ command_params: The list of paramters (Arguments/Options) that is part of a given command.
Returns:
Unparsed list of arguments that can be redirected to system commands.
diff --git a/bentoml_cli/worker/grpc_api_server.py b/bentoml_cli/worker/grpc_api_server.py
new file mode 100644
index 0000000000..be6a35ee49
--- /dev/null
+++ b/bentoml_cli/worker/grpc_api_server.py
@@ -0,0 +1,95 @@
+from __future__ import annotations
+
+import json
+import typing as t
+
+import click
+
+
+@click.command()
+@click.argument("bento_identifier", type=click.STRING, required=False, default=".")
+@click.option("--host", type=click.STRING, required=False, default=None)
+@click.option("--port", type=click.INT, required=False, default=None)
+@click.option(
+ "--runner-map",
+ type=click.STRING,
+ envvar="BENTOML_RUNNER_MAP",
+ help="JSON string of runners map, default sets to envars `BENTOML_RUNNER_MAP`",
+)
+@click.option(
+ "--working-dir",
+ type=click.Path(exists=True),
+ help="Working directory for the API server",
+)
+@click.option(
+ "--worker-id",
+ required=False,
+ type=click.INT,
+ default=None,
+ help="If set, start the server as a bare worker with the given worker ID. Otherwise start a standalone server with a supervisor process.",
+)
+@click.option(
+ "--enable-reflection",
+ type=click.BOOL,
+ is_flag=True,
+ help="Enable reflection.",
+ default=False,
+)
+@click.option(
+ "--max-concurrent-streams",
+ type=click.INT,
+ help="Maximum number of concurrent incoming streams to allow on a HTTP2 connection.",
+ default=None,
+)
+def main(
+ bento_identifier: str,
+ host: str,
+ port: int,
+ runner_map: str | None,
+ working_dir: str | None,
+ worker_id: int | None,
+ enable_reflection: bool,
+ max_concurrent_streams: int | None,
+):
+ """
+ Start BentoML API server.
+ \b
+ This is an internal API, users should not use this directly. Instead use `bentoml serve-grpc [--options]`
+ """
+
+ import bentoml
+ from bentoml._internal.log import configure_server_logging
+ from bentoml._internal.context import component_context
+ from bentoml._internal.configuration.containers import BentoMLContainer
+
+ component_context.component_type = "grpc_api_server"
+ component_context.component_index = worker_id
+ configure_server_logging()
+
+ BentoMLContainer.development_mode.set(False)
+ if runner_map is not None:
+ BentoMLContainer.remote_runner_mapping.set(json.loads(runner_map))
+
+ svc = bentoml.load(bento_identifier, working_dir=working_dir, standalone_load=True)
+
+ # setup context
+ if svc.tag is None:
+ component_context.bento_name = f"*{svc.__class__.__name__}"
+ component_context.bento_version = "not available"
+ else:
+ component_context.bento_name = svc.tag.name
+ component_context.bento_version = svc.tag.version
+
+ from bentoml._internal.server import grpc
+
+ grpc_options: dict[str, t.Any] = {"enable_reflection": enable_reflection}
+ if max_concurrent_streams:
+ grpc_options["max_concurrent_streams"] = int(max_concurrent_streams)
+
+ grpc.Server(
+ grpc.Config(svc.grpc_servicer, bind_address=f"{host}:{port}", **grpc_options)
+ ).run()
+
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
diff --git a/bentoml_cli/worker/grpc_dev_api_server.py b/bentoml_cli/worker/grpc_dev_api_server.py
new file mode 100644
index 0000000000..b672e5605d
--- /dev/null
+++ b/bentoml_cli/worker/grpc_dev_api_server.py
@@ -0,0 +1,74 @@
+from __future__ import annotations
+
+import typing as t
+
+import click
+
+
+@click.command()
+@click.argument("bento_identifier", type=click.STRING, required=False, default=".")
+@click.option("--host", type=click.STRING, required=False, default=None)
+@click.option("--port", type=click.INT, required=False, default=None)
+@click.option("--working-dir", required=False, type=click.Path(), default=None)
+@click.option(
+ "--enable-reflection",
+ type=click.BOOL,
+ is_flag=True,
+ help="Enable reflection.",
+ default=False,
+)
+@click.option(
+ "--max-concurrent-streams",
+ type=int,
+ help="Maximum number of concurrent incoming streams to allow on a http2 connection.",
+ default=None,
+)
+def main(
+ bento_identifier: str,
+ host: str,
+ port: int,
+ working_dir: str | None,
+ enable_reflection: bool,
+ max_concurrent_streams: int | None,
+):
+ import psutil
+
+ from bentoml import load
+ from bentoml._internal.log import configure_server_logging
+ from bentoml._internal.context import component_context
+ from bentoml._internal.configuration.containers import BentoMLContainer
+
+ component_context.component_type = "grpc_dev_api_server"
+ configure_server_logging()
+
+ svc = load(bento_identifier, working_dir=working_dir, standalone_load=True)
+ if not port:
+ port = BentoMLContainer.grpc.port.get()
+ if not host:
+ host = BentoMLContainer.grpc.host.get()
+
+ # setup context
+ if svc.tag is None:
+ component_context.bento_name = f"*{svc.__class__.__name__}"
+ component_context.bento_version = "not available"
+ else:
+ component_context.bento_name = svc.tag.name
+ component_context.bento_version = svc.tag.version
+ if psutil.WINDOWS:
+ import asyncio
+
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore
+
+ from bentoml._internal.server import grpc
+
+ grpc_options: dict[str, t.Any] = {"enable_reflection": enable_reflection}
+ if max_concurrent_streams:
+ grpc_options["max_concurrent_streams"] = int(max_concurrent_streams)
+
+ grpc.Server(
+ grpc.Config(svc.grpc_servicer, bind_address=f"{host}:{port}", **grpc_options)
+ ).run()
+
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
diff --git a/bentoml_cli/worker/grpc_prometheus_server.py b/bentoml_cli/worker/grpc_prometheus_server.py
new file mode 100644
index 0000000000..47a1d8a4e6
--- /dev/null
+++ b/bentoml_cli/worker/grpc_prometheus_server.py
@@ -0,0 +1,96 @@
+from __future__ import annotations
+
+import typing as t
+import logging
+from typing import TYPE_CHECKING
+
+import click
+
+if TYPE_CHECKING:
+ from bentoml._internal import external_typing as ext
+
+logger = logging.getLogger("bentoml")
+
+
+class GenerateLatestMiddleware:
+ def __init__(self, app: ext.ASGIApp):
+ from bentoml._internal.configuration.containers import BentoMLContainer
+
+ self.app = app
+ self.metrics_client = BentoMLContainer.metrics_client.get()
+
+ async def __call__(
+ self, scope: ext.ASGIScope, receive: ext.ASGIReceive, send: ext.ASGISend
+ ) -> None:
+ assert scope["type"] == "http"
+ assert scope["path"] == "/"
+
+ from starlette.responses import Response
+
+ return await Response(
+ self.metrics_client.generate_latest(),
+ status_code=200,
+ media_type=self.metrics_client.CONTENT_TYPE_LATEST,
+ )(scope, receive, send)
+
+
+@click.command()
+@click.option("--fd", type=click.INT, required=True)
+@click.option("--backlog", type=click.INT, default=2048)
+@click.option(
+ "--prometheus-dir",
+ type=click.Path(exists=True),
+ help="Required by prometheus to pass the metrics in multi-process mode",
+)
+def main(fd: int, backlog: int, prometheus_dir: str | None):
+ """
+ Start a standalone Prometheus server to use with gRPC.
+ \b
+ This is an internal API, users should not use this directly. Instead use 'bentoml serve-grpc'.
+ Prometheus then can be accessed at localhost:9090
+ """
+
+ import socket
+
+ import psutil
+ import uvicorn
+ from starlette.middleware import Middleware
+ from starlette.applications import Starlette
+ from starlette.middleware.wsgi import WSGIMiddleware # TODO: a2wsgi
+
+ from bentoml._internal.log import configure_server_logging
+ from bentoml._internal.context import component_context
+ from bentoml._internal.configuration import get_debug_mode
+ from bentoml._internal.configuration.containers import BentoMLContainer
+
+ component_context.component_type = "prom_server"
+
+ configure_server_logging()
+
+ BentoMLContainer.development_mode.set(False)
+ metrics_client = BentoMLContainer.metrics_client.get()
+ if prometheus_dir is not None:
+ BentoMLContainer.prometheus_multiproc_dir.set(prometheus_dir)
+
+ # create a ASGI app that wraps around the default HTTP prometheus server.
+ prom_app = Starlette(
+ debug=get_debug_mode(), middleware=[Middleware(GenerateLatestMiddleware)]
+ )
+ prom_app.mount("/", WSGIMiddleware(metrics_client.make_wsgi_app()))
+ sock = socket.socket(fileno=fd)
+
+ uvicorn_options: dict[str, t.Any] = {
+ "backlog": backlog,
+ "log_config": None,
+ "workers": 1,
+ }
+ if psutil.WINDOWS:
+ uvicorn_options["loop"] = "asyncio"
+ import asyncio
+
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore
+ uvicorn.Server(uvicorn.Config(prom_app, **uvicorn_options)).run(sockets=[sock])
+
+
+if __name__ == "__main__":
+ main() # pylint: disable=no-value-for-parameter
diff --git a/codecov.yml b/codecov.yml
index f3f83382f6..6171b4a18a 100644
--- a/codecov.yml
+++ b/codecov.yml
@@ -1,5 +1,5 @@
github_checks:
- annotations: false
+ annotations: false
comment:
layout: "reach, diff, files"
behavior: default
@@ -22,7 +22,7 @@ ignore:
coverage:
precision: 2
round: down
- range: "80...100"
+ range: "80...100"
status:
default_rules:
flag_coverage_not_uploaded_behavior: exclude
@@ -63,7 +63,8 @@ coverage:
- transformers
- xgboost
- unit-tests
- - e2e-tests
+ - e2e-tests-http
+ - e2e-tests-grpc
catboost:
target: auto
threshold: 10%
@@ -78,12 +79,12 @@ coverage:
target: auto
threshold: 10%
flags:
- - easyocr
+ - easyocr
evalml:
target: auto
threshold: 10%
flags:
- - evalml
+ - evalml
fastai:
target: auto
threshold: 10%
@@ -208,7 +209,8 @@ coverage:
target: auto
threshold: 10%
flags:
- - e2e-tests
+ - e2e-tests-http
+ - e2e-tests-grpc
unit:
target: auto
threshold: 10%
@@ -328,11 +330,16 @@ flags:
carryforward: true
paths:
- bentoml/_internal/frameworks/xgboost.py
- e2e-tests:
+ e2e-tests-http:
carryforward: true
paths:
- "bentoml/**/*"
- bentoml/models.py
+ e2e-tests-grpc:
+ carryforward: true
+ paths:
+ - "bentoml/**/*"
+ - bentoml/grpc/utils.py
unit-tests:
carryforward: true
paths:
diff --git a/pyproject.toml b/pyproject.toml
index 73cf80c53d..9dc787a2dc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,10 +1,5 @@
[build-system]
-requires = [
- # sync with setup.py until we discard non-pep-517/518
- "wheel",
- "setuptools>=59.0",
- "setuptools-scm[toml]>=6.2.3",
-]
+requires = ["setuptools<=60", "setuptools_scm[toml]>=6.2", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools_scm]
@@ -13,6 +8,43 @@ git_describe_command = "git describe --dirty --tags --long --first-parent"
version_scheme = "post-release"
fallback_version = "0.0.0"
+[tool.coverage.paths]
+source = ["bentoml"]
+
+[tool.coverage.run]
+branch = true
+source = ["bentoml", "bentoml_cli"]
+omit = [
+ "bentoml/**/*_pb2.py",
+ "bentoml/__main__.py",
+ "bentoml/_internal/types.py",
+ "bentoml/_internal/external_typing/*",
+ "bentoml/testing/*",
+ "bentoml/io.py",
+]
+
+[tool.coverage.report]
+show_missing = true
+precision = 2
+omit = [
+ "*/bentoml/**/*_pb2*.py",
+ "*/bentoml/_internal/external_typing/*",
+ "*/bentoml/_internal/types.py",
+ "*/bentoml/testing/*",
+ '*/bentoml/__main__.py',
+ "*/bentoml/io.py",
+]
+exclude_lines = [
+ "pragma: no cover",
+ "def __repr__",
+ "raise AssertionError",
+ "raise NotImplementedError",
+ "raise MissingDependencyException",
+ "except ImportError",
+ "if __name__ == .__main__.:",
+ "if TYPE_CHECKING:",
+]
+
[tool.black]
line-length = 88
exclude = '''
@@ -24,104 +56,177 @@ exclude = '''
| \.venv
| _build
| build
+ | venv
+ | lib
| dist
| typings
+ | bentoml/grpc
)/
| bentoml/_version.py
)
'''
+extend-exclude = "(_pb2.py$|_pb2_grpc.py$)"
[tool.pytest.ini_options]
-addopts = "-rfEX -p pytester -p no:warnings"
-python_files = ["test_*.py","*_test.py"]
+addopts = "-rfEX -p pytester -p no:warnings -x --capture=tee-sys"
+python_files = ["test_*.py", "*_test.py"]
testpaths = ["tests"]
-markers = ["gpus"]
-
-[tool.pylint.master]
-ignore=".ipynb_checkpoints,typings,bentoml/_internal/external_typing"
-ignore-paths=".*.pyi"
-unsafe-load-any-extension="no"
-extension-pkg-whitelist="numpy,tensorflow,torch,paddle,keras,pydantic"
-jobs=4
-persistent="yes"
-suggestion-mode="yes"
-max-line-length=88
-
-[tool.pylint.messages_control]
-disable="import-error,print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call,deprecated-itertools-function,deprecated-types-field,next-method-defined,dict-items-not-iterating,dict-keys-not-iterating,dict-values-not-iterating,deprecated-operator-function,deprecated-urllib-function,xreadlines-attribute,deprecated-sys-function,exception-escape,comprehension-escape,logging-fstring-interpolation,logging-format-interpolation,logging-not-lazy,C,R,fixme,protected-access,no-member,unsubscriptable-object,raise-missing-from,isinstance-second-argument-not-valid-type,attribute-defined-outside-init,relative-beyond-top-level"
-enable="c-extension-no-member"
-
-[tool.pylint.reports]
-evaluation="10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)"
-msg-template="{msg_id}:{symbol} [{line:0>3d}:{column:0>2d}] {obj}: {msg}"
-output-format="colorized"
-reports="no"
-score="yes"
+markers = ["gpus", "disable-tf-eager-execution"]
-[tool.pylint.refactoring]
-max-nested-blocks=5
-never-returning-functions="optparse.Values,sys.exit"
+[tool.pylint.main]
+recursive = true
+extension-pkg-allow-list = [
+ "numpy",
+ "tensorflow",
+ "torch",
+ "paddle",
+ "onnxruntime",
+ "onnx",
+ "pydantic.schema",
+]
+ignore-paths = ["typings", "bentoml/_internal/external_typing", "bentoml/grpc"]
+disable = [
+ "import-error",
+ "print-statement",
+ "parameter-unpacking",
+ "unpacking-in-except",
+ "old-raise-syntax",
+ "backtick",
+ "raw-checker-failed",
+ "bad-inline-option",
+ "locally-disabled",
+ "file-ignored",
+ "suppressed-message",
+ "useless-suppression",
+ "deprecated-pragma",
+ "apply-builtin",
+ "basestring-builtin",
+ "buffer-builtin",
+ "cmp-builtin",
+ "coerce-builtin",
+ "execfile-builtin",
+ "file-builtin",
+ "long-builtin",
+ "raw_input-builtin",
+ "reduce-builtin",
+ "standarderror-builtin",
+ "coerce-method",
+ "delslice-method",
+ "getslice-method",
+ "setslice-method",
+ "no-absolute-import",
+ "old-division",
+ "dict-iter-method",
+ "dict-view-method",
+ "next-method-called",
+ "metaclass-assignment",
+ "indexing-exception",
+ "raising-string",
+ "reload-builtin",
+ "oct-method",
+ "hex-method",
+ "nonzero-method",
+ "cmp-method",
+ "input-builtin",
+ "round-builtin",
+ "intern-builtin",
+ "unichr-builtin",
+ "map-builtin-not-iterating",
+ "zip-builtin-not-iterating",
+ "range-builtin-not-iterating",
+ "filter-builtin-not-iterating",
+ "using-cmp-argument",
+ "exception-message-attribute",
+ "invalid-str-codec",
+ "sys-max-int",
+ "bad-python3-import",
+ "deprecated-string-function",
+ "deprecated-str-translate-call",
+ "deprecated-itertools-function",
+ "deprecated-types-field",
+ "next-method-defined",
+ "dict-items-not-iterating",
+ "dict-keys-not-iterating",
+ "dict-values-not-iterating",
+ "deprecated-operator-function",
+ "deprecated-urllib-function",
+ "xreadlines-attribute",
+ "deprecated-sys-function",
+ "exception-escape",
+ "comprehension-escape",
+ "logging-fstring-interpolation",
+ "logging-format-interpolation",
+ "logging-not-lazy",
+ "C",
+ "R",
+ "fixme",
+ "protected-access",
+ "no-member",
+ "unsubscriptable-object",
+ "raise-missing-from",
+ "isinstance-second-argument-not-valid-type",
+ "attribute-defined-outside-init",
+ "relative-beyond-top-level",
+]
+enable = ["c-extension-no-member"]
+evaluation = "10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)"
+msg-template = "{msg_id}:{symbol} [{line:0>3d}:{column:0>2d}] {obj}: {msg}"
+output-format = "colorized"
+score = true
+
+[tool.pylint.classes]
+valid-metaclass-classmethod-first-arg = ["cls", "mcls", "kls"]
+
+[tool.pylint.logging]
+logging-format-style = "old" # using %s formatter for logging (performance-related)
[tool.pylint.miscellaneous]
-notes="FIXME,XXX,TODO,NOTE"
+notes = ["FIXME", "XXX", "TODO", "NOTE", "WARNING"]
-[tool.pylint.typecheck]
-ignored-classes="Namespace"
-contextmanager-decorators="contextlib.contextmanager"
+[tool.pylint.refactoring]
+# specify functions that should not return
+never-returning-functions = ["sys.exit"]
-[tool.pylint.format]
-indent-after-paren=4
-indent-string=' '
-max-line-length=100
-max-module-lines=1000
-single-line-class-stmt="no"
-single-line-if-stmt="no"
+[tool.pylint.spelling]
+spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:,pylint:,type:"
-[tool.pylint.imports]
-allow-wildcard-with-all="no"
-analyse-fallback-blocks="no"
+[tool.pylint.variables]
+init-import = true
-[tool.pylint.classes]
-defining-attr-methods="__init__,__new__,setUp"
-exclude-protected="_asdict,_fields,_replace,_source,_make"
-valid-classmethod-first-arg="cls"
-valid-metaclass-classmethod-first-arg="mcs"
-
-[tool.pylint.design]
-# Maximum number of arguments for function / method
-max-args=5
-# Maximum number of attributes for a class (see R0902).
-max-attributes=7
-# Maximum number of boolean expressions in a if statement
-max-bool-expr=5
-# Maximum number of branch for function / method body
-max-branches=12
-# Maximum number of locals for function / method body
-max-locals=15
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-# Maximum number of public methods for a class (see R0904).
-max-public-methods=20
-# Maximum number of return / yield for function / method body
-max-returns=6
-# Maximum number of statements in function / method body
-max-statements=50
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
+[tool.pylint.typecheck]
+contextmanager-decorators = [
+ "contextlib.contextmanager",
+ "bentoml._internal.utils.cached_contextmanager",
+]
[tool.isort]
profile = "black"
+line_length = 88
length_sort = true
force_single_line = true
order_by_type = true
force_alphabetical_sort_within_sections = true
-skip_glob = ["typings/*", "docs/*"]
+skip_glob = [
+ "typings/*",
+ "test/*",
+ "**/*_pb2.py",
+ "**/*_pb2_grpc.py",
+ "venv/*",
+ "lib/*",
+]
[tool.pyright]
pythonVersion = "3.10"
-include = ["bentoml"]
-exclude = ['bentoml/_version.py','bentoml/__main__.py']
-analysis.useLibraryCodeForTypes = true
+include = ["bentoml/"]
+exclude = [
+ 'bentoml/_version.py',
+ 'bentoml/__main__.py',
+ 'bentoml/_internal/external_typing/',
+ 'bentoml/grpc/v1alpha1/',
+ '**/*_pb2.py',
+ "**/*_pb2_grpc.py",
+]
+useLibraryCodeForTypes = true
strictListInference = true
strictDictionaryInference = true
strictSetInference = true
diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt
index cf6a6b9b4f..6a6a481f78 100644
--- a/requirements/dev-requirements.txt
+++ b/requirements/dev-requirements.txt
@@ -1,5 +1,4 @@
# Dev dependencies
-r tests-requirements.txt
-setup-cfg-fmt
twine
wheel
diff --git a/requirements/tests-requirements.txt b/requirements/tests-requirements.txt
index 18a506eb51..fa6822a0f3 100644
--- a/requirements/tests-requirements.txt
+++ b/requirements/tests-requirements.txt
@@ -8,11 +8,13 @@ pydantic
pylint>=2.14.0
pytest-cov>=3.0.0
pytest>=6.2.0
+pytest-xdist
pytest-asyncio
pandas
scikit-learn
imageio>=2.5.0
-watchfiles>=0.15.0
pyarrow
-build >=0.8.0
+build[virtualenv] >=0.8.0
yamllint
+grpcio-tools>=1.41.0
+opentelemetry-test-utils==0.33b0
diff --git a/scripts/generate_grpc_stubs.sh b/scripts/generate_grpc_stubs.sh
new file mode 100755
index 0000000000..ae859fbe11
--- /dev/null
+++ b/scripts/generate_grpc_stubs.sh
@@ -0,0 +1,49 @@
+#!/usr/bin/env bash
+
+GIT_ROOT=$(git rev-parse --show-toplevel)
+STUBS_GENERATOR="bentoml/stubs-generator"
+
+cd "$GIT_ROOT" || exit 1
+
+main() {
+ local VERSION="${1:-v1alpha1}"
+ # Use inline heredoc for even faster build
+ # Keeping image as cache should be fine since we only want to generate the stubs.
+ if [[ $(docker images --filter=reference="$STUBS_GENERATOR" -q) == "" ]] || test "$(git diff --name-only --diff-filter=d -- "$0")"; then
+ docker buildx build --platform=linux/amd64 -t "$STUBS_GENERATOR" --load -f- . <=3.15.0, and 3.10 support are provided since protobuf>=3.19.0
+ protobuf==3.19.4
+ grpcio-tools==1.41
+ mypy-protobuf>=3.3.0
+EOT
+
+RUN --mount=type=cache,target=/var/lib/apt \
+ --mount=type=cache,target=/var/cache/apt \
+ apt-get update && apt-get install -q -y --no-install-recommends --allow-remove-essential bash build-essential ca-certificates
+
+RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements.txt \
+ && rm -rf /workspace/requirements.txt
+
+EOF
+ fi
+
+ echo "Generating gRPC stubs..."
+ find bentoml/grpc/"$VERSION" -type f -name "*.proto" -exec docker run --rm -it -v "$GIT_ROOT":/workspace --platform=linux/amd64 "$STUBS_GENERATOR" python -m grpc_tools.protoc -I. --grpc_python_out=. --python_out=. --mypy_out=. --mypy_grpc_out=. "{}" \;
+}
+
+if [ "${#}" -gt 1 ]; then
+ echo "$0 takes one optional argument. Usage: $0 [v1alpha2]"
+ exit 1
+fi
+main "$@"
diff --git a/scripts/release_pypi.sh b/scripts/release_pypi.sh
index db0b6953ad..07fb37e08f 100755
--- a/scripts/release_pypi.sh
+++ b/scripts/release_pypi.sh
@@ -52,7 +52,11 @@ else
fi
echo "Generating PyPI source distribution..."
-cd "$GIT_ROOT"
+cd "$GIT_ROOT" || exit 1
+
+# generate gRPC stubs
+./scripts/generate_grpc_stubs.sh
+
python3 -m build -s -w
# Use testpypi by default, run script with: "REPO=pypi release.sh" for
diff --git a/setup.cfg b/setup.cfg
index ca71c0eaa4..4a4fce4a84 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -67,6 +67,7 @@ install_requires =
uvicorn
watchfiles>=0.15.0
backports.cached-property;python_version<"3.8"
+ backports.shutil_copytree;python_version<"3.8"
importlib-metadata;python_version<"3.8"
python_requires = >=3.7
include_package_data = True
@@ -76,6 +77,7 @@ zip_safe = False
include =
# include bentoml packages
bentoml
+ bentoml.grpc*
bentoml.testing
bentoml._internal*
# include bentoml_cli packages
@@ -87,7 +89,23 @@ console_scripts =
bentoml = bentoml_cli.cli:cli
[options.extras_require]
+grpc =
+ # Restrict maximum version due to breaking protobuf 4.21.0 changes
+ # (see https://github.com/protocolbuffers/protobuf/issues/10051)
+ protobuf>=3.5.0, <3.20
+ # Lowest version that support 3.10
+ grpcio>=1.41.0
+ grpcio-health-checking
+ grpcio-reflection
+ opentelemetry-instrumentation-grpc==0.33b0
+tracing.jaeger =
+ opentelemetry-exporter-jaeger
+tracing.zipkin =
+ opentelemetry-exporter-zipkin
+tracing.otlp =
+ opentelemetry-exporter-otlp
tracing =
+ # kept for compatibility
opentelemetry-exporter-jaeger
opentelemetry-exporter-zipkin
opentelemetry-exporter-otlp
@@ -105,30 +123,3 @@ keep_temp = false
[sdist]
formats = gztar
-
-[coverage:run]
-omit =
- bentoml/__main__.py
- bentoml/_internal/types.py
- bentoml/_internal/external_typing/*
- bentoml/testing/*
- bentoml/io.py
-
-[coverage:report]
-show_missing = true
-precision = 2
-omit =
- bentoml/_internal/external_typing/*
- bentoml/_internal/types.py
- bentoml/testing/*
- bentoml/__main__.py
- bentoml/io.py
-exclude_lines =
- pragma: no cover
- def __repr__
- raise AssertionError
- raise NotImplementedError
- raise MissingDependencyException
- except ImportError
- if __name__ == .__main__.:
- if TYPE_CHECKING:
diff --git a/tests/unit/_internal/io/test_numpy.py b/tests/unit/_internal/io/test_numpy.py
index 811891af76..4f4d5765cf 100644
--- a/tests/unit/_internal/io/test_numpy.py
+++ b/tests/unit/_internal/io/test_numpy.py
@@ -32,7 +32,7 @@ def test_invalid_dtype():
generic = ExampleGeneric("asdf")
with pytest.raises(BentoMLException) as e:
_ = NumpyNdarray.from_sample(generic) # type: ignore (test exception)
- assert "expects a numpy.array" in str(e.value)
+ assert "expects a 'numpy.array'" in str(e.value)
@pytest.mark.parametrize("dtype, expected", [("float", "number"), (">U8", "integer")])
@@ -82,9 +82,7 @@ def test_numpy_openapi_responses():
def test_verify_numpy_ndarray(caplog: LogCaptureFixture):
- partial_check = partial(
- from_example._verify_ndarray, exception_cls=BentoMLException # type: ignore (test internal check)
- )
+ partial_check = partial(from_example.validate_array, exception_cls=BentoMLException)
with pytest.raises(BentoMLException) as ex:
partial_check(np.array(["asdf"]))
@@ -94,10 +92,10 @@ def test_verify_numpy_ndarray(caplog: LogCaptureFixture):
partial_check(np.array([[1]]))
assert f'Expecting ndarray of shape "{from_example._shape}"' in str(e.value) # type: ignore (testing message)
- # test cases whwere reshape is failed
+ # test cases where reshape is failed
example = NumpyNdarray.from_sample(np.ones((2, 2, 3)))
example._enforce_shape = False # type: ignore (test internal check)
example._enforce_dtype = False # type: ignore (test internal check)
with caplog.at_level(logging.DEBUG):
- example._verify_ndarray(np.array("asdf"))
+ example.validate_array(np.array("asdf"))
assert "Failed to reshape" in caplog.text