Skip to content

Commit

Permalink
Add ability to read-write to SQL databases.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dref360 committed Sep 3, 2022
1 parent a50f268 commit f76b87c
Show file tree
Hide file tree
Showing 8 changed files with 657 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from math import ceil, floor
from pathlib import Path
from random import sample
from sqlite3 import Connection
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1057,6 +1058,48 @@ def from_text(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
).read()

@staticmethod
def from_sql(
path_or_paths: Union[PathLike, List[PathLike]],
table_name: str,
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
"""Create Dataset from text file(s).
Args:
path_or_paths (path-like or list of path-like): Path(s) of the text file(s).
table_name (``str``): Name of the SQL table to read from.
split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset.
features (:class:`Features`, optional): Dataset features.
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :class:`SqlConfig`.
Returns:
:class:`Dataset`
Example:
```py
>>> ds = Dataset.from_sql('path/to/dataset.sqlite')
```
"""
from .io.sql import SqlDatasetReader

return SqlDatasetReader(
path_or_paths,
table_name=table_name,
split=split,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
**kwargs,
).read()

def __del__(self):
if hasattr(self, "_data"):
del self._data
Expand Down Expand Up @@ -4051,6 +4094,38 @@ def to_parquet(

return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write()

def to_sql(
self,
path_or_buf: Union[PathLike, Connection],
table_name: str,
batch_size: Optional[int] = None,
**sql_writer_kwargs,
) -> int:
"""Exports the dataset to SQLite
Args:
path_or_buf (``PathLike`` or ``Connection``): Either a path to a file or a SQL connection.
table_name (``str``): Name of the SQL table to write to.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`
Returns:
int: The number of characters or bytes written
Example:
```py
>>> ds.to_sql("path/to/dataset/directory")
```
"""
# Dynamic import to avoid circular dependency
from .io.sql import SqlDatasetWriter

return SqlDatasetWriter(
self, path_or_buf, table_name=table_name, batch_size=batch_size, **sql_writer_kwargs
).write()

def _push_parquet_shards_to_hub(
self,
repo_id: str,
Expand Down
138 changes: 138 additions & 0 deletions src/datasets/io/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import multiprocessing
import os
from sqlite3 import Connection, connect
from typing import Optional, Union

from .. import Dataset, Features, NamedSplit, config
from ..formatting import query_table
from ..packaged_modules.sql.sql import Sql
from ..utils import logging
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetReader


class SqlDatasetReader(AbstractDatasetReader):
def __init__(
self,
path_or_paths: NestedDataStructureLike[PathLike],
table_name: str,
split: Optional[NamedSplit] = None,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
super().__init__(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
)
path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths}
self.builder = Sql(
cache_dir=cache_dir,
data_files=path_or_paths,
features=features,
table_name=table_name,
**kwargs,
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
)
return dataset


class SqlDatasetWriter:
def __init__(
self,
dataset: Dataset,
path_or_buf: Union[PathLike, Connection],
table_name: str,
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_sql_kwargs,
):

if num_proc is not None and num_proc <= 0:
raise ValueError(f"num_proc {num_proc} must be an integer > 0.")

self.dataset = dataset
self.path_or_buf = path_or_buf
self.table_name = table_name
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.to_sql_kwargs = to_sql_kwargs

def write(self) -> int:
_ = self.to_sql_kwargs.pop("path_or_buf", None)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with connect(self.path_or_buf) as conn:
written = self._write(conn=conn, **self.to_sql_kwargs)
else:
written = self._write(conn=self.path_or_buf, **self.to_sql_kwargs)
return written

def _batch_sql(self, offset):
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
return batch.to_pandas()

def _write(self, conn: Connection, **to_sql_kwargs) -> int:
"""Writes the pyarrow table as SQL to a binary file handle.
Caller is responsible for opening and closing the handle.
"""
written = 0

if self.num_proc is None or self.num_proc == 1:
for offset in logging.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating SQL from Arrow format",
):
df = self._batch_sql(offset)
written += df.to_sql(
self.table_name, conn, **to_sql_kwargs, if_exists="replace" if offset == 0 else "append"
)

else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for idx, df in logging.tqdm(
enumerate(
pool.imap(
self._batch_sql,
[offset for offset in range(0, num_rows, batch_size)],
)
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating SQL from Arrow format",
):
written += df.to_sql(
self.table_name, conn, **to_sql_kwargs, if_exists="replace" if idx == 0 else "append"
)

return written
Empty file.
99 changes: 99 additions & 0 deletions src/datasets/packaged_modules/sql/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import itertools
from dataclasses import dataclass
from sqlite3 import connect
from typing import Dict, List, Optional, Sequence, Union

import pandas as pd
import pyarrow as pa
from typing_extensions import Literal

import datasets
import datasets.config
from datasets.features.features import require_storage_cast
from datasets.table import table_cast


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class SqlConfig(datasets.BuilderConfig):
"""BuilderConfig for SQL."""

index_col: Optional[Union[int, str, List[int], List[str]]] = None
table_name: str = "Dataset"
coerce_float: bool = True
params: Optional[Union[Sequence, Dict]] = None
parse_dates: Optional[Union[List, Dict]] = None
columns: Optional[List[str]] = None
chunksize: int = 10_000
features: Optional[datasets.Features] = None
encoding_errors: Optional[str] = "strict"
on_bad_lines: Literal["error", "warn", "skip"] = "error"

@property
def read_sql_kwargs(self):
read_sql_kwargs = dict(
index_col=self.index_col,
columns=self.columns,
params=self.params,
coerce_float=self.coerce_float,
parse_dates=self.parse_dates,
chunksize=self.chunksize,
)
return read_sql_kwargs


class Sql(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = SqlConfig

def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
data_files = dl_manager.download_and_extract(self.config.data_files)
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table

def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with connect(file) as conn:
sql_file_reader = pd.read_sql(
f"SELECT * FROM `{self.config.table_name}`", conn, **self.config.read_sql_kwargs
)
try:
for batch_idx, df in enumerate(sql_file_reader):
# Drop index column as it is not relevant.
pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore"))
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
10 changes: 10 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import tarfile
import textwrap
import zipfile
from sqlite3 import connect

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -238,6 +240,14 @@ def arrow_path(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def sql_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite")
with connect(path) as conn:
pd.DataFrame.from_records(DATA).to_sql("TABLE_NAME", conn)
return path


@pytest.fixture(scope="session")
def csv_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.csv")
Expand Down

0 comments on commit f76b87c

Please sign in to comment.