-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ability to read-write to SQL databases.
- Loading branch information
Showing
8 changed files
with
657 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import multiprocessing | ||
import os | ||
from sqlite3 import Connection, connect | ||
from typing import Optional, Union | ||
|
||
from .. import Dataset, Features, NamedSplit, config | ||
from ..formatting import query_table | ||
from ..packaged_modules.sql.sql import Sql | ||
from ..utils import logging | ||
from ..utils.typing import NestedDataStructureLike, PathLike | ||
from .abc import AbstractDatasetReader | ||
|
||
|
||
class SqlDatasetReader(AbstractDatasetReader): | ||
def __init__( | ||
self, | ||
path_or_paths: NestedDataStructureLike[PathLike], | ||
table_name: str, | ||
split: Optional[NamedSplit] = None, | ||
features: Optional[Features] = None, | ||
cache_dir: str = None, | ||
keep_in_memory: bool = False, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs | ||
) | ||
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} | ||
self.builder = Sql( | ||
cache_dir=cache_dir, | ||
data_files=path_or_paths, | ||
features=features, | ||
table_name=table_name, | ||
**kwargs, | ||
) | ||
|
||
def read(self): | ||
download_config = None | ||
download_mode = None | ||
ignore_verifications = False | ||
use_auth_token = None | ||
base_path = None | ||
|
||
self.builder.download_and_prepare( | ||
download_config=download_config, | ||
download_mode=download_mode, | ||
ignore_verifications=ignore_verifications, | ||
# try_from_hf_gcs=try_from_hf_gcs, | ||
base_path=base_path, | ||
use_auth_token=use_auth_token, | ||
) | ||
|
||
# Build dataset for splits | ||
dataset = self.builder.as_dataset( | ||
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory | ||
) | ||
return dataset | ||
|
||
|
||
class SqlDatasetWriter: | ||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
path_or_buf: Union[PathLike, Connection], | ||
table_name: str, | ||
batch_size: Optional[int] = None, | ||
num_proc: Optional[int] = None, | ||
**to_sql_kwargs, | ||
): | ||
|
||
if num_proc is not None and num_proc <= 0: | ||
raise ValueError(f"num_proc {num_proc} must be an integer > 0.") | ||
|
||
self.dataset = dataset | ||
self.path_or_buf = path_or_buf | ||
self.table_name = table_name | ||
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE | ||
self.num_proc = num_proc | ||
self.encoding = "utf-8" | ||
self.to_sql_kwargs = to_sql_kwargs | ||
|
||
def write(self) -> int: | ||
_ = self.to_sql_kwargs.pop("path_or_buf", None) | ||
|
||
if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): | ||
with connect(self.path_or_buf) as conn: | ||
written = self._write(conn=conn, **self.to_sql_kwargs) | ||
else: | ||
written = self._write(conn=self.path_or_buf, **self.to_sql_kwargs) | ||
return written | ||
|
||
def _batch_sql(self, offset): | ||
batch = query_table( | ||
table=self.dataset.data, | ||
key=slice(offset, offset + self.batch_size), | ||
indices=self.dataset._indices, | ||
) | ||
return batch.to_pandas() | ||
|
||
def _write(self, conn: Connection, **to_sql_kwargs) -> int: | ||
"""Writes the pyarrow table as SQL to a binary file handle. | ||
Caller is responsible for opening and closing the handle. | ||
""" | ||
written = 0 | ||
|
||
if self.num_proc is None or self.num_proc == 1: | ||
for offset in logging.tqdm( | ||
range(0, len(self.dataset), self.batch_size), | ||
unit="ba", | ||
disable=not logging.is_progress_bar_enabled(), | ||
desc="Creating SQL from Arrow format", | ||
): | ||
df = self._batch_sql(offset) | ||
written += df.to_sql( | ||
self.table_name, conn, **to_sql_kwargs, if_exists="replace" if offset == 0 else "append" | ||
) | ||
|
||
else: | ||
num_rows, batch_size = len(self.dataset), self.batch_size | ||
with multiprocessing.Pool(self.num_proc) as pool: | ||
for idx, df in logging.tqdm( | ||
enumerate( | ||
pool.imap( | ||
self._batch_sql, | ||
[offset for offset in range(0, num_rows, batch_size)], | ||
) | ||
), | ||
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, | ||
unit="ba", | ||
disable=not logging.is_progress_bar_enabled(), | ||
desc="Creating SQL from Arrow format", | ||
): | ||
written += df.to_sql( | ||
self.table_name, conn, **to_sql_kwargs, if_exists="replace" if idx == 0 else "append" | ||
) | ||
|
||
return written |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import itertools | ||
from dataclasses import dataclass | ||
from sqlite3 import connect | ||
from typing import Dict, List, Optional, Sequence, Union | ||
|
||
import pandas as pd | ||
import pyarrow as pa | ||
from typing_extensions import Literal | ||
|
||
import datasets | ||
import datasets.config | ||
from datasets.features.features import require_storage_cast | ||
from datasets.table import table_cast | ||
|
||
|
||
logger = datasets.utils.logging.get_logger(__name__) | ||
|
||
|
||
@dataclass | ||
class SqlConfig(datasets.BuilderConfig): | ||
"""BuilderConfig for SQL.""" | ||
|
||
index_col: Optional[Union[int, str, List[int], List[str]]] = None | ||
table_name: str = "Dataset" | ||
coerce_float: bool = True | ||
params: Optional[Union[Sequence, Dict]] = None | ||
parse_dates: Optional[Union[List, Dict]] = None | ||
columns: Optional[List[str]] = None | ||
chunksize: int = 10_000 | ||
features: Optional[datasets.Features] = None | ||
encoding_errors: Optional[str] = "strict" | ||
on_bad_lines: Literal["error", "warn", "skip"] = "error" | ||
|
||
@property | ||
def read_sql_kwargs(self): | ||
read_sql_kwargs = dict( | ||
index_col=self.index_col, | ||
columns=self.columns, | ||
params=self.params, | ||
coerce_float=self.coerce_float, | ||
parse_dates=self.parse_dates, | ||
chunksize=self.chunksize, | ||
) | ||
return read_sql_kwargs | ||
|
||
|
||
class Sql(datasets.ArrowBasedBuilder): | ||
BUILDER_CONFIG_CLASS = SqlConfig | ||
|
||
def _info(self): | ||
return datasets.DatasetInfo(features=self.config.features) | ||
|
||
def _split_generators(self, dl_manager): | ||
"""We handle string, list and dicts in datafiles""" | ||
if not self.config.data_files: | ||
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") | ||
data_files = dl_manager.download_and_extract(self.config.data_files) | ||
if isinstance(data_files, (str, list, tuple)): | ||
files = data_files | ||
if isinstance(files, str): | ||
files = [files] | ||
files = [dl_manager.iter_files(file) for file in files] | ||
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] | ||
splits = [] | ||
for split_name, files in data_files.items(): | ||
if isinstance(files, str): | ||
files = [files] | ||
files = [dl_manager.iter_files(file) for file in files] | ||
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) | ||
return splits | ||
|
||
def _cast_table(self, pa_table: pa.Table) -> pa.Table: | ||
if self.config.features is not None: | ||
schema = self.config.features.arrow_schema | ||
if all(not require_storage_cast(feature) for feature in self.config.features.values()): | ||
# cheaper cast | ||
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) | ||
else: | ||
# more expensive cast; allows str <-> int/float or str to Audio for example | ||
pa_table = table_cast(pa_table, schema) | ||
return pa_table | ||
|
||
def _generate_tables(self, files): | ||
for file_idx, file in enumerate(itertools.chain.from_iterable(files)): | ||
with connect(file) as conn: | ||
sql_file_reader = pd.read_sql( | ||
f"SELECT * FROM `{self.config.table_name}`", conn, **self.config.read_sql_kwargs | ||
) | ||
try: | ||
for batch_idx, df in enumerate(sql_file_reader): | ||
# Drop index column as it is not relevant. | ||
pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore")) | ||
# Uncomment for debugging (will print the Arrow table size and elements) | ||
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") | ||
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) | ||
yield (file_idx, batch_idx), self._cast_table(pa_table) | ||
except ValueError as e: | ||
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.