Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add oversampling strategy iterable datasets interleave #5036

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ In this case, the new dataset is constructed by getting examples one by one from
[10, 11, 20, 12, 0, 21, 13]
```

In the case of [`Dataset`] objects, you can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is an subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
You can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`.

```py
>>> d1 = Dataset.from_dict({"a": [0, 1, 2]})
Expand Down
4 changes: 4 additions & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ Define sampling probabilities from each of the original datasets for more contro

Around 80% of the final dataset is made of the `en_dataset`, and 20% of the `fr_dataset`.

You can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is a subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples.
You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached.
Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples`.

## Rename, remove, and cast

The following methods allow you to modify the columns of a dataset. These methods are useful for renaming or removing columns and changing columns to a new set of features.
Expand Down
13 changes: 5 additions & 8 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def interleave_datasets(
If ``probabilities`` is ``None`` (default) the new dataset is constructed by cycling between each source to get the examples.
If ``probabilities`` is not ``None``, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities.

The resulting dataset ends when one of the source datasets runs out of examples except when ``oversampling`` is ``True`` and :class:`Dataset` objects are used,
The resulting dataset ends when one of the source datasets runs out of examples except when ``oversampling`` is ``True``,
in which case, the resulting dataset ends when all datasets have ran out of examples at least one time.

Args:
Expand All @@ -43,8 +43,7 @@ def interleave_datasets(
split ([`NamedSplit`], *optional*): Name of the dataset split.
<Added version="2.4.0"/>
stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`):
Two strategies are proposed right now for :class:`Dataset` objects.
For :class:`IterableDataset` objects, only `first_exhausted` is proposed right now.
Two strategies are proposed right now, `first_exhausted` and `all_exhausted`.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
Expand Down Expand Up @@ -119,18 +118,16 @@ def interleave_datasets(
raise ValueError(
f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects."
)
if iterable and stopping_strategy != "first_exhausted":
raise NotImplementedError(
f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}."
)
if stopping_strategy not in ["first_exhausted", "all_exhausted"]:
raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.")
if map_style:
return _interleave_map_style_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)
else:
return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split)
return _interleave_iterable_datasets(
datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy
)


def concatenate_datasets(
Expand Down
105 changes: 86 additions & 19 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,35 @@ def _batch_to_examples(batch: Dict[str, list]) -> List[Dict[str, Any]]:
yield {col: array[i] for col, array in batch.items()}


class HasNextIterator(Iterator):
"""Iterator with an hasnext() function. Taken from https://stackoverflow.com/questions/1966591/has-next-in-python-iterators."""

def __init__(self, it):
self.it = iter(it)
self._hasnext = None

def __iter__(self):
return self

def __next__(self):
if self._hasnext:
result = self._thenext
else:
result = next(self.it)
self._hasnext = None
return result

def hasnext(self):
if self._hasnext is None:
try:
self._thenext = next(self.it)
except StopIteration:
self._hasnext = False
else:
self._hasnext = True
return self._hasnext


class _BaseExamplesIterable:
"""Base class for the examples iterable used by an IterableDataset"""

Expand Down Expand Up @@ -146,23 +175,53 @@ def shard_data_sources(self, shard_idx: int) -> "MappedExamplesIterable":


class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable):
def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
def __init__(
self, ex_iterables: List[_BaseExamplesIterable], stopping_strategy: Optional[str] = "first_exhausted"
):
self.ex_iterables = ex_iterables
self.stopping_strategy = stopping_strategy

def __iter__(self):
iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables]
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any

def _give_indice_iterator(self):
# this is an infinite iterator to keep track of which iterator we want to pick examples from
indices_iterator = cycle(range(len(iterators)))
return cycle(range(len(self.ex_iterables)))

def __iter__(self):
iterators = [HasNextIterator(ex_iterable) for ex_iterable in self.ex_iterables]

indices_iterator = self._give_indice_iterator()

is_exhausted = np.full(len(self.ex_iterables), False)
for i in indices_iterator:
try: # let's pick one example from the iterator at index i
yield next(iterators[i])
except StopIteration: # if we ran out of examples on this iterator, break the main for loop
break

# it will resume from the yield at the next call so that we can directly test if the iterable is exhausted and if we need to break out of the loop
if not iterators[i].hasnext():
is_exhausted[i] = True

if self.bool_strategy_func(is_exhausted):
# if the stopping criteria is met, break the main for loop
break
# otherwise reinitialise the iterator and yield the first example
iterators[i] = HasNextIterator(self.ex_iterables[i])

except StopIteration:
# here it means that the i-th iterabledataset is empty, i.e we never have the occasion to yield an element of the i-th dataset.
# we still check if the stopping criteria is met and if we break out of the loop in case of an oversampling strategy
is_exhausted[i] = True

if self.bool_strategy_func(is_exhausted):
# if the stopping criteria is met, break the main for loop
break

def shuffle_data_sources(self, generator: np.random.Generator) -> "CyclingMultiSourcesExamplesIterable":
"""Shuffle each underlying examples iterable."""
ex_iterables = [ex_iterable.shuffle_data_sources(generator) for ex_iterable in self.ex_iterables]
return CyclingMultiSourcesExamplesIterable(ex_iterables)
return CyclingMultiSourcesExamplesIterable(ex_iterables, self.stopping_strategy)

@property
def n_shards(self) -> int:
Expand Down Expand Up @@ -280,8 +339,14 @@ def shard_data_sources(self, shard_idx: int) -> "HorizontallyConcatenatedMultiSo


class RandomlyCyclingMultiSourcesExamplesIterable(CyclingMultiSourcesExamplesIterable):
def __init__(self, ex_iterables, generator: np.random.Generator, probabilities: Optional[List[float]] = None):
super().__init__(ex_iterables)
def __init__(
self,
ex_iterables,
generator: np.random.Generator,
probabilities: Optional[List[float]] = None,
stopping_strategy: Optional[str] = "first_exhausted",
):
super().__init__(ex_iterables, stopping_strategy)
self.generator = deepcopy(generator)
self.probabilities = probabilities

Expand All @@ -300,16 +365,10 @@ def _iter_random_indices(
while True:
yield from (int(i) for i in rng.choice(num_sources, size=random_batch_size, p=p))

def __iter__(self):
def _give_indice_iterator(self):
rng = deepcopy(self.generator)
iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables]
# this is an infinite iterator that randomly samples the index of the source to pick examples from
indices_iterator = self._iter_random_indices(rng, len(iterators), p=self.probabilities)
for i in indices_iterator:
try: # let's pick one example from the iterator at index i
yield next(iterators[i])
except StopIteration: # if we ran out of examples on this iterator, break the main for loop
break
return self._iter_random_indices(rng, len(self.ex_iterables), p=self.probabilities)

def shuffle_data_sources(self, generator: np.random.Generator) -> "RandomlyCyclingMultiSourcesExamplesIterable":
"""Shuffle the data sources of each wrapped examples iterable."""
Expand Down Expand Up @@ -1371,6 +1430,7 @@ def _interleave_iterable_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
) -> IterableDataset:
"""
Interleave several iterable datasets (sources) into a single iterable dataset.
Expand All @@ -1385,6 +1445,13 @@ def _interleave_iterable_datasets(
probabilities (:obj:`List[float]`, optional, default None): If specified, the new iterable dataset samples
examples from one source at a time according to these probabilities.
seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example.
stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`):
Two strategies are proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous:
- with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples.
- with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting.

Output:
:class:`datasets.IterableDataset`
Expand All @@ -1404,11 +1471,11 @@ def _interleave_iterable_datasets(

# Use cycling or random cycling or sources
if probabilities is None:
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables)
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables, stopping_strategy=stopping_strategy)
else:
generator = np.random.default_rng(seed)
ex_iterable = RandomlyCyclingMultiSourcesExamplesIterable(
ex_iterables, generator=generator, probabilities=probabilities
ex_iterables, generator=generator, probabilities=probabilities, stopping_strategy=stopping_strategy
)
# Set new info - we update the features
# setting the features also ensures to fill missing columns with None
Expand Down
67 changes: 54 additions & 13 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def test_cycling_multi_sources_examples_iterable():
ex_iterable = CyclingMultiSourcesExamplesIterable([ex_iterable1, ex_iterable2])
expected = list(chain(*zip(generate_examples_fn(text="foo"), generate_examples_fn(text="bar"))))

# The cycling stops as soon as one iterable is out of examples (here ex_iterable1), so the last sample from ex_iterable2 is unecessary
expected = expected[:-1]

assert next(iter(ex_iterable)) == expected[0]
assert list(ex_iterable) == expected
assert all((x["id"], x["text"]) == (i // 2, "bar" if i % 2 else "foo") for i, (_, x) in enumerate(ex_iterable))
Expand All @@ -172,9 +175,13 @@ def test_randomly_cycling_multi_sources_examples_iterable(probabilities):
rng, len(iterators), p=probabilities
)
expected = []
lengths = [len(list(ex_iterable1)), len(list(ex_iterable2))]
for i in indices_iterator:
if lengths[0] == 0 or lengths[1] == 0:
break
for key, example in iterators[i]:
expected.append((key, example))
lengths[i] -= 1
break
else:
break
Expand Down Expand Up @@ -995,22 +1002,29 @@ def test_concatenate_datasets_axis_1_with_different_lengths():


@pytest.mark.parametrize(
"probas, seed, expected_length",
"probas, seed, expected_length, stopping_strategy",
[
(None, None, 3 * DEFAULT_N_EXAMPLES),
([1, 0, 0], None, DEFAULT_N_EXAMPLES),
([0, 1, 0], None, DEFAULT_N_EXAMPLES),
([0.2, 0.5, 0.3], 42, None),
([0.1, 0.1, 0.8], 1337, None),
([0.5, 0.2, 0.3], 101010, None),
(None, None, 3 * (DEFAULT_N_EXAMPLES - 1) + 1, "first_exhausted"),
([1, 0, 0], None, DEFAULT_N_EXAMPLES, "first_exhausted"),
([0, 1, 0], None, DEFAULT_N_EXAMPLES, "first_exhausted"),
([0.2, 0.5, 0.3], 42, None, "first_exhausted"),
([0.1, 0.1, 0.8], 1337, None, "first_exhausted"),
([0.5, 0.2, 0.3], 101010, None, "first_exhausted"),
(None, None, 3 * DEFAULT_N_EXAMPLES, "all_exhausted"),
([0.2, 0.5, 0.3], 42, None, "all_exhausted"),
([0.1, 0.1, 0.8], 1337, None, "all_exhausted"),
([0.5, 0.2, 0.3], 101010, None, "all_exhausted"),
],
)
def test_interleave_datasets(dataset: IterableDataset, probas, seed, expected_length):
def test_interleave_datasets(dataset: IterableDataset, probas, seed, expected_length, stopping_strategy):
d1 = dataset
d2 = dataset.map(lambda x: {"id+1": x["id"] + 1, **x})
d3 = dataset.with_format("python")
datasets = [d1, d2, d3]
merged_dataset = interleave_datasets(datasets, probabilities=probas, seed=seed)

merged_dataset = interleave_datasets(
datasets, probabilities=probas, seed=seed, stopping_strategy=stopping_strategy
)

def fill_default(example):
return {"id": None, "id+1": None, **example}
Expand All @@ -1021,7 +1035,9 @@ def fill_default(example):
)
# Check that it is deterministic
if seed is not None:
merged_dataset2 = interleave_datasets([d1, d2, d3], probabilities=probas, seed=seed)
merged_dataset2 = interleave_datasets(
[d1, d2, d3], probabilities=probas, seed=seed, stopping_strategy=stopping_strategy
)
assert list(merged_dataset) == list(merged_dataset2)
# Check features
assert merged_dataset.features == Features({"id": Value("int64"), "id+1": Value("int64")})
Expand All @@ -1035,13 +1051,14 @@ def fill_default(example):
# Compute length it case it's random
if expected_length is None:
expected_length = 0
counts = [len(list(d)) for d in datasets]
counts = np.array([len(list(d)) for d in datasets])
bool_strategy_func = np.all if stopping_strategy == "all_exhausted" else np.any
rng = np.random.default_rng(seed)
for i in RandomlyCyclingMultiSourcesExamplesIterable._iter_random_indices(rng, len(datasets), p=probas):
if counts[i] == 0:
break
counts[i] -= 1
expected_length += 1
if bool_strategy_func(counts <= 0):
break
# Check length
assert len(list(merged_dataset)) == expected_length

Expand All @@ -1060,3 +1077,27 @@ def test_interleave_datasets_with_features(

merged_dataset = interleave_datasets([dataset, dataset_with_features])
assert merged_dataset.features == features


def test_interleave_datasets_with_oversampling():
# Test hardcoded results
d1 = IterableDataset(ExamplesIterable((lambda: (yield from [(i, {"a": i}) for i in [0, 1, 2]])), {}))
d2 = IterableDataset(ExamplesIterable((lambda: (yield from [(i, {"a": i}) for i in [10, 11, 12, 13]])), {}))
d3 = IterableDataset(ExamplesIterable((lambda: (yield from [(i, {"a": i}) for i in [20, 21, 22, 23, 24]])), {}))

expected_values = [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 10, 24]

# Check oversampling strategy without probabilities
assert [x["a"] for x in interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")] == expected_values

# Check oversampling strategy with probabilities
expected_values = [20, 0, 21, 10, 1, 22, 23, 24, 2, 0, 1, 20, 11, 21, 2, 0, 12, 1, 22, 13]

values = [
x["a"]
for x in interleave_datasets(
[d1, d2, d3], probabilities=[0.5, 0.2, 0.3], seed=42, stopping_strategy="all_exhausted"
)
]

assert values == expected_values