Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oversampling strategies to interleave datasets #4831

38 changes: 32 additions & 6 deletions docs/source/process.mdx
Expand Up @@ -489,12 +489,6 @@ Separate datasets can be concatenated if they share the same column types. Conca
>>> bert_dataset = concatenate_datasets([bookcorpus, wiki])
```

<Tip>

You can also mix several datasets together by taking alternating examples from each one to create a new dataset. This is known as *interleaving*, which is enabled by the [`interleave_datasets`] function. Both [`interleave_datasets`] and [`concatenate_datasets`] work with regular [`Dataset`] and [`IterableDataset`] objects. Refer to the [Stream](./stream#interleave) guide for an example of how to interleave datasets.

</Tip>

You can also concatenate two datasets horizontally by setting `axis=1` as long as the datasets have the same number of rows:

```py
Expand All @@ -503,6 +497,38 @@ You can also concatenate two datasets horizontally by setting `axis=1` as long a
>>> bookcorpus_with_ids = concatenate_datasets([bookcorpus, bookcorpus_ids], axis=1)
```

### Interleave

You can also mix several datasets together by taking alternating examples from each one to create a new dataset. This is known as *interleaving*, which is enabled by the [`interleave_datasets`] function. Both [`interleave_datasets`] and [`concatenate_datasets`] work with regular [`Dataset`] and [`IterableDataset`] objects.
Refer to the [Stream](./stream#interleave) guide for an example of how to interleave [`IterableDataset`] objects.

You can define sampling probabilities for each of the original datasets to specify how to interleave the datasets.
In this case, the new dataset is constructed by getting examples one by one from a random dataset until one of the datasets runs out of samples.

```py
>>> seed = 42
>>> probabilities = [0.3, 0.5, 0.2]
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22]})
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed)
>>> dataset["a"]
[10, 11, 20, 12, 0, 21, 13]
```

In the case of [`Dataset`] objects, you can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is an subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`

```py
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22]})
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20]
```

## Format

The [`~Dataset.set_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter and the columns you want to format. Formatting is applied on-the-fly.
Expand Down
54 changes: 51 additions & 3 deletions src/datasets/arrow_dataset.py
Expand Up @@ -4851,6 +4851,7 @@ def _interleave_map_style_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
**kwargs,
) -> "Dataset":
"""
Expand All @@ -4866,24 +4867,62 @@ def _interleave_map_style_datasets(
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.
stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`):
Two strategies are proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.

Output:
:class:`datasets.Dataset`
"""
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(
f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}"
)

# To interleave the datasets, we concatenate them and then we re-order the indices
concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split)

# Let's now build the indices to pass to .select()
lengths = [len(dset) for dset in datasets]
offsets = np.cumsum([0] + lengths[:-1])
if probabilities is None:

# if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted"
oversampling = stopping_strategy == "all_exhausted"

if probabilities is None and not oversampling:
# Undersampling situation with cycling between each sources
# Example:: If lengths of the datasets are [3, 4, 5]
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
# Note that we only have 3 examples per dataset since the first dataset ran out of examples

# Reasoning behind the following operation: keeping the min_length first indices of each dataset
# while offsetting in order to correspond to the right indices of the concatenated dataset
# and flattening to effectively interleave the datasets
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
elif probabilities is None:
# Oversampling situation with cycling between each sources
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11]
# Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples

# Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset
# For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1]
indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice trick !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that comment and for the review!


# We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets
indices = (indices + offsets).flatten().tolist()

else:
# boolean array indicating if at index i if the dataset_i has been fully exhausted
is_exhausted = np.full(len(lengths), False)

# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
bool_strategy_func = np.all if oversampling else np.any

def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
Expand All @@ -4894,12 +4933,21 @@ def iter_random_indices():
current_index = [0] * len(datasets)
indices = []
for source_idx in iter_random_indices():
# we ran out of examples, let's stop
if current_index[source_idx] >= lengths[source_idx]:
# If no oversampling, we stop as soon as a dataset has ran out of examples (np.any)
# Otherwise, we stop as soon as every dataset has ran out of examples (np.all)
if bool_strategy_func(is_exhausted):
# the stopping condition was reached, let's stop
break

# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1

# we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
if current_index[source_idx] >= lengths[source_idx]:
is_exhausted[source_idx] = True
current_index[source_idx] = 0

return concatenated_datasets.select(indices, **kwargs)


Expand Down
48 changes: 43 additions & 5 deletions src/datasets/combine.py
Expand Up @@ -19,6 +19,7 @@ def interleave_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
) -> DatasetType:
"""
Interleave several datasets (sources) into a single dataset.
Expand All @@ -29,7 +30,8 @@ def interleave_datasets(
If ``probabilities`` is ``None`` (default) the new dataset is constructed by cycling between each source to get the examples.
If ``probabilities`` is not ``None``, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.

The resulting dataset ends when one of the source datasets runs out of examples.
The resulting dataset ends when one of the source datasets runs out of examples except when ``oversampling`` is ``True`` and :class:`Dataset` objects are used,
in which case, the resulting dataset ends when all datasets have ran out of examples at least one time.

Args:
datasets (:obj:`List[Dataset]` or :obj:`List[IterableDataset]`): list of datasets to interleave
Expand All @@ -40,7 +42,14 @@ def interleave_datasets(
<Added version="2.4.0"/>
split ([`NamedSplit`], *optional*): Name of the dataset split.
<Added version="2.4.0"/>

stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`):
Two strategies are proposed right now for :class:`Dataset` objects.
For :class:`IterableDataset` objects, only `first_exhausted` is proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
Returns:
:class:`Dataset` or :class:`IterableDataset`: Return type depends on the input `datasets`
parameter. `Dataset` if the input is a list of `Dataset`, `IterableDataset` if the input is a list of
Expand All @@ -50,17 +59,38 @@ def interleave_datasets(

For regular datasets (map-style):


>>> from datasets import Dataset, interleave_datasets
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22]})
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12, 10, 0, 1, 2, 21, 0, 11, 1, 2, 0, 1, 12, 2, 10, 0, 22]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
>>> dataset["a"]
[10, 0, 11, 1, 2]
>>> dataset = interleave_datasets([d1, d2, d3])
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22, 23, 24]})
>>> dataset = interleave_datasets([d1, d2, d3])
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 0, 24]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12]

[10, 0, 11, 1, 2]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12, 13, ..., 0, 1, 2, 0, 24]
For datasets in streaming mode (iterable):

>>> from datasets import load_dataset, interleave_datasets
Expand Down Expand Up @@ -89,8 +119,16 @@ def interleave_datasets(
raise ValueError(
f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
)
if iterable and stopping_strategy != "first_exhausted":
raise NotImplementedError(
f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}."
)
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
if map_style:
return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split)
return _interleave_map_style_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)
else:
return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split)

Expand Down
35 changes: 35 additions & 0 deletions tests/test_arrow_dataset.py
Expand Up @@ -2691,6 +2691,41 @@ def test_interleave_datasets_probabilities():
)


def test_interleave_datasets_oversampling_strategy():
d1 = Dataset.from_dict({"a": [0, 1, 2]})
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
expected_length = 3 * max(len(d1), len(d2), len(d3))
expected_values = [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20] # hardcoded
assert isinstance(dataset, Dataset)
assert len(dataset) == expected_length
assert dataset["a"] == expected_values
assert dataset._fingerprint == interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")._fingerprint


def test_interleave_datasets_probabilities_oversampling_strategy():
seed = 42
probabilities = [0.3, 0.5, 0.2]
d1 = Dataset.from_dict({"a": [0, 1, 2]})
d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0])
dataset = interleave_datasets(
[d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed
)
expected_length = 16 # hardcoded
expected_values = [10, 11, 20, 12, 0, 21, 13, 10, 1, 11, 12, 22, 13, 20, 10, 2] # hardcoded
assert isinstance(dataset, Dataset)
assert len(dataset) == expected_length
assert dataset["a"] == expected_values
assert (
dataset._fingerprint
== interleave_datasets(
[d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed
)._fingerprint
)


@pytest.mark.parametrize(
"column, expected_dtype",
[(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")],
Expand Down