diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index 54b209d0a9b..1d482e2a937 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -205,6 +205,16 @@ Load Python dictionaries with [`~Dataset.from_dict`]: >>> dataset = Dataset.from_dict(my_dict) ``` +### Python list of dictionaries + +Load a list of Python dictionaries with [`~Dataset.from_list`]: + +```py +>>> from datasets import Dataset +>>> my_list = [{"a": 1}, {"a": 2}, {"a": 3}] +>>> dataset = Dataset.from_list(my_list) +``` + ### Pandas DataFrame Load Pandas DataFrames with [`~Dataset.from_pandas`]: diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bc2c5ab37ba..288254cbc9c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -872,6 +872,33 @@ def from_dict( info.features = Features({col: ts.get_inferred_type() for col, ts in mapping.items()}) return cls(pa_table, info=info, split=split) + @classmethod + def from_list( + cls, + mapping: List[dict], + features: Optional[Features] = None, + info: Optional[DatasetInfo] = None, + split: Optional[NamedSplit] = None, + ) -> "Dataset": + """ + Convert a list of dicts to a :obj:`pyarrow.Table` to create a :class:`Dataset`. + + Note that the keys of the first entry will be used to determine the dataset columns, + regardless of what is passed to features. + + Args: + mapping (:obj:`List[dict]`): A list of mappings of strings to row values. + features (:class:`Features`, optional): Dataset features. + info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. + split (:class:`NamedSplit`, optional): Name of the dataset split. + + Returns: + :class:`Dataset` + """ + # for simplicity and consistency wrt OptimizedTypedSequence we do not use InMemoryTable.from_pylist here + mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {} + return cls.from_dict(mapping, features, info, split) + @staticmethod def from_csv( path_or_paths: Union[PathLike, List[PathLike]], diff --git a/src/datasets/table.py b/src/datasets/table.py index 64da5e5e28e..6ea551626b4 100644 --- a/src/datasets/table.py +++ b/src/datasets/table.py @@ -252,6 +252,19 @@ def to_pydict(self, *args, **kwargs): """ return self.table.to_pydict(*args, **kwargs) + def to_pylist(self, *args, **kwargs): + """ + Convert the Table to a list + + Returns: + :obj:`list` + """ + try: + return self.table.to_pylist(*args, **kwargs) + except AttributeError: # pyarrow <7 does not have to_pylist, so we use to_pydict + pydict = self.table.to_pydict(*args, **kwargs) + return [{k: pydict[k][i] for k in pydict} for i in range(len(self.table))] + def to_pandas(self, *args, **kwargs): """ Convert to a pandas-compatible NumPy array or DataFrame, as appropriate @@ -749,6 +762,28 @@ def from_pydict(cls, *args, **kwargs): """ return cls(pa.Table.from_pydict(*args, **kwargs)) + @classmethod + def from_pylist(cls, mapping, *args, **kwargs): + """ + Construct a Table from list of rows / dictionaries. + + Args: + mapping (:obj:`List[dict]`): + A mapping of strings to row values. + schema (:obj:`Schema`, defaults to :obj:`None`): + If not passed, will be inferred from the Mapping values + metadata (:obj:`Union[dict, Mapping]`, default None): + Optional metadata for the schema (if inferred). + + Returns: + :class:`datasets.table.Table`: + """ + try: + return cls(pa.Table.from_pylist(mapping, *args, **kwargs)) + except AttributeError: # pyarrow <7 does not have from_pylist, so we convert and use from_pydict + mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {} + return cls(pa.Table.from_pydict(mapping, *args, **kwargs)) + @classmethod def from_batches(cls, *args, **kwargs): """ diff --git a/tests/test_dataset_list.py b/tests/test_dataset_list.py new file mode 100644 index 00000000000..1004ae3cd68 --- /dev/null +++ b/tests/test_dataset_list.py @@ -0,0 +1,47 @@ +from unittest import TestCase + +from datasets import Sequence, Value +from datasets.arrow_dataset import Dataset + + +class DatasetListTest(TestCase): + def _create_example_records(self): + return [ + {"col_1": 3, "col_2": "a"}, + {"col_1": 2, "col_2": "b"}, + {"col_1": 1, "col_2": "c"}, + {"col_1": 0, "col_2": "d"}, + ] + + def _create_example_dict(self): + data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} + return Dataset.from_dict(data) + + def test_create(self): + example_records = self._create_example_records() + dset = Dataset.from_list(example_records) + self.assertListEqual(dset.column_names, ["col_1", "col_2"]) + for i, r in enumerate(dset): + self.assertDictEqual(r, example_records[i]) + + def test_list_dict_equivalent(self): + example_records = self._create_example_records() + dset = Dataset.from_list(example_records) + dset_from_dict = Dataset.from_dict({k: [r[k] for r in example_records] for k in example_records[0]}) + self.assertEqual(dset.info, dset_from_dict.info) + + def test_uneven_records(self): # checks what happens with missing columns + uneven_records = [{"col_1": 1}, {"col_2": "x"}] + dset = Dataset.from_list(uneven_records) + self.assertDictEqual(dset[0], {"col_1": 1}) + self.assertDictEqual(dset[1], {"col_1": None}) # NB: first record is used for columns + + def test_variable_list_records(self): # checks if the type can be inferred from the second record + list_records = [{"col_1": []}, {"col_1": [1, 2]}] + dset = Dataset.from_list(list_records) + self.assertEqual(dset.info.features["col_1"], Sequence(Value("int64"))) + + def test_create_empty(self): + dset = Dataset.from_list([]) + self.assertEqual(len(dset), 0) + self.assertListEqual(dset.column_names, []) diff --git a/tests/test_table.py b/tests/test_table.py index 073c8b0482e..4bb31900dea 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -255,6 +255,13 @@ def test_in_memory_table_from_pydict(in_memory_pa_table): assert table.table == pa.Table.from_pydict(pydict) +def test_in_memory_table_from_pylist(in_memory_pa_table): + pylist = InMemoryTable(in_memory_pa_table).to_pylist() + table = InMemoryTable.from_pylist(pylist) + assert isinstance(table, InMemoryTable) + assert pylist == table.to_pylist() + + def test_in_memory_table_from_batches(in_memory_pa_table): batches = list(in_memory_pa_table.to_batches()) table = InMemoryTable.from_batches(batches)