Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oversampling strategies to interleave datasets #4831

47 changes: 44 additions & 3 deletions src/datasets/arrow_dataset.py
Expand Up @@ -4851,6 +4851,7 @@ def _interleave_map_style_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
**kwargs,
) -> "Dataset":
"""
Expand All @@ -4866,24 +4867,55 @@ def _interleave_map_style_datasets(
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.
stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`):
Two strategies are proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets.

Output:
:class:`datasets.Dataset`
"""
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(
f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}"
)

# To interleave the datasets, we concatenate them and then we re-order the indices
concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split)

# Let's now build the indices to pass to .select()
lengths = [len(dset) for dset in datasets]
offsets = np.cumsum([0] + lengths[:-1])
if probabilities is None:

# if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted"
oversampling = stopping_strategy == "all_exhausted"

if probabilities is None and not oversampling:
# Undersampling situation with cycling between each sources
# Example:: If lengths of the datasets are [3, 4, 5]
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9]
# Note that we only have 3 examples per dataset since the first dataset ran out of examples
indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist()
elif probabilities is None:
# Oversampling situation with cycling between each sources
# Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 0, 11]
# Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples
indices = offsets.reshape(1, -1) + np.arange(max(lengths)).reshape(-1, 1)

# We have to keep the indices to their respective dataset offsets
indices = np.mod(indices, np.cumsum(lengths)).flatten().tolist()

else:
# boolean array indicating if at index i if the dataset_i has been fully exhausted
is_exhausted = np.full(len(lengths), False)

# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
bool_strategy_func = np.all if oversampling else np.any

def iter_random_indices():
"""Get an infinite iterator that randomly samples the index of the source to pick examples from."""
Expand All @@ -4894,12 +4926,21 @@ def iter_random_indices():
current_index = [0] * len(datasets)
indices = []
for source_idx in iter_random_indices():
# we ran out of examples, let's stop
if current_index[source_idx] >= lengths[source_idx]:
# If no oversampling, we stop as soon as a dataset has ran out of examples (np.any)
# Otherwise, we stop as soon as every dataset has ran out of examples (np.all)
if bool_strategy_func(is_exhausted):
# the stopping condition was reached, let's stop
break

# let's add the example at the current index of the `source_idx`-th dataset
indices.append(current_index[source_idx] + offsets[source_idx])
current_index[source_idx] += 1

# we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0
if current_index[source_idx] >= lengths[source_idx]:
is_exhausted[source_idx] = True
current_index[source_idx] = 0

return concatenated_datasets.select(indices, **kwargs)


Expand Down
48 changes: 43 additions & 5 deletions src/datasets/combine.py
Expand Up @@ -19,6 +19,7 @@ def interleave_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
) -> DatasetType:
"""
Interleave several datasets (sources) into a single dataset.
Expand All @@ -29,7 +30,8 @@ def interleave_datasets(
If ``probabilities`` is ``None`` (default) the new dataset is constructed by cycling between each source to get the examples.
If ``probabilities`` is not ``None``, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.

The resulting dataset ends when one of the source datasets runs out of examples.
The resulting dataset ends when one of the source datasets runs out of examples except when ``oversampling`` is ``True`` and :class:`Dataset` objects are used,
in which case, the resulting dataset ends when all datasets have ran out of examples at least one time.

Args:
datasets (:obj:`List[Dataset]` or :obj:`List[IterableDataset]`): list of datasets to interleave
Expand All @@ -40,7 +42,14 @@ def interleave_datasets(
<Added version="2.4.0"/>
split ([`NamedSplit`], *optional*): Name of the dataset split.
<Added version="2.4.0"/>

stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`):
Two strategies are proposed right now for :class:`Dataset` objects.
For :class:`IterableDataset` objects, only `first_exhausted` is proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.
Returns:
:class:`Dataset` or :class:`IterableDataset`: Return type depends on the input `datasets`
parameter. `Dataset` if the input is a list of `Dataset`, `IterableDataset` if the input is a list of
Expand All @@ -50,17 +59,38 @@ def interleave_datasets(

For regular datasets (map-style):


>>> from datasets import Dataset, interleave_datasets
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22]})
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12, 10, 0, 1, 2, 21, 0, 11, 1, 2, 0, 1, 12, 2, 10, 0, 22]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
>>> dataset["a"]
[10, 0, 11, 1, 2]
>>> dataset = interleave_datasets([d1, d2, d3])
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]})
>>> d3 = Dataset.from_dict({"a": [20, 21, 22, 23, 24]})
>>> dataset = interleave_datasets([d1, d2, d3])
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22]
>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")
>>> dataset["a"]
[0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 0, 24]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42)
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12]

[10, 0, 11, 1, 2]
>>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted")
>>> dataset["a"]
[10, 0, 11, 1, 2, 20, 12, 13, ..., 0, 1, 2, 0, 24]
For datasets in streaming mode (iterable):

>>> from datasets import load_dataset, interleave_datasets
Expand Down Expand Up @@ -89,8 +119,16 @@ def interleave_datasets(
raise ValueError(
f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
)
if (iterable and stopping_strategy != "first_exhausted") or (
stopping_strategy not in ["first_exhausted", "all_exhausted"]
):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be a ValueError if stopping_strategy not in ["first_exhausted", "all_exhausted"] and a NotImplementedError if iterable and stopping_strategy != "first_exhausted"

f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}"
)
if map_style:
return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split)
return _interleave_map_style_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)
else:
return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split)

Expand Down