Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fixed model serve fail with HTTP 400 on Bad Request. #5003

Merged
merged 12 commits into from Dec 3, 2021
17 changes: 10 additions & 7 deletions mlflow/pyfunc/scoring_server/__init__.py
Expand Up @@ -41,7 +41,7 @@
from mlflow.pyfunc import load_model, PyFuncModel
except ImportError:
from mlflow.pyfunc import load_pyfunc as load_model
from mlflow.protos.databricks_pb2 import MALFORMED_REQUEST, BAD_REQUEST
from mlflow.protos.databricks_pb2 import BAD_REQUEST
from mlflow.server.handlers import catch_mlflow_exception

try:
Expand Down Expand Up @@ -87,7 +87,7 @@ def infer_and_parse_json_input(json_input, schema: Schema = None):
"Failed to parse input from JSON. Ensure that input is a valid JSON"
" formatted string."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)

if isinstance(decoded_input, list):
Expand All @@ -97,7 +97,10 @@ def infer_and_parse_json_input(json_input, schema: Schema = None):
try:
return parse_tf_serving_input(decoded_input, schema=schema)
except MlflowException as ex:
_handle_serving_error(error_message=(ex.message), error_code=MALFORMED_REQUEST)
_handle_serving_error(
error_message=(ex.message),
error_code=BAD_REQUEST,
)
else:
return parse_json_input(json_input=json_input, orient="split", schema=schema)
else:
Expand All @@ -106,7 +109,7 @@ def infer_and_parse_json_input(json_input, schema: Schema = None):
"Failed to parse input from JSON. Ensure that input is a valid JSON"
" list or dictionary."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand All @@ -129,7 +132,7 @@ def parse_json_input(json_input, orient="split", schema: Schema = None):
" produced using the `pandas.DataFrame.to_json(..., orient='{orient}')`"
" method.".format(orient=orient)
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand All @@ -148,7 +151,7 @@ def parse_csv_input(csv_input):
" a valid CSV-formatted Pandas DataFrame produced using the"
" `pandas.DataFrame.to_csv()` method."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand All @@ -173,7 +176,7 @@ def parse_split_oriented_json_input_to_numpy(json_input):
" produced using the `pandas.DataFrame.to_json(..., orient='split')`"
" method."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand Down
12 changes: 7 additions & 5 deletions tests/models/test_cli.py
Expand Up @@ -34,7 +34,7 @@
get_safe_port,
pyfunc_serve_and_score_model,
)
from mlflow.protos.databricks_pb2 import ErrorCode, MALFORMED_REQUEST
from mlflow.protos.databricks_pb2 import ErrorCode, BAD_REQUEST
from mlflow.pyfunc.scoring_server import (
CONTENT_TYPE_JSON_SPLIT_ORIENTED,
CONTENT_TYPE_JSON,
Expand Down Expand Up @@ -478,9 +478,11 @@ def _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enabl
# Try examples of bad input, verify we get a non-200 status code
for content_type in [CONTENT_TYPE_JSON_SPLIT_ORIENTED, CONTENT_TYPE_CSV, CONTENT_TYPE_JSON]:
scoring_response = endpoint.invoke(data="", content_type=content_type)
assert scoring_response.status_code == 500, (
"Expected server failure with error code 500, got response with status code %s "
"and body %s" % (scoring_response.status_code, scoring_response.text)
expected_status_code = 500 if enable_mlserver else 400
assert scoring_response.status_code == expected_status_code, (
"Expected server failure with error code %s, got response with status code %s "
"and body %s"
% (expected_status_code, scoring_response.status_code, scoring_response.text)
)

if enable_mlserver:
Expand All @@ -491,6 +493,6 @@ def _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enabl

scoring_response_dict = json.loads(scoring_response.content)
assert "error_code" in scoring_response_dict
assert scoring_response_dict["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert scoring_response_dict["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in scoring_response_dict
assert "stack_trace" in scoring_response_dict
12 changes: 6 additions & 6 deletions tests/pyfunc/test_scoring_server.py
Expand Up @@ -15,7 +15,7 @@
import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
import mlflow.sklearn
from mlflow.models import ModelSignature, infer_signature
from mlflow.protos.databricks_pb2 import ErrorCode, MALFORMED_REQUEST, BAD_REQUEST
from mlflow.protos.databricks_pb2 import ErrorCode, BAD_REQUEST
from mlflow.pyfunc import PythonModel
from mlflow.pyfunc.scoring_server import get_cmd
from mlflow.types import Schema, ColSpec, DataType
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_scoring_server_responds_to_invalid_json_input_with_stacktrace_and_error
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand All @@ -119,7 +119,7 @@ def test_scoring_server_responds_to_invalid_json_input_with_stacktrace_and_error
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand All @@ -138,7 +138,7 @@ def test_scoring_server_responds_to_malformed_json_input_with_stacktrace_and_err
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand All @@ -159,7 +159,7 @@ def test_scoring_server_responds_to_invalid_pandas_input_format_with_stacktrace_
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand Down Expand Up @@ -198,7 +198,7 @@ def test_scoring_server_responds_to_invalid_csv_input_with_stacktrace_and_error_
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand Down