From 9cf25c0526322bb67fbb520d6f1642bd54eb3d15 Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Thu, 16 Dec 2021 18:32:05 +0900 Subject: [PATCH] Fix can_parse_as_json (#5177) Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- mlflow/utils/rest_utils.py | 9 ++++----- tests/utils/test_rest_utils.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mlflow/utils/rest_utils.py b/mlflow/utils/rest_utils.py index f701a0e0bac89..542b4a0016c9f 100644 --- a/mlflow/utils/rest_utils.py +++ b/mlflow/utils/rest_utils.py @@ -147,10 +147,9 @@ def http_request( raise MlflowException("API request to %s failed with exception %s" % (url, e)) -def _can_parse_as_json(string): +def _can_parse_as_json_object(string): try: - json.loads(string) - return True + return isinstance(json.loads(string), dict) except Exception: return False @@ -166,7 +165,7 @@ def http_request_safe(host_creds, endpoint, method, **kwargs): def verify_rest_response(response, endpoint): """Verify the return code and format, raise exception if the request was not successful.""" if response.status_code != 200: - if _can_parse_as_json(response.text): + if _can_parse_as_json_object(response.text): raise RestException(json.loads(response.text)) else: base_msg = "API request to endpoint %s failed with error code " "%s != 200" % ( @@ -177,7 +176,7 @@ def verify_rest_response(response, endpoint): # Skip validation for endpoints (e.g. DBFS file-download API) which may return a non-JSON # response - if endpoint.startswith(_REST_API_PATH_PREFIX) and not _can_parse_as_json(response.text): + if endpoint.startswith(_REST_API_PATH_PREFIX) and not _can_parse_as_json_object(response.text): base_msg = ( "API request to endpoint was successful but the response body was not " "in a valid JSON format" diff --git a/tests/utils/test_rest_utils.py b/tests/utils/test_rest_utils.py index 1df629cc316d9..3290c90af9b52 100644 --- a/tests/utils/test_rest_utils.py +++ b/tests/utils/test_rest_utils.py @@ -13,6 +13,7 @@ _DEFAULT_HEADERS, call_endpoint, call_endpoints, + _can_parse_as_json_object, ) from mlflow.protos.service_pb2 import GetRun from mlflow.protos.databricks_pb2 import ENDPOINT_NOT_FOUND, ErrorCode @@ -312,3 +313,12 @@ def test_numpy_encoder_fail(): with pytest.raises(TypeError, match="not JSON serializable"): ne = NumpyEncoder() ne.default(test_number) + + +def test_can_parse_as_json_object(): + assert _can_parse_as_json_object("{}") + assert _can_parse_as_json_object('{"a": "b"}') + assert _can_parse_as_json_object('{"a": {"b": "c"}}') + assert not _can_parse_as_json_object("[0, 1, 2]") + assert not _can_parse_as_json_object('"abc"') + assert not _can_parse_as_json_object("123")