Skip to content

Commit

Permalink
BUG: fixed model serve fail with HTTP 400 on Bad Request. (#5003)
Browse files Browse the repository at this point in the history
* BUG: fixed model serve fail with HTTP 400 on Bad Request.

Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* TEST: Updated a test case to reflect an update with HTTP errors fix
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* Updated assert statement, changed from MALFORMED_REQUEST to BAD_REQUEST
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* Updated tests: changed from MALFORMED request to BAD_REQUEST
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* BUG: updated all instances from MALFORMED_REQUEST to BAD_REQUEST
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* Changed error code when parsing tf_serving_input
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* Changed the error code when parsing tf_serving_input
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* Format

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Updated the value of the data argument that passes invalid json
Signed-off-by: Andrei Batomunkuev <abatomunkuev@myseneca.ca>

* fix test failures

Signed-off-by: harupy <hkawamura0130@gmail.com>

* fix assert message

Signed-off-by: harupy <hkawamura0130@gmail.com>

Co-authored-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: harupy <hkawamura0130@gmail.com>
  • Loading branch information
3 people committed Dec 3, 2021
1 parent 6a88bae commit f827fa4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
17 changes: 10 additions & 7 deletions mlflow/pyfunc/scoring_server/__init__.py
Expand Up @@ -41,7 +41,7 @@
from mlflow.pyfunc import load_model, PyFuncModel
except ImportError:
from mlflow.pyfunc import load_pyfunc as load_model
from mlflow.protos.databricks_pb2 import MALFORMED_REQUEST, BAD_REQUEST
from mlflow.protos.databricks_pb2 import BAD_REQUEST
from mlflow.server.handlers import catch_mlflow_exception

try:
Expand Down Expand Up @@ -87,7 +87,7 @@ def infer_and_parse_json_input(json_input, schema: Schema = None):
"Failed to parse input from JSON. Ensure that input is a valid JSON"
" formatted string."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)

if isinstance(decoded_input, list):
Expand All @@ -97,7 +97,10 @@ def infer_and_parse_json_input(json_input, schema: Schema = None):
try:
return parse_tf_serving_input(decoded_input, schema=schema)
except MlflowException as ex:
_handle_serving_error(error_message=(ex.message), error_code=MALFORMED_REQUEST)
_handle_serving_error(
error_message=(ex.message),
error_code=BAD_REQUEST,
)
else:
return parse_json_input(json_input=json_input, orient="split", schema=schema)
else:
Expand All @@ -106,7 +109,7 @@ def infer_and_parse_json_input(json_input, schema: Schema = None):
"Failed to parse input from JSON. Ensure that input is a valid JSON"
" list or dictionary."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand All @@ -129,7 +132,7 @@ def parse_json_input(json_input, orient="split", schema: Schema = None):
" produced using the `pandas.DataFrame.to_json(..., orient='{orient}')`"
" method.".format(orient=orient)
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand All @@ -148,7 +151,7 @@ def parse_csv_input(csv_input):
" a valid CSV-formatted Pandas DataFrame produced using the"
" `pandas.DataFrame.to_csv()` method."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand All @@ -173,7 +176,7 @@ def parse_split_oriented_json_input_to_numpy(json_input):
" produced using the `pandas.DataFrame.to_json(..., orient='split')`"
" method."
),
error_code=MALFORMED_REQUEST,
error_code=BAD_REQUEST,
)


Expand Down
12 changes: 7 additions & 5 deletions tests/models/test_cli.py
Expand Up @@ -34,7 +34,7 @@
get_safe_port,
pyfunc_serve_and_score_model,
)
from mlflow.protos.databricks_pb2 import ErrorCode, MALFORMED_REQUEST
from mlflow.protos.databricks_pb2 import ErrorCode, BAD_REQUEST
from mlflow.pyfunc.scoring_server import (
CONTENT_TYPE_JSON_SPLIT_ORIENTED,
CONTENT_TYPE_JSON,
Expand Down Expand Up @@ -478,9 +478,11 @@ def _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enabl
# Try examples of bad input, verify we get a non-200 status code
for content_type in [CONTENT_TYPE_JSON_SPLIT_ORIENTED, CONTENT_TYPE_CSV, CONTENT_TYPE_JSON]:
scoring_response = endpoint.invoke(data="", content_type=content_type)
assert scoring_response.status_code == 500, (
"Expected server failure with error code 500, got response with status code %s "
"and body %s" % (scoring_response.status_code, scoring_response.text)
expected_status_code = 500 if enable_mlserver else 400
assert scoring_response.status_code == expected_status_code, (
"Expected server failure with error code %s, got response with status code %s "
"and body %s"
% (expected_status_code, scoring_response.status_code, scoring_response.text)
)

if enable_mlserver:
Expand All @@ -491,6 +493,6 @@ def _validate_with_rest_endpoint(scoring_proc, host_port, df, x, sk_model, enabl

scoring_response_dict = json.loads(scoring_response.content)
assert "error_code" in scoring_response_dict
assert scoring_response_dict["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert scoring_response_dict["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in scoring_response_dict
assert "stack_trace" in scoring_response_dict
12 changes: 6 additions & 6 deletions tests/pyfunc/test_scoring_server.py
Expand Up @@ -15,7 +15,7 @@
import mlflow.pyfunc.scoring_server as pyfunc_scoring_server
import mlflow.sklearn
from mlflow.models import ModelSignature, infer_signature
from mlflow.protos.databricks_pb2 import ErrorCode, MALFORMED_REQUEST, BAD_REQUEST
from mlflow.protos.databricks_pb2 import ErrorCode, BAD_REQUEST
from mlflow.pyfunc import PythonModel
from mlflow.pyfunc.scoring_server import get_cmd
from mlflow.types import Schema, ColSpec, DataType
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_scoring_server_responds_to_invalid_json_input_with_stacktrace_and_error
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand All @@ -119,7 +119,7 @@ def test_scoring_server_responds_to_invalid_json_input_with_stacktrace_and_error
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand All @@ -138,7 +138,7 @@ def test_scoring_server_responds_to_malformed_json_input_with_stacktrace_and_err
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand All @@ -159,7 +159,7 @@ def test_scoring_server_responds_to_invalid_pandas_input_format_with_stacktrace_
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand Down Expand Up @@ -198,7 +198,7 @@ def test_scoring_server_responds_to_invalid_csv_input_with_stacktrace_and_error_
)
response_json = json.loads(response.content)
assert "error_code" in response_json
assert response_json["error_code"] == ErrorCode.Name(MALFORMED_REQUEST)
assert response_json["error_code"] == ErrorCode.Name(BAD_REQUEST)
assert "message" in response_json
assert "stack_trace" in response_json

Expand Down

0 comments on commit f827fa4

Please sign in to comment.