Skip to content

Commit

Permalink
MLflow Schema enforcement should not cast object to pandas String (#5134
Browse files Browse the repository at this point in the history
)

* remove cast

Signed-off-by: Steven Chen <s.chen@databricks.com>

* tests

Signed-off-by: Steven Chen <s.chen@databricks.com>
  • Loading branch information
stevenchen-db committed Dec 3, 2021
1 parent e4a7df9 commit 6a88bae
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
15 changes: 3 additions & 12 deletions mlflow/pyfunc/__init__.py
Expand Up @@ -326,18 +326,9 @@ def _enforce_mlflow_datatype(name, values: pandas.Series, t: DataType):
values = values.infer_objects()

if t == DataType.string and values.dtype == np.object:
# NB: strings are by default parsed and inferred as objects, but it is
# recommended to use StringDtype extension type if available. See
#
# `https://pandas.pydata.org/pandas-docs/stable/user_guide/text.html`
#
# for more detail.
try:
return values.astype(t.to_pandas(), errors="raise")
except ValueError:
raise MlflowException(
"Failed to convert column {0} from type {1} to {2}.".format(name, values.dtype, t)
)
# NB: the object can contain any type and we currently cannot cast to pandas Strings
# due to how None is cast
return values

# NB: Comparison of pandas and numpy data type fails when numpy data type is on the left hand
# side of the comparison operator. It works, however, if pandas type is on the left hand side.
Expand Down
Expand Up @@ -179,6 +179,9 @@ def test_column_schema_enforcement():
expected_types = dict(zip(input_schema.input_names(), input_schema.pandas_types()))
# MLflow datetime type in input_schema does not encode precision, so add it for assertions
expected_types["h"] = np.dtype("datetime64[ns]")
# np.object cannot be converted to pandas Strings at the moment
expected_types["f"] = np.object
expected_types["g"] = np.object
actual_types = res.dtypes.to_dict()
assert expected_types == actual_types

Expand Down
7 changes: 5 additions & 2 deletions tests/pyfunc/test_scoring_server.py
Expand Up @@ -542,15 +542,18 @@ def predict(self, context, model_input):
extra_args=["--no-conda"],
)
response_json = json.loads(response.content)
assert response_json == [[k, str(v)] for k, v in pandas_df_with_all_types.dtypes.items()]

# np.objects are not converted to pandas Strings at the moment
expected_types = {**pandas_df_with_all_types.dtypes, "string": np.dtype(object)}
assert response_json == [[k, str(v)] for k, v in expected_types.items()]
response = pyfunc_serve_and_score_model(
model_uri="runs:/{}/model".format(run.info.run_id),
data=json.dumps(pandas_df_with_all_types.to_dict(orient="records"), cls=NumpyEncoder),
content_type=pyfunc_scoring_server.CONTENT_TYPE_JSON_RECORDS_ORIENTED,
extra_args=["--no-conda"],
)
response_json = json.loads(response.content)
assert response_json == [[k, str(v)] for k, v in pandas_df_with_all_types.dtypes.items()]
assert response_json == [[k, str(v)] for k, v in expected_types.items()]


@pytest.mark.large
Expand Down

0 comments on commit 6a88bae

Please sign in to comment.