Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Nov 8, 2021
1 parent 646b60d commit 61c84dd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/models/test_model_input_examples.py
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import pytest
from scipy.sparse import csr_matrix, csc_matrix

from mlflow.models.signature import infer_signature
from mlflow.models.utils import _Example, _read_tensor_input_from_json
Expand Down Expand Up @@ -44,6 +45,14 @@ def dict_of_ndarrays():
}


@pytest.fixture
def dict_of_sparse_matrix():
return {
"csc": csc_matrix(np.arange(0, 12, 0.5).reshape(3, 8)),
"csr": csr_matrix(np.arange(0, 12, 0.5).reshape(3, 8))
}


def test_input_examples(pandas_df_with_all_types, dict_of_ndarrays):
sig = infer_signature(pandas_df_with_all_types)
# test setting example with data frame with all supported data types
Expand Down Expand Up @@ -117,3 +126,14 @@ def test_input_examples(pandas_df_with_all_types, dict_of_ndarrays):
filename = x.info["artifact_path"]
parsed_df = _dataframe_from_json(tmp.path(filename))
assert example == parsed_df.to_dict(orient="records")[0]


def test_sparse_matrix_input_examples(dict_of_sparse_matrix):
for col in dict_of_sparse_matrix:
input_example = dict_of_sparse_matrix[col]
with TempDir() as tmp:
example = _Example(input_example)
example.save(tmp.path())
filename = example.info["artifact_path"]
parsed_ary = _read_tensor_input_from_json(tmp.path(filename))
assert np.array_equal(parsed_ary, input_example.toarray())
17 changes: 17 additions & 0 deletions tests/types/test_schema.py
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import pytest
from scipy.sparse import csr_matrix, csc_matrix

from mlflow.exceptions import MlflowException
from mlflow.pyfunc import _enforce_tensor_spec
Expand Down Expand Up @@ -251,6 +252,22 @@ def test_get_tensor_shape(dict_of_ndarrays):
_infer_schema({"x": 1})


@pytest.fixture
def dict_of_sparse_matrix():
return {
"csc": csc_matrix(np.arange(0, 12, 0.5).reshape(3, 8)),
"csr": csr_matrix(np.arange(0, 12, 0.5).reshape(3, 8))
}


def test_get_sparse_matrix_data_type_and_shape(dict_of_sparse_matrix):
for col in dict_of_sparse_matrix:
sparse_matrix = dict_of_sparse_matrix[col]
schema = _infer_schema(sparse_matrix)
schema.numpy_types() == ["float64"]
_get_tensor_shape(sparse_matrix) == (3, 8)


def test_schema_inference_on_dictionary(dict_of_ndarrays):
# test dictionary
schema = _infer_schema(dict_of_ndarrays)
Expand Down

0 comments on commit 61c84dd

Please sign in to comment.