-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
parquet.py
113 lines (100 loc) 路 3.97 KB
/
parquet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
from typing import BinaryIO, Optional, Union
import pyarrow as pa
import pyarrow.parquet as pq
from .. import Dataset, Features, NamedSplit, config
from ..formatting import query_table
from ..packaged_modules import _PACKAGED_DATASETS_MODULES
from ..packaged_modules.parquet.parquet import Parquet
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader
class ParquetDatasetReader(AbstractDatasetReader):
def __init__(
self,
path_or_paths: NestedDataStructureLike[PathLike],
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
streaming: bool = False,
**kwargs,
):
super().__init__(
path_or_paths,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
streaming=streaming,
**kwargs,
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
hash = _PACKAGED_DATASETS_MODULES["parquet"][1]
self.builder = Parquet(
cache_dir=cache_dir,
data_files=path_or_paths,
features=features,
hash=hash,
**kwargs,
)
def read(self):
# Build iterable dataset
if self.streaming:
dataset = self.builder.as_streaming_dataset(split=self.split)
# Build regular (map-style) dataset
else:
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None
self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset
class ParquetDatasetWriter:
def __init__(
self,
dataset: Dataset,
path_or_buf: Union[PathLike, BinaryIO],
batch_size: Optional[int] = None,
**parquet_writer_kwargs,
):
self.dataset = dataset
self.path_or_buf = path_or_buf
self.batch_size = batch_size
self.parquet_writer_kwargs = parquet_writer_kwargs
def write(self) -> int:
batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with open(self.path_or_buf, "wb+") as buffer:
written = self._write(file_obj=buffer, batch_size=batch_size, **self.parquet_writer_kwargs)
else:
written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.parquet_writer_kwargs)
return written
def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -> int:
"""Writes the pyarrow table as Parquet to a binary file handle.
Caller is responsible for opening and closing the handle.
"""
written = 0
_ = parquet_writer_kwargs.pop("path_or_buf", None)
schema = pa.schema(self.dataset.features.type)
writer = pq.ParquetWriter(file_obj, schema=schema, **parquet_writer_kwargs)
for offset in range(0, len(self.dataset), batch_size):
batch = query_table(
table=self.dataset._data,
key=slice(offset, offset + batch_size),
indices=self.dataset._indices if self.dataset._indices is not None else None,
)
writer.write_table(batch)
written += batch.nbytes
writer.close()
return written