Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: refine onnx implementation #3166

Merged
merged 5 commits into from Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/frameworks.yml
Expand Up @@ -381,7 +381,7 @@ jobs:
- name: Install dependencies
run: |
pip install .
pip install onnx onnxruntime
pip install onnx onnxruntime skl2onnx
pip install -r requirements/tests-requirements.txt

- name: Run tests and generate coverage report
Expand Down
127 changes: 79 additions & 48 deletions src/bentoml/_internal/frameworks/onnx.py
Expand Up @@ -7,7 +7,6 @@
from typing import TYPE_CHECKING

import attr
import numpy as np

import bentoml
from bentoml import Tag
Expand All @@ -17,26 +16,25 @@
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import MissingDependencyException

from ..types import LazyType
from ..utils.pkg import get_pkg_version
from ..utils.pkg import PackageNotFoundError
from ..runner.utils import Params
from ..utils.onnx import gen_input_casting_func

if TYPE_CHECKING:
import torch # noqa

from bentoml.types import ModelSignature
from bentoml.types import ModelSignatureDict

from .. import external_typing as ext
from ..external_typing import tensorflow as tf_ext # noqa
from ..utils.onnx import ONNXArgType
from ..utils.onnx import ONNXArgCastedType

ProvidersType = list[str | tuple[str, dict[str, t.Any]]]
ONNXArgType = ext.NpNDArray | ext.PdDataFrame | torch.Tensor | tf_ext.Tensor


try:
import onnx
from google.protobuf.json_format import MessageToDict
larme marked this conversation as resolved.
Show resolved Hide resolved

except ImportError: # pragma: no cover
raise MissingDependencyException(
"onnx is required in order to use module 'bentoml.onnx', install onnx with 'pip install onnx'. For more information, refer to https://onnx.ai/get-started.html"
Expand All @@ -51,7 +49,7 @@

MODULE_NAME = "bentoml.onnx"
MODEL_FILENAME = "saved_model.onnx"
API_VERSION = "v1"
API_VERSION = "v2"

logger = logging.getLogger(__name__)

Expand All @@ -76,6 +74,8 @@ def flatten_list(lst: t.List[t.Any]) -> t.List[str]: # pragma: no cover
class ONNXOptions(ModelOptions):
"""Options for the ONNX model"""

input_specs: dict[str, list[dict[str, t.Any]]] = attr.field(factory=dict)
output_specs: dict[str, list[dict[str, t.Any]]] = attr.field(factory=dict)
providers: t.Optional[list[str]] = attr.field(default=None)
session_options: t.Optional["ort.SessionOptions"] = attr.field(default=None)

Expand Down Expand Up @@ -106,6 +106,16 @@ def get(tag_like: str | Tag) -> bentoml.Model:
return model


def _load_raw_model(bento_model: str | Tag | bentoml.Model) -> onnx.ModelProto:

if not isinstance(bento_model, bentoml.Model):
bento_model = get(bento_model)

model_path = bento_model.path_of(MODEL_FILENAME)
raw_model = onnx.load(model_path)
return raw_model


def load_model(
bento_model: str | Tag | bentoml.Model,
*,
Expand Down Expand Up @@ -162,7 +172,7 @@ def save_model(
name: str,
model: onnx.ModelProto,
*,
signatures: dict[str, ModelSignatureDict | ModelSignature] | None = None,
signatures: dict[str, ModelSignatureDict] | dict[str, ModelSignature] | None = None,
labels: dict[str, str] | None = None,
custom_objects: dict[str, t.Any] | None = None,
external_modules: t.List[ModuleType] | None = None,
Expand Down Expand Up @@ -265,6 +275,9 @@ def forward(self, x, bias):
), "Failed to find onnxruntime package version."

assert _onnxruntime_version is not None, "onnxruntime is not installed"
if not isinstance(model, onnx.ModelProto):
raise TypeError(f"Given model ({model}) is not a onnx.ModelProto.")

context = ModelContext(
framework_name="onnx",
framework_versions={
Expand All @@ -289,15 +302,12 @@ def forward(self, x, bias):
f"Provided method names {[m for m in provided_methods if m != 'run']} are invalid. 'bentoml.onnx' will load ONNX model into an 'onnxruntime.InferenceSession' for inference, so the only supported method name is 'run'."
)

if not isinstance(model, onnx.ModelProto):
raise TypeError(f"Given model ({model}) is not a onnx.ModelProto.")

if len(model.graph.output) > 1:
logger.warning(
"The model you are attempting to save has more than one output. The ONNX runner will only return the first output."
)
run_input_specs = [MessageToDict(inp) for inp in model.graph.input]
run_output_specs = [MessageToDict(out) for out in model.graph.output]
input_specs = {"run": run_input_specs}
output_specs = {"run": run_output_specs}

options = ONNXOptions()
options = ONNXOptions(input_specs=input_specs, output_specs=output_specs)

with bentoml.models.create(
name,
Expand All @@ -321,6 +331,25 @@ def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]:
Private API: use :obj:`~bentoml.Model.to_runnable` instead.
"""

# backward compatibility for v1, load raw model to infer
# input_specs/output_specs for onnx model
if bento_model.info.api_version == "v1":

raw_model: onnx.ModelProto | None = None
options = t.cast(ONNXOptions, bento_model.info.options)

if not options.input_specs:
aarnphm marked this conversation as resolved.
Show resolved Hide resolved
raw_model = _load_raw_model(bento_model)
run_input_specs = [MessageToDict(inp) for inp in raw_model.graph.input]
input_specs = {"run": run_input_specs}
bento_model = bento_model.with_options(input_specs=input_specs)

if not options.output_specs:
raw_model = raw_model or _load_raw_model(bento_model)
run_output_specs = [MessageToDict(out) for out in raw_model.graph.output]
output_specs = {"run": run_output_specs}
bento_model = bento_model.with_options(output_specs=output_specs)

class ONNXRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
Expand Down Expand Up @@ -376,48 +405,50 @@ def __init__(self):
for method_name in bento_model.info.signatures:
self.predict_fns[method_name] = getattr(self.model, method_name)

def _mapping(item: ONNXArgType) -> ext.NpNDArray:

# currently ort only support np.float32
# https://onnxruntime.ai/docs/api/python/auto_examples/plot_common_errors.html
# TODO @larme: what dtype of input to use if we use FP16? ref: onnxruntime issue 11768
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(item):
if item.dtype != np.float32:
item = item.astype(np.float32, copy=False)
elif LazyType["ext.PdDataFrame"]("pandas.DataFrame").isinstance(item):
item = item.to_numpy(dtype=np.float32)
elif LazyType["tf.Tensor"]("tensorflow.Tensor").isinstance(item):
item = np.array(memoryview(item)).astype(np.float32, copy=False)
elif LazyType["torch.Tensor"]("torch.Tensor").isinstance(item):
item = item.numpy().astype(np.float32, copy=False)
def add_runnable_method(
method_name: str,
signatures: ModelSignature,
input_specs: list[dict[str, t.Any]],
output_specs: list[dict[str, t.Any]],
):

casting_funcs = [gen_input_casting_func(spec) for spec in input_specs]

if len(output_specs) > 1:

def _process_output(outs):
return tuple(outs)

else:
raise TypeError(
"'run' of ONNXRunnable only takes 'numpy.ndarray' or 'pd.DataFrame', 'tf.Tensor', or 'torch.Tensor' as input parameters."
)
# This cast is safe since we have already checked the type of item
return t.cast("ext.NpNDArray", item)

def add_runnable_method(method_name: str, options: ModelSignature):
def _process_output(outs):
return outs[0]

def _run(self: ONNXRunnable, *args: ONNXArgType) -> t.Any:
params = Params["ONNXArgType"](*args)
params = params.map(_mapping)
casted_args = [
casting_funcs[idx](args[idx]) for idx in range(len(casting_funcs))
]

input_names: dict[str, ext.NpNDArray] = {
i.name: val for i, val in zip(self.model.get_inputs(), params.args)
input_names: dict[str, ONNXArgCastedType] = {
i.name: val for i, val in zip(self.model.get_inputs(), casted_args)
}
output_names: list[str] = [o.name for o in self.model.get_outputs()]
return self.predict_fns[method_name](output_names, input_names)[0]
raw_outs = self.predict_fns[method_name](output_names, input_names)
return _process_output(raw_outs)

ONNXRunnable.add_method(
_run,
name=method_name,
batchable=options.batchable,
batch_dim=options.batch_dim,
input_spec=options.input_spec,
output_spec=options.output_spec,
batchable=signatures.batchable,
batch_dim=signatures.batch_dim,
input_spec=signatures.input_spec,
output_spec=signatures.output_spec,
)

for method_name, options in bento_model.info.signatures.items():
add_runnable_method(method_name, options)
for method_name, signatures in bento_model.info.signatures.items():
options = t.cast(ONNXOptions, bento_model.info.options)
input_specs = options.input_specs[method_name]
output_specs = options.output_specs[method_name]
add_runnable_method(method_name, signatures, input_specs, output_specs)

return ONNXRunnable
160 changes: 160 additions & 0 deletions src/bentoml/_internal/utils/onnx.py
@@ -0,0 +1,160 @@
from __future__ import annotations
aarnphm marked this conversation as resolved.
Show resolved Hide resolved

import typing as t
import logging
from typing import TYPE_CHECKING

from bentoml.exceptions import BentoMLException

from ..types import LazyType
from .lazy_loader import LazyLoader

if TYPE_CHECKING:
import onnx
import torch

from .. import external_typing as ext
from ..external_typing import tensorflow as tf_ext # noqa

ONNXArgTensorType = (
ext.NpNDArray
| ext.PdDataFrame
| torch.Tensor
| tf_ext.Tensor
| list[int | float | str]
)
ONNXArgSequenceType = list["ONNXArgType"]
ONNXArgMapKeyType = int | str
ONNXArgMapType = dict[ONNXArgMapKeyType, "ONNXArgType"]
ONNXArgType = ONNXArgMapType | ONNXArgTensorType | ONNXArgSequenceType

ONNXArgCastedType = (
ext.NpNDArray
| list["ONNXArgCastedType"]
| dict[ONNXArgMapKeyType, "ONNXArgCastedType"]
)
ONNXArgCastingFuncType = t.Callable[[ONNXArgType], ONNXArgCastedType]
ONNXArgCastingFuncGeneratorType = t.Callable[
[dict[str, t.Any]], t.Callable[[ONNXArgType], ONNXArgCastedType]
]

else:
np = LazyLoader("np", globals(), "numpy")
onnx = LazyLoader(
"onnx",
globals(),
"onnx",
exc_msg="`onnx` is required to use bentoml.onnx module.",
)

logger = logging.getLogger(__name__)

TENSORPROTO_ELEMENT_TYPE_TO_NUMPY_TYPE: dict[int, str] = {
onnx.TensorProto.FLOAT: "float32", # 1
onnx.TensorProto.UINT8: "uint8", # 2
onnx.TensorProto.INT8: "int8", # 3
onnx.TensorProto.UINT16: "uint16", # 4
onnx.TensorProto.INT16: "int16", # 5
onnx.TensorProto.INT32: "int32", # 6
onnx.TensorProto.INT64: "int64", # 7
onnx.TensorProto.STRING: "str", # 8 or "unicode"?
onnx.TensorProto.BOOL: "bool", # 9
onnx.TensorProto.FLOAT16: "float16", # 10
onnx.TensorProto.DOUBLE: "double", # 11
onnx.TensorProto.UINT32: "uint32", # 12
onnx.TensorProto.UINT64: "uint64", # 13
onnx.TensorProto.COMPLEX64: "csingle", # 14
onnx.TensorProto.COMPLEX128: "cdouble", # 15
# onnx.TensorProto.BFLOAT16: None, # 16
}


CASTING_FUNC_DISPATCHER: dict[str, ONNXArgCastingFuncGeneratorType] = {
# type -> casting function generator
}


def gen_input_casting_func(spec: dict[str, t.Any]) -> ONNXArgCastingFuncType:
return _gen_input_casting_func(spec["type"])


def _gen_input_casting_func(sig: dict[str, t.Any]) -> ONNXArgCastingFuncType:

input_types = list(sig.keys())
if len(input_types) != 1:
raise BentoMLException(
"onnx model input type dictionary should have only one key!"
)
input_type = input_types[0]
input_spec = sig[input_type]
return CASTING_FUNC_DISPATCHER[input_type](input_spec)


def _gen_input_casting_func_for_tensor(
sig: dict[str, t.Any]
) -> t.Callable[[ONNXArgTensorType], ext.NpNDArray]:

elem_type = sig["elemType"]
to_dtype = TENSORPROTO_ELEMENT_TYPE_TO_NUMPY_TYPE[elem_type]

def _mapping(item: ONNXArgTensorType) -> ext.NpNDArray:
if LazyType["ext.NpNDArray"]("numpy.ndarray").isinstance(item):
item = item.astype(to_dtype, copy=False)
elif isinstance(item, list):
item = np.array(item).astype(to_dtype, copy=False)
elif LazyType["ext.PdDataFrame"]("pandas.DataFrame").isinstance(item):
item = item.to_numpy(dtype=to_dtype)
elif LazyType["tf.Tensor"]("tensorflow.Tensor").isinstance(item):
item = np.array(memoryview(item)).astype(to_dtype, copy=False) # type: ignore
elif LazyType["torch.Tensor"]("torch.Tensor").isinstance(item):
item = item.numpy().astype(to_dtype, copy=False)
else:
raise TypeError(
"`run` of ONNXRunnable only takes `numpy.ndarray`, `pd.DataFrame`, `tf.Tensor`, `torch.Tensor` or a list as input Tensor type"
)
return t.cast("ext.NpNDArray", item)

return _mapping


CASTING_FUNC_DISPATCHER["tensorType"] = t.cast(
"ONNXArgCastingFuncGeneratorType", _gen_input_casting_func_for_tensor
)


def _gen_input_casting_func_for_map(
sig: dict[str, t.Any]
) -> t.Callable[[ONNXArgMapType], dict[ONNXArgMapKeyType, ONNXArgCastedType]]:

map_value_sig = t.cast(dict[str, t.Any], sig["valueType"])
value_casting_func = _gen_input_casting_func(map_value_sig)

def _mapping(item: ONNXArgMapType) -> dict[ONNXArgMapKeyType, t.Any]:
new_item = {k: value_casting_func(v) for k, v in item.items()}
return new_item

return _mapping


CASTING_FUNC_DISPATCHER["mapType"] = t.cast(
"ONNXArgCastingFuncGeneratorType", _gen_input_casting_func_for_map
)


def _gen_input_casting_func_for_sequence(
sig: dict[str, t.Any]
) -> t.Callable[[ONNXArgSequenceType], list[t.Any]]:

seq_elem_sig = t.cast(dict[str, t.Any], sig["elemType"])
elem_casting_func = _gen_input_casting_func(seq_elem_sig)

def _mapping(item: ONNXArgSequenceType) -> list[t.Any]:
new_item = list(elem_casting_func(elem) for elem in item)
return new_item

return _mapping


CASTING_FUNC_DISPATCHER["sequenceType"] = t.cast(
"ONNXArgCastingFuncGeneratorType", _gen_input_casting_func_for_sequence
)