From edb756a24b6e5792ce6206b3195ff302e55b951a Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 11 Aug 2022 17:39:04 +0200 Subject: [PATCH 01/10] add a new strategy for interleave_datasets (oversampling strat) --- src/datasets/arrow_dataset.py | 47 ++++++++++++++++++++++++++++++++--- src/datasets/combine.py | 41 +++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 68c0fd8da68..978bfe40f6b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4851,6 +4851,7 @@ def _interleave_map_style_datasets( seed: Optional[int] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, + stopping_strategy: Optional[str] = "first_exhausted", **kwargs, ) -> "Dataset": """ @@ -4866,11 +4867,22 @@ def _interleave_map_style_datasets( seed (:obj:`int`, optional, default None): The random seed used to choose a source for each example. info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc. split (:class:`NamedSplit`, optional): Name of the dataset split. + stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`): + Two strategies are proposed right now. + By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples. + If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once. + Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous: + - with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples. + - with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting. **kwargs (additional keyword arguments): Keyword arguments to be passed to :meth:`datasets.Datasets.select` when selecting the indices used to interleave the datasets. Output: :class:`datasets.Dataset` """ + if stopping_strategy not in ["first_exhausted", "all_exhausted"]: + raise ValueError( + f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}" + ) # To interleave the datasets, we concatenate them and then we re-order the indices concatenated_datasets = _concatenate_map_style_datasets(datasets, info=info, split=split) @@ -4878,12 +4890,32 @@ def _interleave_map_style_datasets( # Let's now build the indices to pass to .select() lengths = [len(dset) for dset in datasets] offsets = np.cumsum([0] + lengths[:-1]) - if probabilities is None: + + # if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted" + oversampling = (stopping_strategy=="all_exhausted") + + if probabilities is None and not oversampling: + # Undersampling situation with cycling between each sources # Example:: If lengths of the datasets are [3, 4, 5] # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] # Note that we only have 3 examples per dataset since the first dataset ran out of examples indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() + elif probabilities is None: + # Oversampling situation with cycling between each sources + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 0, 11] + # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples + indices = (offsets.reshape(1, -1) + np.arange(max(lengths)).reshape(-1, 1)) + + # We have to keep the indices to their respective dataset offsets + indices = np.mod(indices,np.cumsum(lengths)).flatten().tolist() + else: + # boolean array indicating if at index i if the dataset_i has been fully exhausted + is_exhausted = np.full(len(lengths), False) + + # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted + # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once + bool_strategy_func = np.all if oversampling else np.any def iter_random_indices(): """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" @@ -4894,12 +4926,21 @@ def iter_random_indices(): current_index = [0] * len(datasets) indices = [] for source_idx in iter_random_indices(): - # we ran out of examples, let's stop - if current_index[source_idx] >= lengths[source_idx]: + # If no oversampling, we stop as soon as a dataset has ran out of examples (np.any) + # Otherwise, we stop as soon as every dataset has ran out of examples (np.all) + if bool_strategy_func(is_exhausted): + # the stopping condition was reached, let's stop break + # let's add the example at the current index of the `source_idx`-th dataset indices.append(current_index[source_idx] + offsets[source_idx]) current_index[source_idx] += 1 + + # we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0 + if current_index[source_idx] >= lengths[source_idx]: + is_exhausted[source_idx] = True + current_index[source_idx] = 0 + return concatenated_datasets.select(indices, **kwargs) diff --git a/src/datasets/combine.py b/src/datasets/combine.py index 0c4a1773038..a9d18186cf7 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -19,6 +19,7 @@ def interleave_datasets( seed: Optional[int] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, + stopping_strategy: Optional[str] = "first_exhausted" ) -> DatasetType: """ Interleave several datasets (sources) into a single dataset. @@ -40,7 +41,14 @@ def interleave_datasets( split ([`NamedSplit`], *optional*): Name of the dataset split. - + stopping_strategy (Optional :obj:`str`, defaults to `first_exhausted`): + Two strategies are proposed right now for :class:`Dataset` objects. + For :class:`IterableDataset` objects, only `first_exhausted` is proposed right now. + By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples. + If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once. + Note that if the strategy is `all_exhausted`, the interleaved dataset size can get enormous: + - with no probabilities, the resulting dataset will have max_length_datasets*nb_dataset samples. + - with given probabilities, the resulting dataset will have more samples if some datasets have really low probability of visiting. Returns: :class:`Dataset` or :class:`IterableDataset`: Return type depends on the input `datasets` parameter. `Dataset` if the input is a list of `Dataset`, `IterableDataset` if the input is a list of @@ -50,17 +58,38 @@ def interleave_datasets( For regular datasets (map-style): + >>> from datasets import Dataset, interleave_datasets >>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) >>> d2 = Dataset.from_dict({"a": [10, 11, 12]}) >>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted") + >>> dataset["a"] + [10, 0, 11, 1, 2, 20, 12, 10, 0, 1, 2, 21, 0, 11, 1, 2, 0, 1, 12, 2, 10, 0, 22] + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42) + >>> dataset["a"] + [10, 0, 11, 1, 2] + >>> dataset = interleave_datasets([d1, d2, d3]) + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) + >>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + >>> d3 = Dataset.from_dict({"a": [20, 21, 22, 23, 24]}) >>> dataset = interleave_datasets([d1, d2, d3]) >>> dataset["a"] [0, 10, 20, 1, 11, 21, 2, 12, 22] + >>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") + >>> dataset["a"] + [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 23, 1, 0, 24] >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42) >>> dataset["a"] - [10, 0, 11, 1, 2, 20, 12] - + [10, 0, 11, 1, 2] + >>> dataset = interleave_datasets([d1, d2, d3], probabilities=[0.7, 0.2, 0.1], seed=42, stopping_strategy="all_exhausted") + >>> dataset["a"] + [10, 0, 11, 1, 2, 20, 12, 13, ..., 0, 1, 2, 0, 24] For datasets in streaming mode (iterable): >>> from datasets import load_dataset, interleave_datasets @@ -89,8 +118,12 @@ def interleave_datasets( raise ValueError( f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." ) + if (iterable and stopping_strategy != "first_exhausted") or (stopping_strategy not in ["first_exhausted", "all_exhausted"]): + raise ValueError( + f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}" + ) if map_style: - return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split) + return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy) else: return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split) From 345a5e3e31885e2a22bd8b5c3383f57e56abb0df Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 11 Aug 2022 17:40:43 +0200 Subject: [PATCH 02/10] format code according to the library style --- src/datasets/arrow_dataset.py | 6 +++--- src/datasets/combine.py | 10 +++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 978bfe40f6b..869cef4c7ed 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4892,7 +4892,7 @@ def _interleave_map_style_datasets( offsets = np.cumsum([0] + lengths[:-1]) # if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted" - oversampling = (stopping_strategy=="all_exhausted") + oversampling = stopping_strategy == "all_exhausted" if probabilities is None and not oversampling: # Undersampling situation with cycling between each sources @@ -4904,10 +4904,10 @@ def _interleave_map_style_datasets( # Oversampling situation with cycling between each sources # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 0, 11] # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples - indices = (offsets.reshape(1, -1) + np.arange(max(lengths)).reshape(-1, 1)) + indices = offsets.reshape(1, -1) + np.arange(max(lengths)).reshape(-1, 1) # We have to keep the indices to their respective dataset offsets - indices = np.mod(indices,np.cumsum(lengths)).flatten().tolist() + indices = np.mod(indices, np.cumsum(lengths)).flatten().tolist() else: # boolean array indicating if at index i if the dataset_i has been fully exhausted diff --git a/src/datasets/combine.py b/src/datasets/combine.py index a9d18186cf7..5937ab3645d 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -19,7 +19,7 @@ def interleave_datasets( seed: Optional[int] = None, info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, - stopping_strategy: Optional[str] = "first_exhausted" + stopping_strategy: Optional[str] = "first_exhausted", ) -> DatasetType: """ Interleave several datasets (sources) into a single dataset. @@ -118,12 +118,16 @@ def interleave_datasets( raise ValueError( f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." ) - if (iterable and stopping_strategy != "first_exhausted") or (stopping_strategy not in ["first_exhausted", "all_exhausted"]): + if (iterable and stopping_strategy != "first_exhausted") or ( + stopping_strategy not in ["first_exhausted", "all_exhausted"] + ): raise ValueError( f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}" ) if map_style: - return _interleave_map_style_datasets(datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy) + return _interleave_map_style_datasets( + datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy + ) else: return _interleave_iterable_datasets(datasets, probabilities, seed, info=info, split=split) From da58422429280d1ce4e7abc84ea4bcde8ba7f01a Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 11 Aug 2022 18:08:22 +0200 Subject: [PATCH 03/10] update interleave_datasets description --- src/datasets/combine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/datasets/combine.py b/src/datasets/combine.py index 5937ab3645d..41269a92454 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -30,7 +30,8 @@ def interleave_datasets( If ``probabilities`` is ``None`` (default) the new dataset is constructed by cycling between each source to get the examples. If ``probabilities`` is not ``None``, the new dataset is constructed by getting examples from a random source at a time according to the provided probabilities. - The resulting dataset ends when one of the source datasets runs out of examples. + The resulting dataset ends when one of the source datasets runs out of examples except when ``oversampling`` is ``True`` and :class:`Dataset` objects are used, + in which case, the resulting dataset ends when all datasets have ran out of examples at least one time. Args: datasets (:obj:`List[Dataset]` or :obj:`List[IterableDataset]`): list of datasets to interleave From c0236dde7efdfba2cc1eb8f3f5658358c72818aa Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 11:27:38 +0200 Subject: [PATCH 04/10] Add correct Error type for a non implemented strategy in interleave_datasets --- src/datasets/combine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/datasets/combine.py b/src/datasets/combine.py index 41269a92454..4df81d06aa0 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -119,12 +119,12 @@ def interleave_datasets( raise ValueError( f"Unable to interleave a {type(datasets[0])} with a {type(dataset)}. Expected a list of Dataset objects or a list of IterableDataset objects." ) - if (iterable and stopping_strategy != "first_exhausted") or ( - stopping_strategy not in ["first_exhausted", "all_exhausted"] - ): - raise ValueError( - f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}" + if iterable and stopping_strategy != "first_exhausted": + raise NotImplementedError( + f"{stopping_strategy} stopping strategy in `interleave_datasets` is not implemented yet with a list of {type(datasets[0])}." ) + if stopping_strategy not in ["first_exhausted", "all_exhausted"]: + raise ValueError(f"{stopping_strategy} is not supported. Please enter a valid stopping_strategy.") if map_style: return _interleave_map_style_datasets( datasets, probabilities, seed, info=info, split=split, stopping_strategy=stopping_strategy From 76ab4c3a07ed6c0b22107b7f819c2a4cfae53483 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 13:11:27 +0200 Subject: [PATCH 05/10] correcting an example in the comments --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 869cef4c7ed..67564df92fe 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4902,7 +4902,7 @@ def _interleave_map_style_datasets( indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() elif probabilities is None: # Oversampling situation with cycling between each sources - # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 0, 11] + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11] # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples indices = offsets.reshape(1, -1) + np.arange(max(lengths)).reshape(-1, 1) From c63379a4a4bf23cbcb47b3a96c5643f98d31cd33 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 13:25:23 +0200 Subject: [PATCH 06/10] adding comment to the default case of _interleave_map_style_datasets --- src/datasets/arrow_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 67564df92fe..f531ab8f3f2 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4899,6 +4899,10 @@ def _interleave_map_style_datasets( # Example:: If lengths of the datasets are [3, 4, 5] # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] # Note that we only have 3 examples per dataset since the first dataset ran out of examples + + # Reasoning behind the following operation: keeping the min_length first indices of each dataset + # while offsetting in order to correspond to the right indices of the concatenated dataset + # and flattening to effectively interleave the datasets indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() elif probabilities is None: # Oversampling situation with cycling between each sources From 88d12f1a71f58992c70b1d12b58c83d19ae378f4 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 13:26:22 +0200 Subject: [PATCH 07/10] correct the case of oversampling strategy with no probabilities of _interleave_map_style_datasets and add comments --- src/datasets/arrow_dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f531ab8f3f2..210fb97dc1c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4908,10 +4908,13 @@ def _interleave_map_style_datasets( # Oversampling situation with cycling between each sources # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11] # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples - indices = offsets.reshape(1, -1) + np.arange(max(lengths)).reshape(-1, 1) - # We have to keep the indices to their respective dataset offsets - indices = np.mod(indices, np.cumsum(lengths)).flatten().tolist() + # Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset + # For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1] + indices = np.mod(np.arange(max(lengths)).reshape(-1,1), np.array(lengths).reshape(1,-1)) + + # We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets + indices = (indices + offsets).flatten().tolist() else: # boolean array indicating if at index i if the dataset_i has been fully exhausted From de490ebc79338b0f37c570f3997cee64f56552ef Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 13:53:54 +0200 Subject: [PATCH 08/10] reformat with datasets's style --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 210fb97dc1c..ec70c547349 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4911,7 +4911,7 @@ def _interleave_map_style_datasets( # Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset # For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1] - indices = np.mod(np.arange(max(lengths)).reshape(-1,1), np.array(lengths).reshape(1,-1)) + indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1)) # We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets indices = (indices + offsets).flatten().tolist() From 1be09cc4dc9ee93d3fdfa378971cf4898b2c21a9 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 13:54:36 +0200 Subject: [PATCH 09/10] add tests for oversampling strategy in interleave_datasets --- tests/test_arrow_dataset.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index e637e681998..7184fe25255 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2691,6 +2691,41 @@ def test_interleave_datasets_probabilities(): ) +def test_interleave_datasets_oversampling_strategy(): + d1 = Dataset.from_dict({"a": [0, 1, 2]}) + d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0]) + dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") + expected_length = 3 * max(len(d1), len(d2), len(d3)) + expected_values = [0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20] # hardcoded + assert isinstance(dataset, Dataset) + assert len(dataset) == expected_length + assert dataset["a"] == expected_values + assert dataset._fingerprint == interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted")._fingerprint + + +def test_interleave_datasets_probabilities_oversampling_strategy(): + seed = 42 + probabilities = [0.3, 0.5, 0.2] + d1 = Dataset.from_dict({"a": [0, 1, 2]}) + d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) + d3 = Dataset.from_dict({"a": [22, 21, 20]}).select([2, 1, 0]) + dataset = interleave_datasets( + [d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed + ) + expected_length = 16 # hardcoded + expected_values = [10, 11, 20, 12, 0, 21, 13, 10, 1, 11, 12, 22, 13, 20, 10, 2] # hardcoded + assert isinstance(dataset, Dataset) + assert len(dataset) == expected_length + assert dataset["a"] == expected_values + assert ( + dataset._fingerprint + == interleave_datasets( + [d1, d2, d3], stopping_strategy="all_exhausted", probabilities=probabilities, seed=seed + )._fingerprint + ) + + @pytest.mark.parametrize( "column, expected_dtype", [(["a", "b", "c", "d"], "string"), ([1, 2, 3, 4], "int64"), ([1.0, 2.0, 3.0, 4.0], "float64")], From 7ce7a15f6084a35b9cf638633379dd053ddf48c7 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Wed, 24 Aug 2022 14:27:23 +0200 Subject: [PATCH 10/10] mention of the sampling strategy of interleave_datasets in the documentation of process.mdx --- docs/source/process.mdx | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 14d7be77cc2..160b2243cbe 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -489,12 +489,6 @@ Separate datasets can be concatenated if they share the same column types. Conca >>> bert_dataset = concatenate_datasets([bookcorpus, wiki]) ``` - - -You can also mix several datasets together by taking alternating examples from each one to create a new dataset. This is known as *interleaving*, which is enabled by the [`interleave_datasets`] function. Both [`interleave_datasets`] and [`concatenate_datasets`] work with regular [`Dataset`] and [`IterableDataset`] objects. Refer to the [Stream](./stream#interleave) guide for an example of how to interleave datasets. - - - You can also concatenate two datasets horizontally by setting `axis=1` as long as the datasets have the same number of rows: ```py @@ -503,6 +497,38 @@ You can also concatenate two datasets horizontally by setting `axis=1` as long a >>> bookcorpus_with_ids = concatenate_datasets([bookcorpus, bookcorpus_ids], axis=1) ``` +### Interleave + +You can also mix several datasets together by taking alternating examples from each one to create a new dataset. This is known as *interleaving*, which is enabled by the [`interleave_datasets`] function. Both [`interleave_datasets`] and [`concatenate_datasets`] work with regular [`Dataset`] and [`IterableDataset`] objects. +Refer to the [Stream](./stream#interleave) guide for an example of how to interleave [`IterableDataset`] objects. + +You can define sampling probabilities for each of the original datasets to specify how to interleave the datasets. +In this case, the new dataset is constructed by getting examples one by one from a random dataset until one of the datasets runs out of samples. + +```py +>>> seed = 42 +>>> probabilities = [0.3, 0.5, 0.2] +>>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) +>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) +>>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) +>>> dataset = interleave_datasets([d1, d2, d3], probabilities=probabilities, seed=seed) +>>> dataset["a"] +[10, 11, 20, 12, 0, 21, 13] +``` + +In the case of [`Dataset`] objects, you can also specify the `stopping_strategy`. The default strategy, `first_exhausted`, is an subsampling strategy, i.e the dataset construction is stopped as soon one of the dataset runs out of samples. +You can specify `stopping_strategy=all_exhausted` to execute an oversampling strategy. In this case, the dataset construction is stopped as soon as every samples in every dataset has been added at least once. In practice, it means that if a dataset is exhausted, it will return to the beginning of this dataset until the stop criterion has been reached. +Note that if no sampling probabilities are specified, the new dataset will have `max_length_datasets*nb_dataset samples` + +```py +>>> d1 = Dataset.from_dict({"a": [0, 1, 2]}) +>>> d2 = Dataset.from_dict({"a": [10, 11, 12, 13]}) +>>> d3 = Dataset.from_dict({"a": [20, 21, 22]}) +>>> dataset = interleave_datasets([d1, d2, d3], stopping_strategy="all_exhausted") +>>> dataset["a"] +[0, 10, 20, 1, 11, 21, 2, 12, 22, 0, 13, 20] +``` + ## Format The [`~Dataset.set_format`] function changes the format of a column to be compatible with some common data formats. Specify the output you'd like in the `type` parameter and the columns you want to format. Formatting is applied on-the-fly.