Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Nov 7, 2021
1 parent 85e056d commit 076d12d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mlflow/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _handle_tensor_input(input_tensor: Union[np.ndarray, dict, csr_matrix, csc_m
elif isinstance(input_tensor, np.ndarray):
return {"inputs": input_tensor.tolist()}
else:
return input_tensor.toarray().tolist()
return {"inputs": input_tensor.toarray().tolist()}

def _handle_dataframe_input(input_ex):
if isinstance(input_ex, dict):
Expand Down
8 changes: 6 additions & 2 deletions tests/sklearn/test_sklearn_autolog.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sklearn.pipeline
import sklearn.model_selection
from scipy.stats import uniform
from scipy.sparse import csr_matrix, csc_matrix

from mlflow.exceptions import MlflowException
from mlflow.models import Model
Expand Down Expand Up @@ -849,13 +850,16 @@ def test_parameter_search_handles_large_volume_of_metric_outputs():
assert len(child_run.data.metrics) >= metrics_size


@pytest.mark.parametrize("data_type", [pd.DataFrame, np.array])
@pytest.mark.parametrize("data_type", [pd.DataFrame, np.array, csr_matrix, csc_matrix])
def test_autolog_logs_signature_and_input_example(data_type):
mlflow.sklearn.autolog(log_input_examples=True, log_model_signatures=True)

X, y = get_iris()
X = data_type(X)
y = data_type(y)
if data_type in [csr_matrix, csc_matrix]:
y = np.array(y)
else:
y = data_type(y)
model = sklearn.linear_model.LinearRegression()

with mlflow.start_run() as run:
Expand Down

0 comments on commit 076d12d

Please sign in to comment.