Skip to content

Commit

Permalink
added from_generator method to IterableDataset class. (#5052)
Browse files Browse the repository at this point in the history
* added from_generator method to IterableDataset class.

* Move streaming param to __init__

* Test

* Type hint fix

Co-authored-by: mariosasko <mariosasko777@gmail.com>
  • Loading branch information
hamid-vakilzadeh and mariosasko committed Oct 5, 2022
1 parent 41f7fb4 commit 6ad430b
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 102 deletions.
11 changes: 7 additions & 4 deletions src/datasets/io/abc.py
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Optional, Union

from .. import DatasetDict, Features, NamedSplit
from ..arrow_dataset import Dataset
from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit
from ..utils.typing import NestedDataStructureLike, PathLike


Expand All @@ -14,17 +13,19 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
self.path_or_paths = path_or_paths
self.split = split if split or isinstance(path_or_paths, dict) else "train"
self.features = features
self.cache_dir = cache_dir
self.keep_in_memory = keep_in_memory
self.streaming = streaming
self.kwargs = kwargs

@abstractmethod
def read(self) -> Union[Dataset, DatasetDict]:
def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
pass


Expand All @@ -34,13 +35,15 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
self.features = features
self.cache_dir = cache_dir
self.keep_in_memory = keep_in_memory
self.streaming = streaming
self.kwargs = kwargs

@abstractmethod
def read(self) -> Dataset:
def read(self) -> Union[Dataset, IterableDataset]:
pass
50 changes: 30 additions & 20 deletions src/datasets/io/csv.py
Expand Up @@ -18,10 +18,17 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
self.builder = Csv(
Expand All @@ -32,25 +39,28 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


Expand Down
44 changes: 25 additions & 19 deletions src/datasets/io/generator.py
Expand Up @@ -12,10 +12,13 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
**kwargs,
):
super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs)
super().__init__(
features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, streaming=streaming, **kwargs
)
self.builder = Generator(
cache_dir=cache_dir,
features=features,
Expand All @@ -25,23 +28,26 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split="train")
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

# Build dataset for splits
dataset = self.builder.as_dataset(
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset
51 changes: 30 additions & 21 deletions src/datasets/io/json.py
Expand Up @@ -20,11 +20,18 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
field: Optional[str] = None,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
self.field = field
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
Expand All @@ -37,26 +44,28 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = True
try_from_hf_gcs = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


Expand Down
48 changes: 29 additions & 19 deletions src/datasets/io/parquet.py
Expand Up @@ -20,10 +20,17 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
hash = _PACKAGED_DATASETS_MODULES["parquet"][1]
Expand All @@ -36,25 +43,28 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


Expand Down
48 changes: 29 additions & 19 deletions src/datasets/io/text.py
Expand Up @@ -14,10 +14,17 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
self.builder = Text(
Expand All @@ -28,23 +35,26 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset

1 comment on commit 6ad430b

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==6.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010112 / 0.011353 (-0.001241) 0.004478 / 0.011008 (-0.006530) 0.034440 / 0.038508 (-0.004068) 0.036554 / 0.023109 (0.013445) 0.344249 / 0.275898 (0.068351) 0.408554 / 0.323480 (0.085074) 0.006897 / 0.007986 (-0.001089) 0.004011 / 0.004328 (-0.000318) 0.007527 / 0.004250 (0.003277) 0.047496 / 0.037052 (0.010443) 0.375555 / 0.258489 (0.117066) 0.392554 / 0.293841 (0.098714) 0.047271 / 0.128546 (-0.081276) 0.014521 / 0.075646 (-0.061125) 0.306025 / 0.419271 (-0.113246) 0.062730 / 0.043533 (0.019197) 0.350566 / 0.255139 (0.095427) 0.396208 / 0.283200 (0.113008) 0.111605 / 0.141683 (-0.030078) 1.738083 / 1.452155 (0.285929) 1.759966 / 1.492716 (0.267250)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.240025 / 0.018006 (0.222019) 0.539952 / 0.000490 (0.539462) 0.011887 / 0.000200 (0.011687) 0.000460 / 0.000054 (0.000406)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.026451 / 0.037411 (-0.010960) 0.118563 / 0.014526 (0.104037) 0.131528 / 0.176557 (-0.045029) 0.181756 / 0.737135 (-0.555380) 0.131466 / 0.296338 (-0.164872)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.612041 / 0.215209 (0.396832) 5.970681 / 2.077655 (3.893026) 2.321092 / 1.504120 (0.816972) 1.925000 / 1.541195 (0.383805) 1.905375 / 1.468490 (0.436885) 0.731952 / 4.584777 (-3.852825) 5.340776 / 3.745712 (1.595064) 4.895301 / 5.269862 (-0.374560) 2.692918 / 4.565676 (-1.872759) 0.085743 / 0.424275 (-0.338532) 0.012912 / 0.007607 (0.005304) 0.760442 / 0.226044 (0.534398) 7.671428 / 2.268929 (5.402500) 3.106182 / 55.444624 (-52.338442) 2.443888 / 6.876477 (-4.432589) 2.597745 / 2.142072 (0.455672) 0.976539 / 4.805227 (-3.828689) 0.190678 / 6.500664 (-6.309986) 0.075290 / 0.075469 (-0.000179)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.780881 / 1.841788 (-0.060907) 15.614747 / 8.074308 (7.540439) 44.624355 / 10.191392 (34.432963) 1.121043 / 0.680424 (0.440619) 0.672767 / 0.534201 (0.138566) 0.493910 / 0.579283 (-0.085373) 0.624247 / 0.434364 (0.189883) 0.343525 / 0.540337 (-0.196813) 0.354977 / 1.386936 (-1.031959)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.007735 / 0.011353 (-0.003618) 0.004686 / 0.011008 (-0.006322) 0.031424 / 0.038508 (-0.007084) 0.034679 / 0.023109 (0.011570) 0.373570 / 0.275898 (0.097672) 0.402536 / 0.323480 (0.079057) 0.004706 / 0.007986 (-0.003279) 0.003780 / 0.004328 (-0.000548) 0.005455 / 0.004250 (0.001205) 0.045222 / 0.037052 (0.008170) 0.381373 / 0.258489 (0.122884) 0.453102 / 0.293841 (0.159261) 0.047035 / 0.128546 (-0.081512) 0.011823 / 0.075646 (-0.063823) 0.282766 / 0.419271 (-0.136505) 0.066106 / 0.043533 (0.022573) 0.395641 / 0.255139 (0.140502) 0.408449 / 0.283200 (0.125249) 0.105341 / 0.141683 (-0.036342) 1.740375 / 1.452155 (0.288220) 1.809050 / 1.492716 (0.316334)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.242941 / 0.018006 (0.224935) 0.507194 / 0.000490 (0.506704) 0.001236 / 0.000200 (0.001036) 0.000124 / 0.000054 (0.000070)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.022605 / 0.037411 (-0.014807) 0.101847 / 0.014526 (0.087321) 0.115288 / 0.176557 (-0.061269) 0.168134 / 0.737135 (-0.569001) 0.115091 / 0.296338 (-0.181247)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.607324 / 0.215209 (0.392115) 6.030225 / 2.077655 (3.952570) 2.574127 / 1.504120 (1.070007) 2.229928 / 1.541195 (0.688733) 2.118670 / 1.468490 (0.650179) 0.717035 / 4.584777 (-3.867742) 5.275224 / 3.745712 (1.529511) 2.775148 / 5.269862 (-2.494713) 1.818921 / 4.565676 (-2.746756) 0.087328 / 0.424275 (-0.336947) 0.013697 / 0.007607 (0.006090) 0.757987 / 0.226044 (0.531943) 7.407147 / 2.268929 (5.138219) 3.191027 / 55.444624 (-52.253597) 2.495022 / 6.876477 (-4.381455) 2.543505 / 2.142072 (0.401432) 0.900018 / 4.805227 (-3.905209) 0.187861 / 6.500664 (-6.312803) 0.070130 / 0.075469 (-0.005339)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.844456 / 1.841788 (0.002668) 15.386266 / 8.074308 (7.311958) 22.584783 / 10.191392 (12.393391) 1.159857 / 0.680424 (0.479433) 0.771729 / 0.534201 (0.237528) 0.447065 / 0.579283 (-0.132218) 0.574534 / 0.434364 (0.140170) 0.318939 / 0.540337 (-0.221399) 0.332309 / 1.386936 (-1.054627)

CML watermark

Please sign in to comment.