From 61c84ddd1d0f9135240b049ae1c104c6a7a3a1c0 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Mon, 8 Nov 2021 15:44:30 +0800 Subject: [PATCH] add tests Signed-off-by: Weichen Xu --- tests/models/test_model_input_examples.py | 20 ++++++++++++++++++++ tests/types/test_schema.py | 17 +++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tests/models/test_model_input_examples.py b/tests/models/test_model_input_examples.py index ea9be72ea54cd..e89feb9268285 100644 --- a/tests/models/test_model_input_examples.py +++ b/tests/models/test_model_input_examples.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +from scipy.sparse import csr_matrix, csc_matrix from mlflow.models.signature import infer_signature from mlflow.models.utils import _Example, _read_tensor_input_from_json @@ -44,6 +45,14 @@ def dict_of_ndarrays(): } +@pytest.fixture +def dict_of_sparse_matrix(): + return { + "csc": csc_matrix(np.arange(0, 12, 0.5).reshape(3, 8)), + "csr": csr_matrix(np.arange(0, 12, 0.5).reshape(3, 8)) + } + + def test_input_examples(pandas_df_with_all_types, dict_of_ndarrays): sig = infer_signature(pandas_df_with_all_types) # test setting example with data frame with all supported data types @@ -117,3 +126,14 @@ def test_input_examples(pandas_df_with_all_types, dict_of_ndarrays): filename = x.info["artifact_path"] parsed_df = _dataframe_from_json(tmp.path(filename)) assert example == parsed_df.to_dict(orient="records")[0] + + +def test_sparse_matrix_input_examples(dict_of_sparse_matrix): + for col in dict_of_sparse_matrix: + input_example = dict_of_sparse_matrix[col] + with TempDir() as tmp: + example = _Example(input_example) + example.save(tmp.path()) + filename = example.info["artifact_path"] + parsed_ary = _read_tensor_input_from_json(tmp.path(filename)) + assert np.array_equal(parsed_ary, input_example.toarray()) diff --git a/tests/types/test_schema.py b/tests/types/test_schema.py index 6682c482c4e74..d10cea3fe495d 100644 --- a/tests/types/test_schema.py +++ b/tests/types/test_schema.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +from scipy.sparse import csr_matrix, csc_matrix from mlflow.exceptions import MlflowException from mlflow.pyfunc import _enforce_tensor_spec @@ -251,6 +252,22 @@ def test_get_tensor_shape(dict_of_ndarrays): _infer_schema({"x": 1}) +@pytest.fixture +def dict_of_sparse_matrix(): + return { + "csc": csc_matrix(np.arange(0, 12, 0.5).reshape(3, 8)), + "csr": csr_matrix(np.arange(0, 12, 0.5).reshape(3, 8)) + } + + +def test_get_sparse_matrix_data_type_and_shape(dict_of_sparse_matrix): + for col in dict_of_sparse_matrix: + sparse_matrix = dict_of_sparse_matrix[col] + schema = _infer_schema(sparse_matrix) + schema.numpy_types() == ["float64"] + _get_tensor_shape(sparse_matrix) == (3, 8) + + def test_schema_inference_on_dictionary(dict_of_ndarrays): # test dictionary schema = _infer_schema(dict_of_ndarrays)