Skip to content

Commit

Permalink
add more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
larme committed Nov 7, 2022
1 parent 21fc08e commit e3f8dd0
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/frameworks.yml
Expand Up @@ -381,7 +381,7 @@ jobs:
- name: Install dependencies
run: |
pip install .
pip install onnx onnxruntime
pip install onnx onnxruntime skl2onnx
pip install -r requirements/tests-requirements.txt
- name: Run tests and generate coverage report
Expand Down
152 changes: 149 additions & 3 deletions tests/integration/frameworks/models/onnx.py
Expand Up @@ -3,19 +3,34 @@
import os
import typing as t
import tempfile
from typing import Callable
from typing import TYPE_CHECKING

import onnx
import numpy as np
import torch
import sklearn
import torch.nn as nn
import onnxruntime as ort
from pytest import CallInfo
from skl2onnx import convert_sklearn
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.common.data_types import Int64TensorType
from skl2onnx.common.data_types import StringTensorType

import bentoml

from . import FrameworkTestModel
from . import FrameworkTestModelInput as Input
from . import FrameworkTestModelConfiguration as Config

if TYPE_CHECKING:
import bentoml._internal.external_typing as ext

framework = bentoml.onnx

backward_compatible = True
Expand Down Expand Up @@ -114,12 +129,12 @@ def make_pytorch_onnx_model(tmpdir):


with tempfile.TemporaryDirectory() as tmpdir:
pytorch_model = make_pytorch_onnx_model(tmpdir)
onnx_pytorch_raw_model = make_pytorch_onnx_model(tmpdir)


onnx_pytorch_model = FrameworkTestModel(
name="onnx_pytorch_model",
model=pytorch_model,
model=onnx_pytorch_raw_model,
model_method_caller=method_caller,
model_signatures={"run": {"batchable": True}},
configurations=[
Expand All @@ -136,4 +151,135 @@ def make_pytorch_onnx_model(tmpdir):
),
],
)
models: list[FrameworkTestModel] = [onnx_pytorch_model]


# sklearn random forest with multiple outputs
def make_rf_onnx_model() -> tuple[
onnx.ModelProto, tuple[ext.NpNDArray, tuple[ext.NpNDArray, ext.NpNDArray]]
]:
iris: sklearn.utils.Bunch = load_iris()
X: ext.NpNDArray = iris.data
y: ext.NpNDArray = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

initial_type = [("float_input", FloatTensorType([None, 4]))]
onnx_model = t.cast(
onnx.ModelProto, convert_sklearn(clr, initial_types=initial_type)
)
expected_input = t.cast("ext.NpNDArray", X_test[:2])
expected_output1 = t.cast("ext.NpNDArray", clr.predict(expected_input))
expected_output2 = t.cast("ext.NpNDArray", clr.predict_proba(expected_input))
expected_output = (expected_output1, expected_output2)
expected_data = (expected_input, expected_output)
return (onnx_model, expected_data)


# the output of onnxruntime has a different format from the output of
# the original model, we need generate a function to adapt the outputs
# of onnxruntime (also the BentoML runner) to the outputs of original
# model
def gen_rf_output_checker(
expected_output: tuple[ext.NpNDArray, ext.NpNDArray]
) -> t.Callable[[t.Any], bool]:
expected_output1, expected_output2 = expected_output

def _check(out: tuple[ext.NpNDArray, list[dict[int, float]]]) -> bool:
out1, out2 = out
flag1 = (out1 == expected_output1).all()
out2_lst = [[d[idx] for idx in sorted(d.keys())] for d in out2]
flag2 = t.cast(
bool, np.isclose(np.array(out2_lst), expected_output2, rtol=1e-3).all()
)
return flag1 and flag2

return _check


onnx_rf_raw_model, _expected_data = make_rf_onnx_model()
rf_input, rf_expected_output = _expected_data

onnx_rf_model = FrameworkTestModel(
name="onnx_rf_model",
model=onnx_rf_raw_model,
model_method_caller=method_caller,
model_signatures={"run": {"batchable": True}},
configurations=[
Config(
test_inputs={
"run": [
Input(
input_args=[rf_input],
expected=gen_rf_output_checker(rf_expected_output),
),
],
},
check_model=check_model,
),
],
)


# sklearn label encoder testing int and string input types
LT = t.TypeVar("LT")


def make_le_onnx_model(
labels: list[LT], tensor_type: type
) -> tuple[onnx.ModelProto, tuple[list[list[LT]], ext.NpNDArray]]:
le = LabelEncoder()
le.fit(labels)

initial_type = [("tensor_input", tensor_type([None, 1]))]
onnx_model = t.cast(
onnx.ModelProto, convert_sklearn(le, initial_types=initial_type)
)
expected_input = [[labels[0]], [labels[1]]]

expected_output = t.cast("ext.NpNDArray", le.transform(expected_input))
expected_data = (expected_input, expected_output)
return (onnx_model, expected_data)


onnx_le_models = []
int_labels = [5, 2, 3]
str_labels = ["apple", "orange", "cat"]

for labels, tensor_type in [
(int_labels, Int64TensorType),
(str_labels, StringTensorType),
]:
onnx_le_raw_model, expected_data = make_le_onnx_model(labels, tensor_type)
le_input, le_expected_output = expected_data

def _check(
out: ext.NpNDArray, expected_out: ext.NpNDArray = le_expected_output
) -> bool:
# LabelEncoder's raw output have one less dim than the onnxruntime's output
flat_out = np.squeeze(out, axis=1)
return (expected_out == flat_out).all()

onnx_le_model = FrameworkTestModel(
name=f"onnx_le_model_{tensor_type.__name__.lower()}",
model=onnx_le_raw_model,
model_method_caller=method_caller,
model_signatures={"run": {"batchable": True}},
configurations=[
Config(
test_inputs={
"run": [
Input(
input_args=[le_input],
expected=_check,
),
],
},
check_model=check_model,
),
],
)
onnx_le_models.append(onnx_le_model)


models: list[FrameworkTestModel] = [onnx_pytorch_model, onnx_rf_model] + onnx_le_models

0 comments on commit e3f8dd0

Please sign in to comment.