Skip to content

Commit

Permalink
add api v1 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Nov 8, 2022
1 parent 9502db3 commit 65ade2a
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions src/bentoml/_internal/frameworks/onnx.py
Expand Up @@ -7,7 +7,6 @@
from typing import TYPE_CHECKING

import attr
import numpy as np

import bentoml
from bentoml import Tag
Expand All @@ -17,25 +16,20 @@
from bentoml.exceptions import BentoMLException
from bentoml.exceptions import MissingDependencyException

from ..types import LazyType
from ..utils.pkg import get_pkg_version
from ..utils.pkg import PackageNotFoundError
from ..utils.onnx import gen_input_casting_func
from ..runner.utils import Params

if TYPE_CHECKING:
import torch # noqa

from bentoml.types import ModelSignature
from bentoml.types import ModelSignatureDict

from .. import external_typing as ext
from ..external_typing import tensorflow as tf_ext # noqa

ProvidersType = list[str | tuple[str, dict[str, t.Any]]]
from ..utils.onnx import ONNXArgType
from ..utils.onnx import ONNXArgCastedType

ProvidersType = list[str | tuple[str, dict[str, t.Any]]]


try:
import onnx
Expand Down Expand Up @@ -112,6 +106,16 @@ def get(tag_like: str | Tag) -> bentoml.Model:
return model


def _load_raw_model(bento_model: str | Tag | bentoml.Model) -> onnx.ModelProto:

if not isinstance(bento_model, bentoml.Model):
bento_model = get(bento_model)

model_path = bento_model.path_of(MODEL_FILENAME)
raw_model = onnx.load(model_path)
return raw_model


def load_model(
bento_model: str | Tag | bentoml.Model,
*,
Expand Down Expand Up @@ -327,6 +331,25 @@ def get_runnable(bento_model: bentoml.Model) -> t.Type[bentoml.Runnable]:
Private API: use :obj:`~bentoml.Model.to_runnable` instead.
"""

# backward compatibility for v1, load raw model to infer
# input_specs/output_specs for onnx model
if bento_model.info.api_version == "v1":

raw_model: onnx.ModelProto | None = None
options = t.cast(ONNXOptions, bento_model.info.options)

if not options.input_specs:
raw_model = _load_raw_model(bento_model)
run_input_specs = [MessageToDict(inp) for inp in raw_model.graph.input]
input_specs = {"run": run_input_specs}
bento_model = bento_model.with_options(input_specs=input_specs)

if not options.output_specs:
raw_model = raw_model or _load_raw_model(bento_model)
run_output_specs = [MessageToDict(out) for out in raw_model.graph.output]
output_specs = {"run": run_output_specs}
bento_model = bento_model.with_options(output_specs=output_specs)

class ONNXRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ("nvidia.com/gpu", "cpu")
SUPPORTS_CPU_MULTI_THREADING = True
Expand Down

0 comments on commit 65ade2a

Please sign in to comment.