Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added from_generator method to IterableDataset class. #5052

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/datasets/io/abc.py
@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from typing import Optional, Union

from .. import DatasetDict, Features, NamedSplit
from ..arrow_dataset import Dataset
from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit
from ..utils.typing import NestedDataStructureLike, PathLike


Expand All @@ -14,17 +13,19 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
self.path_or_paths = path_or_paths
self.split = split if split or isinstance(path_or_paths, dict) else "train"
self.features = features
self.cache_dir = cache_dir
self.keep_in_memory = keep_in_memory
self.streaming = streaming
self.kwargs = kwargs

@abstractmethod
def read(self) -> Union[Dataset, DatasetDict]:
def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]:
pass


Expand All @@ -34,13 +35,15 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
self.features = features
self.cache_dir = cache_dir
self.keep_in_memory = keep_in_memory
self.streaming = streaming
self.kwargs = kwargs

@abstractmethod
def read(self) -> Dataset:
def read(self) -> Union[Dataset, IterableDataset]:
pass
50 changes: 30 additions & 20 deletions src/datasets/io/csv.py
Expand Up @@ -18,10 +18,17 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
self.builder = Csv(
Expand All @@ -32,25 +39,28 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


Expand Down
44 changes: 25 additions & 19 deletions src/datasets/io/generator.py
Expand Up @@ -12,10 +12,13 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
gen_kwargs: Optional[dict] = None,
**kwargs,
):
super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs)
super().__init__(
features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, streaming=streaming, **kwargs
)
self.builder = Generator(
cache_dir=cache_dir,
features=features,
Expand All @@ -25,23 +28,26 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split="train")
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

# Build dataset for splits
dataset = self.builder.as_dataset(
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset
51 changes: 30 additions & 21 deletions src/datasets/io/json.py
Expand Up @@ -20,11 +20,18 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
field: Optional[str] = None,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
self.field = field
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
Expand All @@ -37,26 +44,28 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = True
try_from_hf_gcs = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


Expand Down
48 changes: 29 additions & 19 deletions src/datasets/io/parquet.py
Expand Up @@ -20,10 +20,17 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
hash = _PACKAGED_DATASETS_MODULES["parquet"][1]
Expand All @@ -36,25 +43,28 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


Expand Down
48 changes: 29 additions & 19 deletions src/datasets/io/text.py
Expand Up @@ -14,10 +14,17 @@ def __init__(
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
self.builder = Text(
Expand All @@ -28,23 +35,26 @@ def __init__(
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset