Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for parsing JSON files in array form #4997

Merged
merged 6 commits into from Sep 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/datasets/packaged_modules/json/json.py
Expand Up @@ -15,6 +15,25 @@
logger = datasets.utils.logging.get_logger(__name__)


if datasets.config.PYARROW_VERSION.major >= 7:

def pa_table_from_pylist(mapping):
return pa.Table.from_pylist(mapping)

else:

def pa_table_from_pylist(mapping):
# Copied from: https://github.com/apache/arrow/blob/master/python/pyarrow/table.pxi#L5193
arrays = []
names = []
if mapping:
names = list(mapping[0].keys())
for n in names:
v = [row[n] if n in row else None for row in mapping]
arrays.append(v)
return pa.Table.from_arrays(arrays, names)


@dataclass
class JsonConfig(datasets.BuilderConfig):
"""BuilderConfig for JSON."""
Expand Down Expand Up @@ -125,18 +144,29 @@ def _generate_tables(self, files):
)
block_size *= 2
except pa.ArrowInvalid as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
try:
with open(file, encoding="utf-8") as f:
dataset = json.load(f)
except json.JSONDecodeError:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise e
raise ValueError(
f"Not able to read records in the JSON file at {file}. "
f"You should probably indicate the field of the JSON file containing your records. "
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
) from None
# If possible, parse the file as a list of json objects and exit the loop
if isinstance(dataset, list): # list is the only sequence type supported in JSON
try:
pa_table = pa_table_from_pylist(dataset)
except (pa.ArrowInvalid, AttributeError) as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise ValueError(f"Not able to read records in the JSON file at {file}.") from None
yield file_idx, self._cast_table(pa_table)
break
else:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise ValueError(
f"Not able to read records in the JSON file at {file}. "
f"You should probably indicate the field of the JSON file containing your records. "
f"This JSON file contain the following fields: {str(list(dataset.keys()))}. "
f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. "
) from None
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
Expand Down
2 changes: 1 addition & 1 deletion tests/packaged_modules/test_csv.py
Expand Up @@ -12,7 +12,7 @@

@pytest.fixture
def csv_file(tmp_path):
filename = tmp_path / "malformed_file.csv"
filename = tmp_path / "file.csv"
data = textwrap.dedent(
"""\
header1,header2
Expand Down
71 changes: 71 additions & 0 deletions tests/packaged_modules/test_json.py
@@ -0,0 +1,71 @@
import textwrap

import pyarrow as pa
import pytest

from datasets.packaged_modules.json.json import Json


@pytest.fixture
def jsonl_file(tmp_path):
filename = tmp_path / "file.jsonl"
data = textwrap.dedent(
"""\
{"col_1": 1, "col_2": 2}
{"col_1": 10, "col_2": 20}
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)


@pytest.fixture
def json_file_with_list_of_dicts(tmp_path):
filename = tmp_path / "file_with_list_of_dicts.json"
data = textwrap.dedent(
"""\
[
{"col_1": 1, "col_2": 2},
{"col_1": 10, "col_2": 20}
]
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)


@pytest.fixture
def json_file_with_list_of_dicts_field(tmp_path):
filename = tmp_path / "file_with_list_of_dicts_field.json"
data = textwrap.dedent(
"""\
{
"field1": 1,
"field2": "aabb",
"field3": [
{"col_1": 1, "col_2": 2},
{"col_1": 10, "col_2": 20}
]
}
"""
)
with open(filename, "w") as f:
f.write(data)
return str(filename)


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
("jsonl_file", {}),
("json_file_with_list_of_dicts", {}),
("json_file_with_list_of_dicts_field", {"field": "field3"}),
],
)
def test_json_generate_tables(file_fixture, config_kwargs, request):
json = Json(**config_kwargs)
generator = json._generate_tables([[request.getfixturevalue(file_fixture)]])
pa_table = pa.concat_tables([table for _, table in generator])
assert pa_table.to_pydict() == {"col_1": [1, 10], "col_2": [2, 20]}