From dc5cb17289061c0ad13ab6a79a1b8981d3beb37b Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Wed, 24 Aug 2022 18:46:07 +0200 Subject: [PATCH] Add oversampling strategies to interleave datasets (#4831) * add a new strategy for interleave_datasets (oversampling strat) * format code according to the library style * update interleave_datasets description * Add correct Error type for a non implemented strategy in interleave_datasets * correcting an example in the comments * adding comment to the default case of _interleave_map_style_datasets * correct the case of oversampling strategy with no probabilities of _interleave_map_style_datasets and add comments * reformat with datasets's style * add tests for oversampling strategy in interleave_datasets * mention of the sampling strategy of interleave_datasets in the documentation of process.mdx --- docs/source/process.mdx | 38 ++++++++++++++++++++---- src/datasets/arrow_dataset.py | 54 +++++++++++++++++++++++++++++++++-- src/datasets/combine.py | 48 +++++++++++++++++++++++++++---- tests/test_arrow_dataset.py | 35 +++++++++++++++++++++++ 4 files changed, 161 insertions(+), 14 deletions(-) diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 14d7be77cc2..160b2243cbe 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -489,12 +489,6 @@ Separate datasets can be concatenated if they share the same column types. Conca >>> bert_dataset = concatenate_datasets([bookcorpus, wiki]) ``` - - -You can also mix several datasets together by taking alternating examples from each one to create a new dataset. This is known as *interleaving*, which is enabled by the [`interleave_datasets`] function. Both [`interleave_datasets`] and [`concatenate_datasets`] work with regular [`Dataset`] and [`IterableDataset`] objects. Refer to the [Stream](./stream#interleave) guide for an example of how to interleave datasets. - - - You can also concatenate two datasets horizontally by setting `axis=1` as long as the datasets have the same number of rows: ```py @@ -503,6 +497,38 @@ You can also concatenate two datasets horizontally by setting `axis=1` as long a >>> bookcorpus_with_ids = concatenate_datasets([bookcorpus, bookcorpus_ids], axis=1) ``` +### Interleave + +You can also mix several datasets together by taking alternating examples from each one to create a new dataset. This is known as *interleaving*, which is enabled by the [`interleave_datasets`] function. Both [`interleave_datasets`] and [`concatenate_datasets`] work with regular [`Dataset`] and [`IterableDataset`] objects. +Refer to the [Stream](./stream#interleave) guide for an example of how to interleave [`IterableDataset`] objects. + +You can define sampling probabilities for each of the original datasets to specify how to interleave the datasets. +In this case, the new dataset is constructed by getting examples one by one from a random dataset until one of the datasets runs out of samples. + +```py +>>> seed = 42 +>>> probabilities = [0.3, 0.5, 0.2] +>>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) +>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) +>>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) +>>> dataset = interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed) +>>> dataset["a"] +[10, 11, 20, 12, 0, 21, 13] +``` + +In the case of [`Dataset`] objects, you can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is an subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples. +You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached. +Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples` + +```py +>>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) +>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) +>>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) +>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") +>>> dataset["a"] +[0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20] +``` + ## Format The [`~Dataset.set_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter and the columns you want to format. Formatting is applied on-the-fly. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 7bf64c0e836..bc2c5ab37ba 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4851,6 +4851,7 @@ def _interleave_map_style_datasets( seed: Optional[int] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, + stopping_strategy: Optional[str] = "first_exhausted", **kwargs, ) -> "Dataset": """ @@ -4866,11 +4867,22 @@ def _interleave_map_style_datasets( seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. split (:class:`NamedSplit`, optional): Name of the dataset split. + stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`): + Two strategies are proposed right now. + By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples. + If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once. + Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous: + - with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples. + - with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting. **kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. Output: :class:`datasets.Dataset` """ + if stopping_strategy not in ["first_exhausted", "all_exhausted"]: + raise ValueError( + f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}" + ) # To interleave the datasets, we concatenate them and then we re-order the indices concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split) @@ -4878,12 +4890,39 @@ def _interleave_map_style_datasets( # Let's now build the indices to pass to .select() lengths = [len(dset) for dset in datasets] offsets = np.cumsum([0] + lengths[:-1]) - if probabilities is None: + + # if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted" + oversampling = stopping_strategy == "all_exhausted" + + if probabilities is None and not oversampling: + # Undersampling situation with cycling between each sources # Example:: If lengths of the datasets are [3, 4, 5] # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] # Note that we only have 3 examples per dataset since the first dataset ran out of examples + + # Reasoning behind the following operation: keeping the min_length first indices of each dataset + # while offsetting in order to correspond to the right indices of the concatenated dataset + # and flattening to effectively interleave the datasets indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() + elif probabilities is None: + # Oversampling situation with cycling between each sources + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11] + # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples + + # Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset + # For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1] + indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1)) + + # We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets + indices = (indices + offsets).flatten().tolist() + else: + # boolean array indicating if at index i if the dataset_i has been fully exhausted + is_exhausted = np.full(len(lengths), False) + + # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted + # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once + bool_strategy_func = np.all if oversampling else np.any def iter_random_indices(): """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" @@ -4894,12 +4933,21 @@ def iter_random_indices(): current_index = [0] * len(datasets) indices = [] for source_idx in iter_random_indices(): - # we ran out of examples, let's stop - if current_index[source_idx] >= lengths[source_idx]: + # If no oversampling, we stop as soon as a dataset has ran out of examples (np.any) + # Otherwise, we stop as soon as every dataset has ran out of examples (np.all) + if bool_strategy_func(is_exhausted): + # the stopping condition was reached, let's stop break + # let's add the example at the current index of the `source_idx`-th dataset indices.append(current_index[source_idx] + offsets[source_idx]) current_index[source_idx] += 1 + + # we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0 + if current_index[source_idx] >= lengths[source_idx]: + is_exhausted[source_idx] = True + current_index[source_idx] = 0 + return concatenated_datasets.select(indices, **kwargs) diff --git a/src/datasets/combine.py b/src/datasets/combine.py index 11ee2676d15..b2463d173af 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -19,6 +19,7 @@ def interleave_datasets( seed: Optional[int] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, + stopping_strategy: Optional[str] = "first_exhausted", ) -> DatasetType: """ Interleave several datasets (sources) into a single dataset. @@ -29,7 +30,8 @@ def interleave_datasets( If ``probabilities`` is ``None`` (default) the new dataset is constructed by cycling between each source to get the examples. If ``probabilities`` is not ``None``, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. - The resulting dataset ends when one of the source datasets runs out of examples. + The resulting dataset ends when one of the source datasets runs out of examples except when ``oversampling`` is ``True`` and :class:`Dataset` objects are used, + in which case, the resulting dataset ends when all datasets have ran out of examples at least one time. Args: datasets (:obj:`List[Dataset]` or :obj:`List[IterableDataset]`): list of datasets to interleave @@ -40,7 +42,14 @@ def interleave_datasets( split ([`NamedSplit`], *optional*): Name of the dataset split. - + stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`): + Two strategies are proposed right now for :class:`Dataset` objects. + For :class:`IterableDataset` objects, only `first_exhausted` is proposed right now. + By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples. + If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once. + Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous: + - with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples. + - with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting. Returns: :class:`Dataset` or :class:`IterableDataset`: Return type depends on the input `datasets` parameter. `Dataset` if the input is a list of `Dataset`, `IterableDataset` if the input is a list of @@ -50,17 +59,38 @@ def interleave_datasets( For regular datasets (map-style): + >>> from datasets import Dataset, interleave_datasets >>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) >>> d2 = Dataset.from_dict({"a": [10, 11, 12]}) >>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted") + >>> dataset["a"] + [10, 0, 11, 1, 2, 20, 12, 10, 0, 1, 2, 21, 0, 11, 1, 2, 0, 1, 12, 2, 10, 0, 22] + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42) + >>> dataset["a"] + [10, 0, 11, 1, 2] + >>> dataset = interleave_datasets([d1, d2, d3]) + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) + >>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + >>> d3 = Dataset.from_dict({"a": [20, 21, 22, 23, 24]}) >>> dataset = interleave_datasets([d1, d2, d3]) >>> dataset["a"] [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 0, 24] >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42) >>> dataset["a"] - [10, 0, 11, 1, 2, 20, 12] - + [10, 0, 11, 1, 2] + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted") + >>> dataset["a"] + [10, 0, 11, 1, 2, 20, 12, 13, ..., 0, 1, 2, 0, 24] For datasets in streaming mode (iterable): >>> from datasets import load_dataset, interleave_datasets @@ -89,8 +119,16 @@ def interleave_datasets( raise ValueError( f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." ) + if iterable and stopping_strategy != "first_exhausted": + raise NotImplementedError( + f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}." + ) + if stopping_strategy not in ["first_exhausted", "all_exhausted"]: + raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.") if map_style: - return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split) + return _interleave_map_style_datasets( + datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy + ) else: return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index e637e681998..7184fe25255 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2691,6 +2691,41 @@ def test_interleave_datasets_probabilities(): ) +def test_interleave_datasets_oversampling_strategy(): + d1 = Dataset.from_dict({"a": [0, 1, 2]}) + d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0]) + dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") + expected_length = 3 * max(len(d1), len(d2), len(d3)) + expected_values = [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20] # hardcoded + assert isinstance(dataset, Dataset) + assert len(dataset) == expected_length + assert dataset["a"] == expected_values + assert dataset._fingerprint == interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")._fingerprint + + +def test_interleave_datasets_probabilities_oversampling_strategy(): + seed = 42 + probabilities = [0.3, 0.5, 0.2] + d1 = Dataset.from_dict({"a": [0, 1, 2]}) + d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0]) + dataset = interleave_datasets( + [d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed + ) + expected_length = 16 # hardcoded + expected_values = [10, 11, 20, 12, 0, 21, 13, 10, 1, 11, 12, 22, 13, 20, 10, 2] # hardcoded + assert isinstance(dataset, Dataset) + assert len(dataset) == expected_length + assert dataset["a"] == expected_values + assert ( + dataset._fingerprint + == interleave_datasets( + [d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed + )._fingerprint + ) + + @pytest.mark.parametrize( "column, expected_dtype", [(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")],