diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index c3e3f9b906a..275d9a558d7 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import pandas as pd import pyarrow as pa @@ -33,6 +33,7 @@ class CsvConfig(datasets.BuilderConfig): prefix: Optional[str] = None mangle_dupe_cols: bool = True engine: Optional[str] = None + converters: Dict[Union[int, str], Callable[[Any], Any]] = None true_values: Optional[list] = None false_values: Optional[list] = None skipinitialspace: bool = False @@ -80,6 +81,7 @@ def read_csv_kwargs(self): prefix=self.prefix, mangle_dupe_cols=self.mangle_dupe_cols, engine=self.engine, + converters=self.converters, true_values=self.true_values, false_values=self.false_values, skipinitialspace=self.skipinitialspace, diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index ee360275736..6cfa5e4ca23 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -70,6 +70,22 @@ def csv_file_with_label(tmp_path): return str(filename) +@pytest.fixture +def csv_file_with_int_list(tmp_path): + filename = tmp_path / "csv_with_int_list.csv" + data = textwrap.dedent( + """\ + int_list + 1 2 3 + 4 5 6 + 7 8 9 + """ + ) + with open(filename, "w") as f: + f.write(data) + return str(filename) + + def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog): csv = Csv() generator = csv._generate_tables([[csv_file, malformed_csv_file]]) @@ -96,7 +112,6 @@ def test_csv_cast_image(csv_file_with_image): assert generated_content == [{"path": image_file, "bytes": None}] -@require_pil def test_csv_cast_label(csv_file_with_label): with open(csv_file_with_label, encoding="utf-8") as f: labels = f.read().splitlines()[1:] @@ -106,3 +121,12 @@ def test_csv_cast_label(csv_file_with_label): assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])() generated_content = pa_table.to_pydict()["label"] assert generated_content == [ClassLabel(names=["good", "bad"]).str2int(label) for label in labels] + + +def test_csv_convert_int_list(csv_file_with_int_list): + csv = Csv(encoding="utf-8", sep=",", converters={"int_list": lambda x: [int(i) for i in x.split()]}) + generator = csv._generate_tables([[csv_file_with_int_list]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa.types.is_list(pa_table.schema.field("int_list").type) + generated_content = pa_table.to_pydict()["int_list"] + assert generated_content == [[1, 2, 3], [4, 5, 6], [7, 8, 9]]