Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Dataset.from_list #4890

Merged
merged 7 commits into from Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/loading.mdx
Expand Up @@ -205,6 +205,16 @@ Load Python dictionaries with [`~Dataset.from_dict`]:
>>> dataset = Dataset.from_dict(my_dict)
```

### Python list of dictionaries

Load a list of Python dictionaries with [`~Dataset.from_list`]:

```py
>>> from datasets import Dataset
>>> my_list = [{"a": 1}, {"a": 2}, {"a": 3}]
>>> dataset = Dataset.from_list(my_list)
```

### Pandas DataFrame

Load Pandas DataFrames with [`~Dataset.from_pandas`]:
Expand Down
27 changes: 27 additions & 0 deletions src/datasets/arrow_dataset.py
Expand Up @@ -872,6 +872,33 @@ def from_dict(
info.features = Features({col: ts.get_inferred_type() for col, ts in mapping.items()})
return cls(pa_table, info=info, split=split)

@classmethod
def from_list(
cls,
mapping: List[dict],
features: Optional[Features] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
) -> "Dataset":
"""
Convert a list of dicts to a :obj:`pyarrow.Table` to create a :class:`Dataset`.

Note that the keys of the first entry will be used to determine the dataset columns,
regardless of what is passed to features.

Args:
mapping (:obj:`List[dict]`): A list of mappings of strings to row values.
features (:class:`Features`, optional): Dataset features.
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.

Returns:
:class:`Dataset`
"""
# for simplicity and consistency wrt OptimizedTypedSequence we do not use InMemoryTable.from_pylist here
mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {}
return cls.from_dict(mapping, features, info, split)

@staticmethod
def from_csv(
path_or_paths: Union[PathLike, List[PathLike]],
Expand Down
35 changes: 35 additions & 0 deletions src/datasets/table.py
Expand Up @@ -252,6 +252,19 @@ def to_pydict(self, *args, **kwargs):
"""
return self.table.to_pydict(*args, **kwargs)

def to_pylist(self, *args, **kwargs):
"""
Convert the Table to a list

Returns:
:obj:`list`
"""
try:
return self.table.to_pylist(*args, **kwargs)
except AttributeError: # pyarrow <7 does not have to_pylist, so we use to_pydict
pydict = self.table.to_pydict(*args, **kwargs)
return [{k: pydict[k][i] for k in pydict} for i in range(len(self.table))]

def to_pandas(self, *args, **kwargs):
"""
Convert to a pandas-compatible NumPy array or DataFrame, as appropriate
Expand Down Expand Up @@ -749,6 +762,28 @@ def from_pydict(cls, *args, **kwargs):
"""
return cls(pa.Table.from_pydict(*args, **kwargs))

@classmethod
def from_pylist(cls, mapping, *args, **kwargs):
"""
Construct a Table from list of rows / dictionaries.

Args:
mapping (:obj:`List[dict]`):
A mapping of strings to row values.
schema (:obj:`Schema`, defaults to :obj:`None`):
If not passed, will be inferred from the Mapping values
metadata (:obj:`Union[dict, Mapping]`, default None):
Optional metadata for the schema (if inferred).

Returns:
:class:`datasets.table.Table`:
"""
try:
return cls(pa.Table.from_pylist(mapping, *args, **kwargs))
except AttributeError: # pyarrow <7 does not have from_pylist, so we convert and use from_pydict
mapping = {k: [r.get(k) for r in mapping] for k in mapping[0]} if mapping else {}
return cls(pa.Table.from_pydict(mapping, *args, **kwargs))

@classmethod
def from_batches(cls, *args, **kwargs):
"""
Expand Down
47 changes: 47 additions & 0 deletions tests/test_dataset_list.py
@@ -0,0 +1,47 @@
from unittest import TestCase

from datasets import Sequence, Value
from datasets.arrow_dataset import Dataset


class DatasetListTest(TestCase):
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
def _create_example_records(self):
return [
{"col_1": 3, "col_2": "a"},
{"col_1": 2, "col_2": "b"},
{"col_1": 1, "col_2": "c"},
{"col_1": 0, "col_2": "d"},
]

def _create_example_dict(self):
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
return Dataset.from_dict(data)

def test_create(self):
example_records = self._create_example_records()
dset = Dataset.from_list(example_records)
self.assertListEqual(dset.column_names, ["col_1", "col_2"])
for i, r in enumerate(dset):
self.assertDictEqual(r, example_records[i])

def test_list_dict_equivalent(self):
example_records = self._create_example_records()
dset = Dataset.from_list(example_records)
dset_from_dict = Dataset.from_dict({k: [r[k] for r in example_records] for k in example_records[0]})
self.assertEqual(dset.info, dset_from_dict.info)

def test_uneven_records(self): # checks what happens with missing columns
uneven_records = [{"col_1": 1}, {"col_2": "x"}]
dset = Dataset.from_list(uneven_records)
self.assertDictEqual(dset[0], {"col_1": 1})
self.assertDictEqual(dset[1], {"col_1": None}) # NB: first record is used for columns

def test_variable_list_records(self): # checks if the type can be inferred from the second record
list_records = [{"col_1": []}, {"col_1": [1, 2]}]
dset = Dataset.from_list(list_records)
self.assertEqual(dset.info.features["col_1"], Sequence(Value("int64")))

def test_create_empty(self):
dset = Dataset.from_list([])
self.assertEqual(len(dset), 0)
self.assertListEqual(dset.column_names, [])
7 changes: 7 additions & 0 deletions tests/test_table.py
Expand Up @@ -255,6 +255,13 @@ def test_in_memory_table_from_pydict(in_memory_pa_table):
assert table.table == pa.Table.from_pydict(pydict)


def test_in_memory_table_from_pylist(in_memory_pa_table):
pylist = InMemoryTable(in_memory_pa_table).to_pylist()
table = InMemoryTable.from_pylist(pylist)
assert isinstance(table, InMemoryTable)
assert pylist == table.to_pylist()


def test_in_memory_table_from_batches(in_memory_pa_table):
batches = list(in_memory_pa_table.to_batches())
table = InMemoryTable.from_batches(batches)
Expand Down