diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index fa8a63e686a..152b35fdf9d 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -15,6 +15,25 @@ logger = datasets.utils.logging.get_logger(__name__) +if datasets.config.PYARROW_VERSION.major >= 7: + + def pa_table_from_pylist(mapping): + return pa.Table.from_pylist(mapping) + +else: + + def pa_table_from_pylist(mapping): + # Copied from: https://github.com/apache/arrow/blob/master/python/pyarrow/table.pxi#L5193 + arrays = [] + names = [] + if mapping: + names = list(mapping[0].keys()) + for n in names: + v = [row[n] if n in row else None for row in mapping] + arrays.append(v) + return pa.Table.from_arrays(arrays, names) + + @dataclass class JsonConfig(datasets.BuilderConfig): """BuilderConfig for JSON.""" @@ -125,18 +144,29 @@ def _generate_tables(self, files): ) block_size *= 2 except pa.ArrowInvalid as e: - logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") try: with open(file, encoding="utf-8") as f: dataset = json.load(f) except json.JSONDecodeError: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise e - raise ValueError( - f"Not able to read records in the JSON file at {file}. " - f"You should probably indicate the field of the JSON file containing your records. " - f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " - f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " - ) from None + # If possible, parse the file as a list of json objects and exit the loop + if isinstance(dataset, list): # list is the only sequence type supported in JSON + try: + pa_table = pa_table_from_pylist(dataset) + except (pa.ArrowInvalid, AttributeError) as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise ValueError(f"Not able to read records in the JSON file at {file}.") from None + yield file_idx, self._cast_table(pa_table) + break + else: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise ValueError( + f"Not able to read records in the JSON file at {file}. " + f"You should probably indicate the field of the JSON file containing your records. " + f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " + f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " + ) from None # Uncomment for debugging (will print the Arrow table size and elements) # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index 1698807e46a..ee360275736 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -12,7 +12,7 @@ @pytest.fixture def csv_file(tmp_path): - filename = tmp_path / "malformed_file.csv" + filename = tmp_path / "file.csv" data = textwrap.dedent( """\ header1,header2 diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py new file mode 100644 index 00000000000..09e5572828c --- /dev/null +++ b/tests/packaged_modules/test_json.py @@ -0,0 +1,71 @@ +import textwrap + +import pyarrow as pa +import pytest + +from datasets.packaged_modules.json.json import Json + + +@pytest.fixture +def jsonl_file(tmp_path): + filename = tmp_path / "file.jsonl" + data = textwrap.dedent( + """\ + {"col_1": 1, "col_2": 2} + {"col_1": 10, "col_2": 20} + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def json_file_with_list_of_dicts(tmp_path): + filename = tmp_path / "file_with_list_of_dicts.json" + data = textwrap.dedent( + """\ + [ + {"col_1": 1, "col_2": 2}, + {"col_1": 10, "col_2": 20} + ] + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def json_file_with_list_of_dicts_field(tmp_path): + filename = tmp_path / "file_with_list_of_dicts_field.json" + data = textwrap.dedent( + """\ + { + "field1": 1, + "field2": "aabb", + "field3": [ + {"col_1": 1, "col_2": 2}, + {"col_1": 10, "col_2": 20} + ] + } + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + +@pytest.mark.parametrize( + "file_fixture, config_kwargs", + [ + ("jsonl_file", {}), + ("json_file_with_list_of_dicts", {}), + ("json_file_with_list_of_dicts_field", {"field": "field3"}), + ], +) +def test_json_generate_tables(file_fixture, config_kwargs, request): + json = Json(**config_kwargs) + generator = json._generate_tables([[request.getfixturevalue(file_fixture)]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.to_pydict() == {"col_1": [1, 10], "col_2": [2, 20]}