From 8958ab66a80e947852fef685c6cf4c68e2cf9192 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Mon, 12 Sep 2022 19:09:27 +0200 Subject: [PATCH 1/5] Support parsing JSON lists --- src/datasets/packaged_modules/json/json.py | 34 ++++++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index fa8a63e686a..eaf6da34531 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -15,6 +15,23 @@ logger = datasets.utils.logging.get_logger(__name__) +if datasets.config.PYARROW_VERSION.major >= 7: + + def pa_table_from_pylist(mapping): + return pa.Table.from_pylist(mapping) +else: + + def pa_table_from_pylist(mapping): + arrays = [] + names = [] + if mapping: + names = list(mapping[0].keys()) + for n in names: + v = [row[n] if n in row else None for row in mapping] + arrays.append(v) + return pa.Table.from_arrays(arrays, names) + + @dataclass class JsonConfig(datasets.BuilderConfig): """BuilderConfig for JSON.""" @@ -131,12 +148,17 @@ def _generate_tables(self, files): dataset = json.load(f) except json.JSONDecodeError: raise e - raise ValueError( - f"Not able to read records in the JSON file at {file}. " - f"You should probably indicate the field of the JSON file containing your records. " - f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " - f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " - ) from None + if isinstance(dataset, list): # list is the only supported sequence type in JSON + pa_table = pa_table_from_pylist(dataset) + yield file_idx, self._cast_table(pa_table) + break + else: + raise ValueError( + f"Not able to read records in the JSON file at {file}. " + f"You should probably indicate the field of the JSON file containing your records. " + f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " + f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " + ) from None # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) From ab250cfc13f3ba9f61118937ae7bda3e289d3d23 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 13 Sep 2022 14:13:56 +0200 Subject: [PATCH 2/5] Add error handling --- src/datasets/packaged_modules/json/json.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index eaf6da34531..f6f89286d54 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -16,9 +16,10 @@ if datasets.config.PYARROW_VERSION.major >= 7: - + def pa_table_from_pylist(mapping): return pa.Table.from_pylist(mapping) + else: def pa_table_from_pylist(mapping): @@ -142,17 +143,23 @@ def _generate_tables(self, files): ) block_size *= 2 except pa.ArrowInvalid as e: - logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") try: with open(file, encoding="utf-8") as f: dataset = json.load(f) except json.JSONDecodeError: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise e - if isinstance(dataset, list): # list is the only supported sequence type in JSON - pa_table = pa_table_from_pylist(dataset) + # If possible, parse the file as a list of json objects and break the loop + if isinstance(dataset, list): # list is the only sequence type supported in JSON + try: + pa_table = pa_table_from_pylist(dataset) + except (pa.ArrowInvalid, AttributeError) as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise ValueError(f"Not able to read records in the JSON file at {file}.") from None yield file_idx, self._cast_table(pa_table) break else: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise ValueError( f"Not able to read records in the JSON file at {file}. " f"You should probably indicate the field of the JSON file containing your records. " From c75758b427ae9cac0a99b4dfd3d2a5a089e33be5 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 16 Sep 2022 13:26:04 +0200 Subject: [PATCH 3/5] Minor improvements --- src/datasets/packaged_modules/json/json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index f6f89286d54..6154232a65c 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -149,7 +149,7 @@ def _generate_tables(self, files): except json.JSONDecodeError: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise e - # If possible, parse the file as a list of json objects and break the loop + # If possible, parse the file as a list of json objects and exit the loop if isinstance(dataset, list): # list is the only sequence type supported in JSON try: pa_table = pa_table_from_pylist(dataset) From 4d9c1224b9712e72b883707e039c7bee8f10cbfc Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 16 Sep 2022 13:26:20 +0200 Subject: [PATCH 4/5] Add tests --- tests/packaged_modules/test_csv.py | 2 +- tests/packaged_modules/test_json.py | 71 +++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 tests/packaged_modules/test_json.py diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index 1698807e46a..ee360275736 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -12,7 +12,7 @@ @pytest.fixture def csv_file(tmp_path): - filename = tmp_path / "malformed_file.csv" + filename = tmp_path / "file.csv" data = textwrap.dedent( """\ header1,header2 diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py new file mode 100644 index 00000000000..09e5572828c --- /dev/null +++ b/tests/packaged_modules/test_json.py @@ -0,0 +1,71 @@ +import textwrap + +import pyarrow as pa +import pytest + +from datasets.packaged_modules.json.json import Json + + +@pytest.fixture +def jsonl_file(tmp_path): + filename = tmp_path / "file.jsonl" + data = textwrap.dedent( + """\ + {"col_1": 1, "col_2": 2} + {"col_1": 10, "col_2": 20} + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def json_file_with_list_of_dicts(tmp_path): + filename = tmp_path / "file_with_list_of_dicts.json" + data = textwrap.dedent( + """\ + [ + {"col_1": 1, "col_2": 2}, + {"col_1": 10, "col_2": 20} + ] + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def json_file_with_list_of_dicts_field(tmp_path): + filename = tmp_path / "file_with_list_of_dicts_field.json" + data = textwrap.dedent( + """\ + { + "field1": 1, + "field2": "aabb", + "field3": [ + {"col_1": 1, "col_2": 2}, + {"col_1": 10, "col_2": 20} + ] + } + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.mark.parametrize( + "file_fixture, config_kwargs", + [ + ("jsonl_file", {}), + ("json_file_with_list_of_dicts", {}), + ("json_file_with_list_of_dicts_field", {"field": "field3"}), + ], +) +def test_json_generate_tables(file_fixture, config_kwargs, request): + json = Json(**config_kwargs) + generator = json._generate_tables([[request.getfixturevalue(file_fixture)]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.to_pydict() == {"col_1": [1, 10], "col_2": [2, 20]} From ff72c49978903d672120b4be64e116e0bbbd750c Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 20 Sep 2022 15:04:34 +0200 Subject: [PATCH 5/5] Comment --- src/datasets/packaged_modules/json/json.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 6154232a65c..152b35fdf9d 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -23,6 +23,7 @@ def pa_table_from_pylist(mapping): else: def pa_table_from_pylist(mapping): + # Copied from: https://github.com/apache/arrow/blob/master/python/pyarrow/table.pxi#L5193 arrays = [] names = [] if mapping: