Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to read-write to SQL databases. #4928

Merged
merged 24 commits into from Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f76b87c
Add ability to read-write to SQL databases.
Dref360 Sep 3, 2022
5747ad6
Fix issue where pandas<1.4.0 doesn't return the number of rows
Dref360 Sep 3, 2022
3811a5e
Fix issue where connections were not closed properly
Dref360 Sep 3, 2022
27d56b7
Apply suggestions from code review
Dref360 Sep 5, 2022
e9af3cf
Change according to reviews
Dref360 Sep 5, 2022
87eeb1a
Change according to reviews
Dref360 Sep 17, 2022
70e57c7
Merge main
Dref360 Sep 17, 2022
c3597c9
Inherit from AbstractDatasetInputStream in SqlDatasetReader
Dref360 Sep 17, 2022
61cf29a
Revert typing in SQLDatasetReader as we do not support Connexion
Dref360 Sep 18, 2022
453f2c3
Align API with Pandas/Daskk
mariosasko Sep 21, 2022
5410f51
Update tests
mariosasko Sep 21, 2022
3c128be
Update docs
mariosasko Sep 21, 2022
40268ae
Update some more tests
mariosasko Sep 21, 2022
7830d91
Merge branch 'main' of github.com:huggingface/datasets into HF-3094/i…
mariosasko Sep 21, 2022
dc005df
Missing comma
mariosasko Sep 21, 2022
a3c39d9
Small docs fix
mariosasko Sep 21, 2022
7c4999e
Style
mariosasko Sep 21, 2022
920de97
Update src/datasets/arrow_dataset.py
mariosasko Sep 23, 2022
9ecdb1f
Update src/datasets/packaged_modules/sql/sql.py
mariosasko Sep 23, 2022
27c9674
Address some comments
mariosasko Sep 23, 2022
ad20c27
Merge branch 'HF-3094/io_sql' of github.com:Dref360/datasets into HF-…
mariosasko Sep 23, 2022
81ad0e4
Address the rest
mariosasko Sep 23, 2022
3714fb0
Improve tests
mariosasko Sep 23, 2022
f3610c8
sqlalchemy required tip
mariosasko Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/loading.mdx
Expand Up @@ -196,6 +196,15 @@ To load remote Parquet files via HTTP, pass the URLs instead:
>>> wiki = load_dataset("parquet", data_files=data_files, split="train")
```

### SQLite

Datasets stored as a Table in a SQLite file can be loaded with:

```py
>>> from datasets import Dataset
>>> dataset = Dataset.from_sql('sqlite_file.db' table_name='Dataset')
```

## In-memory data

🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.
Expand Down
76 changes: 76 additions & 0 deletions src/datasets/arrow_dataset.py
Expand Up @@ -33,6 +33,7 @@
from math import ceil, floor
from pathlib import Path
from random import sample
from sqlite3 import Connection
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1091,6 +1092,45 @@ def from_text(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
).read()

@staticmethod
def from_sql(
path_or_paths: Union[PathLike, List[PathLike]],
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
table_name: str,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
"""Create Dataset from SQLite file(s).

Args:
path_or_paths (path-like or list of path-like): Path(s) of the SQLite URI(s).
table_name (``str``): Name of the SQL table to read from.
features (:class:`Features`, optional): Dataset features.
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :class:`SqlConfig`.

Returns:
:class:`Dataset`

Example:

```py
>>> ds = Dataset.from_sql('path/to/dataset.sqlite')
```
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
"""
from .io.sql import SqlDatasetReader

return SqlDatasetReader(
path_or_paths,
table_name=table_name,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
**kwargs,
).read()

def __del__(self):
if hasattr(self, "_data"):
del self._data
Expand Down Expand Up @@ -4085,6 +4125,42 @@ def to_parquet(

return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write()

def to_sql(
self,
path_or_buf: Union[PathLike, Connection],
table_name: str,
batch_size: Optional[int] = None,
**sql_writer_kwargs,
) -> int:
"""Exports the dataset to SQLite

Args:
path_or_buf (``PathLike`` or ``Connection``): Either a path to a file or a SQL connection.
table_name (``str``): Name of the SQL table to write to.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`

Returns:
int: The number of characters or bytes written

Example:

```py
>>> ds.to_sql("path/to/dataset/directory", table_name='Dataset')
>>> # Also supports SQLAlchemy engines
>>> from sqlalchemy import create_engine
>>> engine = create_engine('sqlite:///my_own_db.sql', echo=False)
>>> ds.to_sql(engine, table_name='Dataset')
```
"""
# Dynamic import to avoid circular dependency
from .io.sql import SqlDatasetWriter

return SqlDatasetWriter(
self, path_or_buf, table_name=table_name, batch_size=batch_size, **sql_writer_kwargs
).write()

def _push_parquet_shards_to_hub(
self,
repo_id: str,
Expand Down
136 changes: 136 additions & 0 deletions src/datasets/io/sql.py
@@ -0,0 +1,136 @@
import contextlib
import multiprocessing
import os
from sqlite3 import Connection, connect
from typing import Optional, Union

from .. import Dataset, Features, config
from ..formatting import query_table
from ..packaged_modules.sql.sql import Sql
from ..utils import logging
from ..utils.typing import NestedDataStructureLike, PathLike
from .abc import AbstractDatasetInputStream


class SqlDatasetReader(AbstractDatasetInputStream):
def __init__(
self,
conn: NestedDataStructureLike[Union[PathLike, Connection]],
table_name: str,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs)
conn = conn if isinstance(conn, dict) else {"train": conn}
self.builder = Sql(
cache_dir=cache_dir,
conn=conn,
features=features,
table_name=table_name,
**kwargs,
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
)
return dataset


class SqlDatasetWriter:
def __init__(
self,
dataset: Dataset,
path_or_buf: Union[PathLike, Connection],
table_name: str,
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_sql_kwargs,
):

if num_proc is not None and num_proc <= 0:
raise ValueError(f"num_proc {num_proc} must be an integer > 0.")

self.dataset = dataset
self.path_or_buf = path_or_buf
self.table_name = table_name
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.encoding = "utf-8"
self.to_sql_kwargs = to_sql_kwargs

def write(self) -> int:
_ = self.to_sql_kwargs.pop("path_or_buf", None)

if isinstance(self.path_or_buf, (str, bytes, os.PathLike)):
with contextlib.closing(connect(self.path_or_buf)) as conn:
written = self._write(conn=conn, **self.to_sql_kwargs)
else:
written = self._write(conn=self.path_or_buf, **self.to_sql_kwargs)
return written

def _batch_sql(self, offset):
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
return batch.to_pandas()

def _write(self, conn: Connection, **to_sql_kwargs) -> int:
"""Writes the pyarrow table as SQL to a binary file handle.

Caller is responsible for opening and closing the handle.
"""
written = 0

if self.num_proc is None or self.num_proc == 1:
for offset in logging.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating SQL from Arrow format",
):
df = self._batch_sql(offset)
written += df.to_sql(
self.table_name, conn, **to_sql_kwargs, if_exists="replace" if offset == 0 else "append"
) or len(df)

else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for idx, df in logging.tqdm(
enumerate(
pool.imap(
self._batch_sql,
[offset for offset in range(0, num_rows, batch_size)],
)
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating SQL from Arrow format",
):
written += df.to_sql(
self.table_name, conn, **to_sql_kwargs, if_exists="replace" if idx == 0 else "append"
) or len(df)

return written
Empty file.
108 changes: 108 additions & 0 deletions src/datasets/packaged_modules/sql/sql.py
@@ -0,0 +1,108 @@
import contextlib
import itertools
from dataclasses import dataclass
from sqlite3 import Connection, connect
from typing import Dict, List, Optional, Sequence, Union

import pandas as pd
import pyarrow as pa
from typing_extensions import Literal

import datasets
import datasets.config
from datasets import NamedSplit
from datasets.features.features import require_storage_cast
from datasets.table import table_cast
from datasets.utils.typing import NestedDataStructureLike, PathLike


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class SqlConfig(datasets.BuilderConfig):
"""BuilderConfig for SQL."""

index_col: Optional[Union[int, str, List[int], List[str]]] = None
conn: Dict[Optional[NamedSplit], NestedDataStructureLike[Union[PathLike, Connection]]] = None
table_name: str = None
query: str = "SELECT * FROM `{table_name}`"
coerce_float: bool = True
params: Optional[Union[Sequence, Dict]] = None
parse_dates: Optional[Union[List, Dict]] = None
columns: Optional[List[str]] = None
chunksize: int = 10_000
features: Optional[datasets.Features] = None
encoding_errors: Optional[str] = "strict"
on_bad_lines: Literal["error", "warn", "skip"] = "error"

def __post_init__(self):
if self.table_name is None:
raise ValueError("Expected argument `table_name` to be supplied.")
if self.conn is None:
raise ValueError("Expected argument `conn` to connect to the database")

@property
def read_sql_kwargs(self):
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
read_sql_kwargs = dict(
index_col=self.index_col,
columns=self.columns,
params=self.params,
coerce_float=self.coerce_float,
parse_dates=self.parse_dates,
chunksize=self.chunksize,
)
return read_sql_kwargs


class Sql(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = SqlConfig

def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
"""We handle string, Connection, list and dicts in datafiles"""
data_files = self.config.conn
if isinstance(data_files, (str, list, tuple)):
files = data_files
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table

def _generate_tables(self, files):
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with contextlib.closing(connect(file)) as conn:
sql_file_reader = pd.read_sql(
self.config.query.format(table_name=self.config.table_name), conn, **self.config.read_sql_kwargs
)
try:
for batch_idx, df in enumerate(sql_file_reader):
# Drop index column as it is not relevant.
pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore"))
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield (file_idx, batch_idx), self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
11 changes: 11 additions & 0 deletions tests/fixtures/files.py
@@ -1,10 +1,13 @@
import contextlib
import csv
import json
import os
import tarfile
import textwrap
import zipfile
from sqlite3 import connect

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -238,6 +241,14 @@ def arrow_path(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def sql_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite")
with contextlib.closing(connect(path)) as conn:
pd.DataFrame.from_records(DATA).to_sql("TABLE_NAME", conn)
return path


@pytest.fixture(scope="session")
def csv_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset.csv")
Expand Down