From f76b87c3981ab3cadd4f68e1ca949314d3d4e468 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 3 Sep 2022 15:06:37 -0400 Subject: [PATCH 01/21] Add ability to read-write to SQL databases. --- src/datasets/arrow_dataset.py | 75 ++++++++ src/datasets/io/sql.py | 138 ++++++++++++++ src/datasets/packaged_modules/sql/__init__.py | 0 src/datasets/packaged_modules/sql/sql.py | 99 ++++++++++ tests/fixtures/files.py | 10 + tests/io/test_sql.py | 175 ++++++++++++++++++ tests/packaged_modules/test_sql.py | 35 ++++ tests/test_arrow_dataset.py | 125 +++++++++++++ 8 files changed, 657 insertions(+) create mode 100644 src/datasets/io/sql.py create mode 100644 src/datasets/packaged_modules/sql/__init__.py create mode 100644 src/datasets/packaged_modules/sql/sql.py create mode 100644 tests/io/test_sql.py create mode 100644 tests/packaged_modules/test_sql.py diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 288254cbc9c..0072253925c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -32,6 +32,7 @@ from math import ceil, floor from pathlib import Path from random import sample +from sqlite3 import Connection from typing import ( TYPE_CHECKING, Any, @@ -1057,6 +1058,48 @@ def from_text( path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ).read() + @staticmethod + def from_sql( + path_or_paths: Union[PathLike, List[PathLike]], + table_name: str, + split: Optional[NamedSplit] = None, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + **kwargs, + ): + """Create Dataset from text file(s). + + Args: + path_or_paths (path-like or list of path-like): Path(s) of the text file(s). + table_name (``str``): Name of the SQL table to read from. + split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. + features (:class:`Features`, optional): Dataset features. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. + keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. + **kwargs (additional keyword arguments): Keyword arguments to be passed to :class:`SqlConfig`. + + Returns: + :class:`Dataset` + + Example: + + ```py + >>> ds = Dataset.from_sql('path/to/dataset.sqlite') + ``` + """ + from .io.sql import SqlDatasetReader + + return SqlDatasetReader( + path_or_paths, + table_name=table_name, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + **kwargs, + ).read() + def __del__(self): if hasattr(self, "_data"): del self._data @@ -4051,6 +4094,38 @@ def to_parquet( return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write() + def to_sql( + self, + path_or_buf: Union[PathLike, Connection], + table_name: str, + batch_size: Optional[int] = None, + **sql_writer_kwargs, + ) -> int: + """Exports the dataset to SQLite + + Args: + path_or_buf (``PathLike`` or ``Connection``): Either a path to a file or a SQL connection. + table_name (``str``): Name of the SQL table to write to. + batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once. + Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + **sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql` + + Returns: + int: The number of characters or bytes written + + Example: + + ```py + >>> ds.to_sql("path/to/dataset/directory") + ``` + """ + # Dynamic import to avoid circular dependency + from .io.sql import SqlDatasetWriter + + return SqlDatasetWriter( + self, path_or_buf, table_name=table_name, batch_size=batch_size, **sql_writer_kwargs + ).write() + def _push_parquet_shards_to_hub( self, repo_id: str, diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py new file mode 100644 index 00000000000..d11a4ca70c0 --- /dev/null +++ b/src/datasets/io/sql.py @@ -0,0 +1,138 @@ +import multiprocessing +import os +from sqlite3 import Connection, connect +from typing import Optional, Union + +from .. import Dataset, Features, NamedSplit, config +from ..formatting import query_table +from ..packaged_modules.sql.sql import Sql +from ..utils import logging +from ..utils.typing import NestedDataStructureLike, PathLike +from .abc import AbstractDatasetReader + + +class SqlDatasetReader(AbstractDatasetReader): + def __init__( + self, + path_or_paths: NestedDataStructureLike[PathLike], + table_name: str, + split: Optional[NamedSplit] = None, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + **kwargs, + ): + super().__init__( + path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + ) + path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} + self.builder = Sql( + cache_dir=cache_dir, + data_files=path_or_paths, + features=features, + table_name=table_name, + **kwargs, + ) + + def read(self): + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + + # Build dataset for splits + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) + return dataset + + +class SqlDatasetWriter: + def __init__( + self, + dataset: Dataset, + path_or_buf: Union[PathLike, Connection], + table_name: str, + batch_size: Optional[int] = None, + num_proc: Optional[int] = None, + **to_sql_kwargs, + ): + + if num_proc is not None and num_proc <= 0: + raise ValueError(f"num_proc {num_proc} must be an integer > 0.") + + self.dataset = dataset + self.path_or_buf = path_or_buf + self.table_name = table_name + self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE + self.num_proc = num_proc + self.encoding = "utf-8" + self.to_sql_kwargs = to_sql_kwargs + + def write(self) -> int: + _ = self.to_sql_kwargs.pop("path_or_buf", None) + + if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): + with connect(self.path_or_buf) as conn: + written = self._write(conn=conn, **self.to_sql_kwargs) + else: + written = self._write(conn=self.path_or_buf, **self.to_sql_kwargs) + return written + + def _batch_sql(self, offset): + batch = query_table( + table=self.dataset.data, + key=slice(offset, offset + self.batch_size), + indices=self.dataset._indices, + ) + return batch.to_pandas() + + def _write(self, conn: Connection, **to_sql_kwargs) -> int: + """Writes the pyarrow table as SQL to a binary file handle. + + Caller is responsible for opening and closing the handle. + """ + written = 0 + + if self.num_proc is None or self.num_proc == 1: + for offset in logging.tqdm( + range(0, len(self.dataset), self.batch_size), + unit="ba", + disable=not logging.is_progress_bar_enabled(), + desc="Creating SQL from Arrow format", + ): + df = self._batch_sql(offset) + written += df.to_sql( + self.table_name, conn, **to_sql_kwargs, if_exists="replace" if offset == 0 else "append" + ) + + else: + num_rows, batch_size = len(self.dataset), self.batch_size + with multiprocessing.Pool(self.num_proc) as pool: + for idx, df in logging.tqdm( + enumerate( + pool.imap( + self._batch_sql, + [offset for offset in range(0, num_rows, batch_size)], + ) + ), + total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, + unit="ba", + disable=not logging.is_progress_bar_enabled(), + desc="Creating SQL from Arrow format", + ): + written += df.to_sql( + self.table_name, conn, **to_sql_kwargs, if_exists="replace" if idx == 0 else "append" + ) + + return written diff --git a/src/datasets/packaged_modules/sql/__init__.py b/src/datasets/packaged_modules/sql/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py new file mode 100644 index 00000000000..863d3bebf16 --- /dev/null +++ b/src/datasets/packaged_modules/sql/sql.py @@ -0,0 +1,99 @@ +import itertools +from dataclasses import dataclass +from sqlite3 import connect +from typing import Dict, List, Optional, Sequence, Union + +import pandas as pd +import pyarrow as pa +from typing_extensions import Literal + +import datasets +import datasets.config +from datasets.features.features import require_storage_cast +from datasets.table import table_cast + + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class SqlConfig(datasets.BuilderConfig): + """BuilderConfig for SQL.""" + + index_col: Optional[Union[int, str, List[int], List[str]]] = None + table_name: str = "Dataset" + coerce_float: bool = True + params: Optional[Union[Sequence, Dict]] = None + parse_dates: Optional[Union[List, Dict]] = None + columns: Optional[List[str]] = None + chunksize: int = 10_000 + features: Optional[datasets.Features] = None + encoding_errors: Optional[str] = "strict" + on_bad_lines: Literal["error", "warn", "skip"] = "error" + + @property + def read_sql_kwargs(self): + read_sql_kwargs = dict( + index_col=self.index_col, + columns=self.columns, + params=self.params, + coerce_float=self.coerce_float, + parse_dates=self.parse_dates, + chunksize=self.chunksize, + ) + return read_sql_kwargs + + +class Sql(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = SqlConfig + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + """We handle string, list and dicts in datafiles""" + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + data_files = dl_manager.download_and_extract(self.config.data_files) + if isinstance(data_files, (str, list, tuple)): + files = data_files + if isinstance(files, str): + files = [files] + files = [dl_manager.iter_files(file) for file in files] + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + files = [dl_manager.iter_files(file) for file in files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + # cheaper cast + pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) + else: + # more expensive cast; allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, schema) + return pa_table + + def _generate_tables(self, files): + for file_idx, file in enumerate(itertools.chain.from_iterable(files)): + with connect(file) as conn: + sql_file_reader = pd.read_sql( + f"SELECT * FROM `{self.config.table_name}`", conn, **self.config.read_sql_kwargs + ) + try: + for batch_idx, df in enumerate(sql_file_reader): + # Drop index column as it is not relevant. + pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore")) + # Uncomment for debugging (will print the Arrow table size and elements) + # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield (file_idx, batch_idx), self._cast_table(pa_table) + except ValueError as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index a6502468549..6bdd6aea72a 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -4,7 +4,9 @@ import tarfile import textwrap import zipfile +from sqlite3 import connect +import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -238,6 +240,14 @@ def arrow_path(tmp_path_factory): return path +@pytest.fixture(scope="session") +def sql_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite") + with connect(path) as conn: + pd.DataFrame.from_records(DATA).to_sql("TABLE_NAME", conn) + return path + + @pytest.fixture(scope="session") def csv_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.csv") diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py new file mode 100644 index 00000000000..c37ca16b80f --- /dev/null +++ b/tests/io/test_sql.py @@ -0,0 +1,175 @@ +import os +from sqlite3 import connect + +import pandas as pd +import pytest + +from datasets import Dataset, DatasetDict, Features, NamedSplit, Value +from datasets.io.sql import SqlDatasetReader, SqlDatasetWriter + +from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases + + +SQLITE_TABLE_NAME = "TABLE_NAME" + + +def _check_sql_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = SqlDatasetReader( + sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, keep_in_memory=keep_in_memory + ).read() + _check_sql_dataset(dataset, expected_features) + + +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_sql_features(features, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, features=features, cache_dir=cache_dir).read() + _check_sql_dataset(dataset, expected_features) + + +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_dataset_from_sql_split(split, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, split=split).read() + _check_sql_dataset(dataset, expected_features) + assert dataset.split == split if split else "train" + + +@pytest.mark.parametrize("path_type", [str, list]) +def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): + if issubclass(path_type, str): + path = sql_path + elif issubclass(path_type, list): + path = [sql_path] + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = SqlDatasetReader(path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + _check_sql_dataset(dataset, expected_features) + + +def _check_sql_datasetdict(dataset_dict, expected_features, splits=("train",)): + assert isinstance(dataset_dict, DatasetDict) + for split in splits: + dataset = dataset_dict[split] + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_sql_datasetdict_reader_keep_in_memory(keep_in_memory, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = SqlDatasetReader( + {"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, keep_in_memory=keep_in_memory + ).read() + _check_sql_datasetdict(dataset, expected_features) + + +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_sql_datasetdict_reader_features(features, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = SqlDatasetReader( + {"train": sql_path}, table_name=SQLITE_TABLE_NAME, features=features, cache_dir=cache_dir + ).read() + _check_sql_datasetdict(dataset, expected_features) + + +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_sql_datasetdict_reader_split(split, sql_path, tmp_path): + if split: + path = {split: sql_path} + else: + split = "train" + path = {"train": sql_path, "test": sql_path} + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = SqlDatasetReader(path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + _check_sql_datasetdict(dataset, expected_features, splits=list(path.keys())) + assert all(dataset[split].split == split for split in path.keys()) + + +def iter_sql_file(sql_path): + with connect(sql_path) as conn: + return pd.read_sql(f"SELECT * FROM {SQLITE_TABLE_NAME}", conn).drop("index", axis=1, errors="ignore") + + +def test_dataset_to_sql(sql_path, tmp_path): + cache_dir = tmp_path / "cache" + output_sql = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader({"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset["train"], output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=1).write() + + original_sql = iter_sql_file(sql_path) + expected_sql = iter_sql_file(output_sql) + + for row1, row2 in zip(original_sql, expected_sql): + assert row1 == row2 + + +def test_dataset_to_sql_multiproc(sql_path, tmp_path): + cache_dir = tmp_path / "cache" + output_sql = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader({"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset["train"], output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=2).write() + + original_sql = iter_sql_file(sql_path) + expected_sql = iter_sql_file(output_sql) + + for row1, row2 in zip(original_sql, expected_sql): + assert row1 == row2 + + +def test_dataset_to_sql_invalidproc(sql_path, tmp_path): + cache_dir = tmp_path / "cache" + output_sql = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader({"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + with pytest.raises(ValueError): + SqlDatasetWriter(dataset["train"], output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=0) diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py new file mode 100644 index 00000000000..9614ea00356 --- /dev/null +++ b/tests/packaged_modules/test_sql.py @@ -0,0 +1,35 @@ +from sqlite3 import connect + +import pandas as pd +import pyarrow as pa +import pytest + +from datasets import ClassLabel, Features, Value +from datasets.packaged_modules.sql.sql import Sql + + +@pytest.fixture +def sqlite_file(tmp_path): + filename = str(tmp_path / "malformed_file.sqlite") + with connect(filename) as conn: + pd.DataFrame.from_dict({"header1": [1, 10], "header2": [2, 20], "label": ["good", "bad"]}).to_sql( + "TABLE_NAME", con=conn + ) + return filename + + +def test_csv_cast_label(sqlite_file): + table_name = "TABLE_NAME" + with connect(sqlite_file) as conn: + labels = pd.read_sql(f"SELECT * FROM {table_name}", conn)["label"].tolist() + sql = Sql( + features=Features( + {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} + ), + table_name=table_name, + ) + generator = sql._generate_tables([[sqlite_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])() + generated_content = pa_table.to_pydict()["label"] + assert generated_content == [ClassLabel(names=["good", "bad"]).str2int(label) for label in labels] diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 7184fe25255..ed2aff3e910 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -7,6 +7,7 @@ import tempfile from functools import partial from pathlib import Path +from sqlite3 import connect from unittest import TestCase from unittest.mock import patch @@ -60,6 +61,9 @@ ) +SQLITE_TABLE_NAME = "TABLE_NAME" + + class Unpicklable: def __getstate__(self): raise pickle.PicklingError() @@ -2015,6 +2019,65 @@ def test_to_parquet(self, in_memory): self.assertEqual(parquet_dset.shape, dset.shape) self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + def test_to_sql(self, in_memory): + table_name = "TABLE_NAME" + + def read_sql(filepath): + + with connect(filepath) as conn: + return pd.read_sql(f"SELECT * FROM {table_name}", conn).drop("index", axis=1, errors="ignore") + + with tempfile.TemporaryDirectory() as tmp_dir: + # File path argument + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.sql") + _ = dset.to_sql(path_or_buf=file_path, table_name=table_name) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) + sql_dset = read_sql(file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # File buffer argument + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_buffer.sql") + with connect(file_path) as buffer: + _ = dset.to_sql(path_or_buf=buffer, table_name=table_name) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) + sql_dset = read_sql(file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # After a select/shuffle transform + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset = dset.select(range(0, len(dset), 2)).shuffle() + file_path = os.path.join(tmp_dir, "test_path.sql") + _ = dset.to_sql(path_or_buf=file_path, table_name=table_name) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) + sql_dset = read_sql(file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # With array features + with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.sql") + _ = dset.to_sql(path_or_buf=file_path, table_name=table_name) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) + sql_dset = read_sql(file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + def test_train_test_split(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir) as dset: @@ -3101,6 +3164,68 @@ def test_dataset_from_text_path_type(path_type, text_path, tmp_path): _check_text_dataset(dataset, expected_features) +def _check_sql_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = Dataset.from_sql( + sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, keep_in_memory=keep_in_memory + ) + _check_sql_dataset(dataset, expected_features) + + +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_sql_features(features, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = Dataset.from_sql(sql_path, table_name=SQLITE_TABLE_NAME, features=features, cache_dir=cache_dir) + _check_sql_dataset(dataset, expected_features) + + +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_dataset_from_sql_split(split, sql_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = Dataset.from_sql(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, split=split) + _check_sql_dataset(dataset, expected_features) + assert dataset.split == split if split else "train" + + +@pytest.mark.parametrize("path_type", [str, list]) +def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): + if issubclass(path_type, str): + path = sql_path + elif issubclass(path_type, list): + path = [sql_path] + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = Dataset.from_sql(path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir) + _check_sql_dataset(dataset, expected_features) + + def test_dataset_to_json(dataset, tmp_path): file_path = tmp_path / "test_path.jsonl" bytes_written = dataset.to_json(path_or_buf=file_path) From 5747ad6b2c65b754c456f45e421de6aa7dcb1a13 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 3 Sep 2022 16:44:52 -0400 Subject: [PATCH 02/21] Fix issue where pandas<1.4.0 doesn't return the number of rows --- src/datasets/io/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index d11a4ca70c0..d127bd09e8b 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -114,7 +114,7 @@ def _write(self, conn: Connection, **to_sql_kwargs) -> int: df = self._batch_sql(offset) written += df.to_sql( self.table_name, conn, **to_sql_kwargs, if_exists="replace" if offset == 0 else "append" - ) + ) or len(df) else: num_rows, batch_size = len(self.dataset), self.batch_size @@ -133,6 +133,6 @@ def _write(self, conn: Connection, **to_sql_kwargs) -> int: ): written += df.to_sql( self.table_name, conn, **to_sql_kwargs, if_exists="replace" if idx == 0 else "append" - ) + ) or len(df) return written From 3811a5e5260f5a2aebaa5dab3f951a699aee7b7b Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 3 Sep 2022 17:32:50 -0400 Subject: [PATCH 03/21] Fix issue where connections were not closed properly --- src/datasets/io/sql.py | 3 ++- src/datasets/packaged_modules/sql/sql.py | 25 ++++++++++++------------ tests/fixtures/files.py | 3 ++- tests/io/test_sql.py | 3 ++- tests/packaged_modules/test_sql.py | 5 +++-- tests/test_arrow_dataset.py | 5 +++-- 6 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index d127bd09e8b..df188e3bae5 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -1,3 +1,4 @@ +import contextlib import multiprocessing import os from sqlite3 import Connection, connect @@ -83,7 +84,7 @@ def write(self) -> int: _ = self.to_sql_kwargs.pop("path_or_buf", None) if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): - with connect(self.path_or_buf) as conn: + with contextlib.closing(connect(self.path_or_buf)) as conn: written = self._write(conn=conn, **self.to_sql_kwargs) else: written = self._write(conn=self.path_or_buf, **self.to_sql_kwargs) diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 863d3bebf16..7f5058bf471 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -1,3 +1,4 @@ +import contextlib import itertools from dataclasses import dataclass from sqlite3 import connect @@ -82,18 +83,18 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - with connect(file) as conn: + with contextlib.closing(connect(file)) as conn: sql_file_reader = pd.read_sql( f"SELECT * FROM `{self.config.table_name}`", conn, **self.config.read_sql_kwargs ) - try: - for batch_idx, df in enumerate(sql_file_reader): - # Drop index column as it is not relevant. - pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore")) - # Uncomment for debugging (will print the Arrow table size and elements) - # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") - # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield (file_idx, batch_idx), self._cast_table(pa_table) - except ValueError as e: - logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") - raise + try: + for batch_idx, df in enumerate(sql_file_reader): + # Drop index column as it is not relevant. + pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore")) + # Uncomment for debugging (will print the Arrow table size and elements) + # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield (file_idx, batch_idx), self._cast_table(pa_table) + except ValueError as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 6bdd6aea72a..e9d96d50706 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -1,3 +1,4 @@ +import contextlib import csv import json import os @@ -243,7 +244,7 @@ def arrow_path(tmp_path_factory): @pytest.fixture(scope="session") def sql_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite") - with connect(path) as conn: + with contextlib.closing(connect(path)) as conn: pd.DataFrame.from_records(DATA).to_sql("TABLE_NAME", conn) return path diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py index c37ca16b80f..ca952640fa9 100644 --- a/tests/io/test_sql.py +++ b/tests/io/test_sql.py @@ -1,3 +1,4 @@ +import contextlib import os from sqlite3 import connect @@ -137,7 +138,7 @@ def test_sql_datasetdict_reader_split(split, sql_path, tmp_path): def iter_sql_file(sql_path): - with connect(sql_path) as conn: + with contextlib.closing(connect(sql_path)) as conn: return pd.read_sql(f"SELECT * FROM {SQLITE_TABLE_NAME}", conn).drop("index", axis=1, errors="ignore") diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py index 9614ea00356..c2768788b65 100644 --- a/tests/packaged_modules/test_sql.py +++ b/tests/packaged_modules/test_sql.py @@ -1,3 +1,4 @@ +import contextlib from sqlite3 import connect import pandas as pd @@ -11,7 +12,7 @@ @pytest.fixture def sqlite_file(tmp_path): filename = str(tmp_path / "malformed_file.sqlite") - with connect(filename) as conn: + with contextlib.closing(connect(filename)) as conn: pd.DataFrame.from_dict({"header1": [1, 10], "header2": [2, 20], "label": ["good", "bad"]}).to_sql( "TABLE_NAME", con=conn ) @@ -20,7 +21,7 @@ def sqlite_file(tmp_path): def test_csv_cast_label(sqlite_file): table_name = "TABLE_NAME" - with connect(sqlite_file) as conn: + with contextlib.closing(connect(sqlite_file)) as conn: labels = pd.read_sql(f"SELECT * FROM {table_name}", conn)["label"].tolist() sql = Sql( features=Features( diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index ed2aff3e910..a9c06dc461d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1,3 +1,4 @@ +import contextlib import copy import itertools import json @@ -2024,7 +2025,7 @@ def test_to_sql(self, in_memory): def read_sql(filepath): - with connect(filepath) as conn: + with contextlib.closing(connect(filepath)) as conn: return pd.read_sql(f"SELECT * FROM {table_name}", conn).drop("index", axis=1, errors="ignore") with tempfile.TemporaryDirectory() as tmp_dir: @@ -2043,7 +2044,7 @@ def read_sql(filepath): # File buffer argument with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: file_path = os.path.join(tmp_dir, "test_buffer.sql") - with connect(file_path) as buffer: + with contextlib.closing(connect(file_path)) as buffer: _ = dset.to_sql(path_or_buf=buffer, table_name=table_name) self.assertTrue(os.path.isfile(file_path)) From 27d56b72b2aa2e60ad0a657b5e5a447f51c7a496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Branchaud-Charron?= Date: Mon, 5 Sep 2022 15:30:05 -0400 Subject: [PATCH 04/21] Apply suggestions from code review Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- tests/io/test_sql.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py index ca952640fa9..9b7ff318f49 100644 --- a/tests/io/test_sql.py +++ b/tests/io/test_sql.py @@ -46,7 +46,6 @@ def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): ) def test_dataset_from_sql_features(features, sql_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( @@ -111,7 +110,6 @@ def test_sql_datasetdict_reader_keep_in_memory(keep_in_memory, sql_path, tmp_pat ) def test_sql_datasetdict_reader_features(features, sql_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( From e9af3cf3c67e708cf27868cde3a8ebd042186f02 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Mon, 5 Sep 2022 16:11:26 -0400 Subject: [PATCH 05/21] Change according to reviews --- docs/source/loading.mdx | 9 +++++++++ src/datasets/arrow_dataset.py | 10 +++++++--- src/datasets/packaged_modules/__init__.py | 2 ++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index 5c8284d4d05..0cffdd9d80c 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -196,6 +196,15 @@ To load remote Parquet files via HTTP, pass the URLs instead: >>> wiki = load_dataset("parquet", data_files=data_files, split="train") ``` +### SQLite + +Datasets stored as a Table in a SQLite file can be loaded with: + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset('sql', data_files={'train': 'sqlite_file.db'}, table_name='Dataset') +``` + ## In-memory data 🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 0072253925c..b1e662d382b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1068,10 +1068,10 @@ def from_sql( keep_in_memory: bool = False, **kwargs, ): - """Create Dataset from text file(s). + """Create Dataset from SQLite file(s). Args: - path_or_paths (path-like or list of path-like): Path(s) of the text file(s). + path_or_paths (path-like or list of path-like): Path(s) of the SQLite file(s). table_name (``str``): Name of the SQL table to read from. split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. features (:class:`Features`, optional): Dataset features. @@ -4116,7 +4116,11 @@ def to_sql( Example: ```py - >>> ds.to_sql("path/to/dataset/directory") + >>> ds.to_sql("path/to/dataset/directory", table_name='Dataset') + >>> # Also supports SQLAlchemy engines + >>> from sqlalchemy import create_engine + >>> engine = create_engine('sqlite:///my_own_db.sql', echo=False) + >>> ds.to_sql(engine, table_name='Dataset') ``` """ # Dynamic import to avoid circular dependency diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index deabca1f35a..a2e42f3b7bd 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -9,6 +9,7 @@ from .json import json from .pandas import pandas from .parquet import parquet +from .sql import sql from .text import text @@ -32,6 +33,7 @@ def _hash_python_lines(lines: List[str]) -> str: "pandas": (pandas.__name__, _hash_python_lines(inspect.getsource(pandas).splitlines())), "parquet": (parquet.__name__, _hash_python_lines(inspect.getsource(parquet).splitlines())), "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), + "sql": (sql.__name__, _hash_python_lines(inspect.getsource(sql).splitlines())), "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), } From 87eeb1a4b4b944809f301b44e94a434faabdbe6a Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 17 Sep 2022 09:15:19 -0400 Subject: [PATCH 06/21] Change according to reviews --- docs/source/loading.mdx | 4 ++-- src/datasets/io/sql.py | 8 ++++---- src/datasets/packaged_modules/__init__.py | 2 -- src/datasets/packaged_modules/sql/sql.py | 22 +++++++++++++++------- tests/packaged_modules/test_sql.py | 21 ++++++++++++++++++++- 5 files changed, 41 insertions(+), 16 deletions(-) diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index 0cffdd9d80c..4daf2c93b48 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -201,8 +201,8 @@ To load remote Parquet files via HTTP, pass the URLs instead: Datasets stored as a Table in a SQLite file can be loaded with: ```py ->>> from datasets import load_dataset ->>> dataset = load_dataset('sql', data_files={'train': 'sqlite_file.db'}, table_name='Dataset') +>>> from datasets import Dataset +>>> dataset = Dataset.from_sql('sqlite_file.db' table_name='Dataset') ``` ## In-memory data diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index df188e3bae5..dbd62670f66 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -15,7 +15,7 @@ class SqlDatasetReader(AbstractDatasetReader): def __init__( self, - path_or_paths: NestedDataStructureLike[PathLike], + conn: NestedDataStructureLike[Union[PathLike, Connection]], table_name: str, split: Optional[NamedSplit] = None, features: Optional[Features] = None, @@ -24,12 +24,12 @@ def __init__( **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + conn, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ) - path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} + conn = conn if isinstance(conn, dict) else {self.split: conn} self.builder = Sql( cache_dir=cache_dir, - data_files=path_or_paths, + conn=conn, features=features, table_name=table_name, **kwargs, diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index a2e42f3b7bd..deabca1f35a 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -9,7 +9,6 @@ from .json import json from .pandas import pandas from .parquet import parquet -from .sql import sql from .text import text @@ -33,7 +32,6 @@ def _hash_python_lines(lines: List[str]) -> str: "pandas": (pandas.__name__, _hash_python_lines(inspect.getsource(pandas).splitlines())), "parquet": (parquet.__name__, _hash_python_lines(inspect.getsource(parquet).splitlines())), "text": (text.__name__, _hash_python_lines(inspect.getsource(text).splitlines())), - "sql": (sql.__name__, _hash_python_lines(inspect.getsource(sql).splitlines())), "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), } diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 7f5058bf471..15c7308d84b 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -1,7 +1,7 @@ import contextlib import itertools from dataclasses import dataclass -from sqlite3 import connect +from sqlite3 import Connection, connect from typing import Dict, List, Optional, Sequence, Union import pandas as pd @@ -10,8 +10,10 @@ import datasets import datasets.config +from datasets import NamedSplit from datasets.features.features import require_storage_cast from datasets.table import table_cast +from datasets.utils.typing import NestedDataStructureLike, PathLike logger = datasets.utils.logging.get_logger(__name__) @@ -22,7 +24,9 @@ class SqlConfig(datasets.BuilderConfig): """BuilderConfig for SQL.""" index_col: Optional[Union[int, str, List[int], List[str]]] = None - table_name: str = "Dataset" + conn: Dict[Optional[NamedSplit], NestedDataStructureLike[Union[PathLike, Connection]]] = None + table_name: str = None + query: str = "SELECT * FROM `{table_name}`" coerce_float: bool = True params: Optional[Union[Sequence, Dict]] = None parse_dates: Optional[Union[List, Dict]] = None @@ -32,6 +36,12 @@ class SqlConfig(datasets.BuilderConfig): encoding_errors: Optional[str] = "strict" on_bad_lines: Literal["error", "warn", "skip"] = "error" + def __post_init__(self): + if self.table_name is None: + raise ValueError("Expected argument `table_name` to be supplied.") + if self.conn is None: + raise ValueError("Expected argument `conn` to connect to the database") + @property def read_sql_kwargs(self): read_sql_kwargs = dict( @@ -52,10 +62,8 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): - """We handle string, list and dicts in datafiles""" - if not self.config.data_files: - raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") - data_files = dl_manager.download_and_extract(self.config.data_files) + """We handle string, Connection, list and dicts in datafiles""" + data_files = self.config.conn if isinstance(data_files, (str, list, tuple)): files = data_files if isinstance(files, str): @@ -85,7 +93,7 @@ def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): with contextlib.closing(connect(file)) as conn: sql_file_reader = pd.read_sql( - f"SELECT * FROM `{self.config.table_name}`", conn, **self.config.read_sql_kwargs + self.config.query.format(table_name=self.config.table_name), conn, **self.config.read_sql_kwargs ) try: for batch_idx, df in enumerate(sql_file_reader): diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py index c2768788b65..65261ee63db 100644 --- a/tests/packaged_modules/test_sql.py +++ b/tests/packaged_modules/test_sql.py @@ -19,11 +19,12 @@ def sqlite_file(tmp_path): return filename -def test_csv_cast_label(sqlite_file): +def test_sql_cast_label(sqlite_file): table_name = "TABLE_NAME" with contextlib.closing(connect(sqlite_file)) as conn: labels = pd.read_sql(f"SELECT * FROM {table_name}", conn)["label"].tolist() sql = Sql( + conn=sqlite_file, features=Features( {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} ), @@ -34,3 +35,21 @@ def test_csv_cast_label(sqlite_file): assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])() generated_content = pa_table.to_pydict()["label"] assert generated_content == [ClassLabel(names=["good", "bad"]).str2int(label) for label in labels] + + +def test_missing_args(sqlite_file): + with pytest.raises(ValueError, match="Expected argument `table_name`"): + _ = Sql( + conn=sqlite_file, + features=Features( + {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} + ), + ) + + with pytest.raises(ValueError, match="Expected argument `conn`"): + _ = Sql( + features=Features( + {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} + ), + table_name="TABLE_NAME", + ) From c3597c997af596f2a181389e54aceb364c62af49 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sat, 17 Sep 2022 10:47:48 -0400 Subject: [PATCH 07/21] Inherit from AbstractDatasetInputStream in SqlDatasetReader --- src/datasets/arrow_dataset.py | 5 +-- src/datasets/io/sql.py | 15 +++---- tests/io/test_sql.py | 82 +++-------------------------------- tests/test_arrow_dataset.py | 9 ---- 4 files changed, 14 insertions(+), 97 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f3a87a3c03f..bc5719e656f 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1096,7 +1096,6 @@ def from_text( def from_sql( path_or_paths: Union[PathLike, List[PathLike]], table_name: str, - split: Optional[NamedSplit] = None, features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, @@ -1105,9 +1104,8 @@ def from_sql( """Create Dataset from SQLite file(s). Args: - path_or_paths (path-like or list of path-like): Path(s) of the SQLite file(s). + path_or_paths (path-like or list of path-like): Path(s) of the SQLite URI(s). table_name (``str``): Name of the SQL table to read from. - split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. features (:class:`Features`, optional): Dataset features. cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. @@ -1127,7 +1125,6 @@ def from_sql( return SqlDatasetReader( path_or_paths, table_name=table_name, - split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index dbd62670f66..1669f5333b1 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -4,29 +4,26 @@ from sqlite3 import Connection, connect from typing import Optional, Union -from .. import Dataset, Features, NamedSplit, config +from .. import Dataset, Features, config from ..formatting import query_table from ..packaged_modules.sql.sql import Sql from ..utils import logging from ..utils.typing import NestedDataStructureLike, PathLike -from .abc import AbstractDatasetReader +from .abc import AbstractDatasetInputStream -class SqlDatasetReader(AbstractDatasetReader): +class SqlDatasetReader(AbstractDatasetInputStream): def __init__( self, conn: NestedDataStructureLike[Union[PathLike, Connection]], table_name: str, - split: Optional[NamedSplit] = None, features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, **kwargs, ): - super().__init__( - conn, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs - ) - conn = conn if isinstance(conn, dict) else {self.split: conn} + super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs) + conn = conn if isinstance(conn, dict) else {"train": conn} self.builder = Sql( cache_dir=cache_dir, conn=conn, @@ -53,7 +50,7 @@ def read(self): # Build dataset for splits dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory ) return dataset diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py index 9b7ff318f49..244517a6fdb 100644 --- a/tests/io/test_sql.py +++ b/tests/io/test_sql.py @@ -5,7 +5,7 @@ import pandas as pd import pytest -from datasets import Dataset, DatasetDict, Features, NamedSplit, Value +from datasets import Dataset, Features, Value from datasets.io.sql import SqlDatasetReader, SqlDatasetWriter from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases @@ -55,15 +55,6 @@ def test_dataset_from_sql_features(features, sql_path, tmp_path): _check_sql_dataset(dataset, expected_features) -@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) -def test_dataset_from_sql_split(split, sql_path, tmp_path): - cache_dir = tmp_path / "cache" - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, split=split).read() - _check_sql_dataset(dataset, expected_features) - assert dataset.split == split if split else "train" - - @pytest.mark.parametrize("path_type", [str, list]) def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): if issubclass(path_type, str): @@ -76,65 +67,6 @@ def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): _check_sql_dataset(dataset, expected_features) -def _check_sql_datasetdict(dataset_dict, expected_features, splits=("train",)): - assert isinstance(dataset_dict, DatasetDict) - for split in splits: - dataset = dataset_dict[split] - assert dataset.num_rows == 4 - assert dataset.num_columns == 3 - assert dataset.column_names == ["col_1", "col_2", "col_3"] - for feature, expected_dtype in expected_features.items(): - assert dataset.features[feature].dtype == expected_dtype - - -@pytest.mark.parametrize("keep_in_memory", [False, True]) -def test_sql_datasetdict_reader_keep_in_memory(keep_in_memory, sql_path, tmp_path): - cache_dir = tmp_path / "cache" - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): - dataset = SqlDatasetReader( - {"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, keep_in_memory=keep_in_memory - ).read() - _check_sql_datasetdict(dataset, expected_features) - - -@pytest.mark.parametrize( - "features", - [ - None, - {"col_1": "string", "col_2": "int64", "col_3": "float64"}, - {"col_1": "string", "col_2": "string", "col_3": "string"}, - {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, - {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, - ], -) -def test_sql_datasetdict_reader_features(features, sql_path, tmp_path): - cache_dir = tmp_path / "cache" - default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - expected_features = features.copy() if features else default_expected_features - features = ( - Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None - ) - dataset = SqlDatasetReader( - {"train": sql_path}, table_name=SQLITE_TABLE_NAME, features=features, cache_dir=cache_dir - ).read() - _check_sql_datasetdict(dataset, expected_features) - - -@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) -def test_sql_datasetdict_reader_split(split, sql_path, tmp_path): - if split: - path = {split: sql_path} - else: - split = "train" - path = {"train": sql_path, "test": sql_path} - cache_dir = tmp_path / "cache" - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - dataset = SqlDatasetReader(path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() - _check_sql_datasetdict(dataset, expected_features, splits=list(path.keys())) - assert all(dataset[split].split == split for split in path.keys()) - - def iter_sql_file(sql_path): with contextlib.closing(connect(sql_path)) as conn: return pd.read_sql(f"SELECT * FROM {SQLITE_TABLE_NAME}", conn).drop("index", axis=1, errors="ignore") @@ -143,8 +75,8 @@ def iter_sql_file(sql_path): def test_dataset_to_sql(sql_path, tmp_path): cache_dir = tmp_path / "cache" output_sql = os.path.join(cache_dir, "tmp.sql") - dataset = SqlDatasetReader({"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() - SqlDatasetWriter(dataset["train"], output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=1).write() + dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset, output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=1).write() original_sql = iter_sql_file(sql_path) expected_sql = iter_sql_file(output_sql) @@ -156,8 +88,8 @@ def test_dataset_to_sql(sql_path, tmp_path): def test_dataset_to_sql_multiproc(sql_path, tmp_path): cache_dir = tmp_path / "cache" output_sql = os.path.join(cache_dir, "tmp.sql") - dataset = SqlDatasetReader({"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() - SqlDatasetWriter(dataset["train"], output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=2).write() + dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset, output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=2).write() original_sql = iter_sql_file(sql_path) expected_sql = iter_sql_file(output_sql) @@ -169,6 +101,6 @@ def test_dataset_to_sql_multiproc(sql_path, tmp_path): def test_dataset_to_sql_invalidproc(sql_path, tmp_path): cache_dir = tmp_path / "cache" output_sql = os.path.join(cache_dir, "tmp.sql") - dataset = SqlDatasetReader({"train": sql_path}, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() with pytest.raises(ValueError): - SqlDatasetWriter(dataset["train"], output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=0) + SqlDatasetWriter(dataset, output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=0) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index f44a60763c0..0ba669f006f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3337,15 +3337,6 @@ def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): _check_sql_dataset(dataset, expected_features) -@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) -def test_dataset_from_sql_split(split, sql_path, tmp_path): - cache_dir = tmp_path / "cache" - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - dataset = Dataset.from_sql(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, split=split) - _check_sql_dataset(dataset, expected_features) - assert dataset.split == split if split else "train" - - @pytest.mark.parametrize("path_type", [str, list]) def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): if issubclass(path_type, str): From 61cf29a14bb075cafccc852054985366ee67c134 Mon Sep 17 00:00:00 2001 From: Dref360 Date: Sun, 18 Sep 2022 10:49:11 -0400 Subject: [PATCH 08/21] Revert typing in SQLDatasetReader as we do not support Connexion --- src/datasets/io/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 1669f5333b1..1fca23041e5 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -15,7 +15,7 @@ class SqlDatasetReader(AbstractDatasetInputStream): def __init__( self, - conn: NestedDataStructureLike[Union[PathLike, Connection]], + conn: NestedDataStructureLike[PathLike], table_name: str, features: Optional[Features] = None, cache_dir: str = None, From 453f2c3f8d5d5155ce64b802a081565966f49408 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 14:26:06 +0200 Subject: [PATCH 09/21] Align API with Pandas/Daskk --- src/datasets/arrow_dataset.py | 44 ++++++----- src/datasets/config.py | 3 + src/datasets/io/sql.py | 67 ++++++++--------- src/datasets/packaged_modules/sql/sql.py | 95 +++++++++++------------- 4 files changed, 97 insertions(+), 112 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bc5719e656f..abf73b51820 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -33,7 +33,6 @@ from math import ceil, floor from pathlib import Path from random import sample -from sqlite3 import Connection from typing import ( TYPE_CHECKING, Any, @@ -111,6 +110,10 @@ if TYPE_CHECKING: + import sqlite3 + + import sqlalchemy + from .dataset_dict import DatasetDict logger = logging.get_logger(__name__) @@ -1094,18 +1097,18 @@ def from_text( @staticmethod def from_sql( - path_or_paths: Union[PathLike, List[PathLike]], - table_name: str, + sql: Union[str, "sqlalchemy.sql.Selectable"], + con: str, features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, **kwargs, ): - """Create Dataset from SQLite file(s). + """Create Dataset from SQL query or database table. Args: - path_or_paths (path-like or list of path-like): Path(s) of the SQLite URI(s). - table_name (``str``): Name of the SQL table to read from. + sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name. + con (`str`): A connection URI string used to instantiate a database connection. features (:class:`Features`, optional): Dataset features. cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. @@ -1117,14 +1120,14 @@ def from_sql( Example: ```py - >>> ds = Dataset.from_sql('path/to/dataset.sqlite') + >>> ds = Dataset.from_sql("test_data", "postgres:///db_name"s) ``` """ from .io.sql import SqlDatasetReader return SqlDatasetReader( - path_or_paths, - table_name=table_name, + sql, + con, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, @@ -4127,39 +4130,34 @@ def to_parquet( def to_sql( self, - path_or_buf: Union[PathLike, Connection], - table_name: str, + name: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], batch_size: Optional[int] = None, **sql_writer_kwargs, ) -> int: - """Exports the dataset to SQLite + """Exports the dataset to a SQL database. Args: - path_or_buf (``PathLike`` or ``Connection``): Either a path to a file or a SQL connection. - table_name (``str``): Name of the SQL table to write to. + name (`str`): Name of SQL table. + con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`): + A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database. batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once. Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. **sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql` Returns: - int: The number of characters or bytes written + int: The number of records written. Example: ```py - >>> ds.to_sql("path/to/dataset/directory", table_name='Dataset') - >>> # Also supports SQLAlchemy engines - >>> from sqlalchemy import create_engine - >>> engine = create_engine('sqlite:///my_own_db.sql', echo=False) - >>> ds.to_sql(engine, table_name='Dataset') + >>> ds.to_sql("data", "sqlite:///my_own_db.sql") ``` """ # Dynamic import to avoid circular dependency from .io.sql import SqlDatasetWriter - return SqlDatasetWriter( - self, path_or_buf, table_name=table_name, batch_size=batch_size, **sql_writer_kwargs - ).write() + return SqlDatasetWriter(self, name, con, batch_size=batch_size, **sql_writer_kwargs).write() def _push_parquet_shards_to_hub( self, diff --git a/src/datasets/config.py b/src/datasets/config.py index 2bd5419cbe3..e98166a8676 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -125,6 +125,9 @@ logger.info("Disabling Apache Beam because USE_BEAM is set to False") +# Optional tools for data loading +SQLALCHEMY_AVAILABLE = importlib.util.find_spec("sqlalchemy") is not None + # Optional tools for feature decoding PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 1fca23041e5..c2f2fc7b6a9 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -1,34 +1,35 @@ -import contextlib import multiprocessing -import os -from sqlite3 import Connection, connect -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from .. import Dataset, Features, config from ..formatting import query_table from ..packaged_modules.sql.sql import Sql from ..utils import logging -from ..utils.typing import NestedDataStructureLike, PathLike from .abc import AbstractDatasetInputStream +if TYPE_CHECKING: + import sqlite3 + + import sqlalchemy + + class SqlDatasetReader(AbstractDatasetInputStream): def __init__( self, - conn: NestedDataStructureLike[PathLike], - table_name: str, + sql: Union[str, "sqlalchemy.sql.Selectable"], + con: str, features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, **kwargs, ): super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs) - conn = conn if isinstance(conn, dict) else {"train": conn} self.builder = Sql( cache_dir=cache_dir, - conn=conn, features=features, - table_name=table_name, + sql=sql, + con=con, **kwargs, ) @@ -59,8 +60,8 @@ class SqlDatasetWriter: def __init__( self, dataset: Dataset, - path_or_buf: Union[PathLike, Connection], - table_name: str, + name: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], batch_size: Optional[int] = None, num_proc: Optional[int] = None, **to_sql_kwargs, @@ -70,35 +71,35 @@ def __init__( raise ValueError(f"num_proc {num_proc} must be an integer > 0.") self.dataset = dataset - self.path_or_buf = path_or_buf - self.table_name = table_name + self.name = name + self.con = con self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE self.num_proc = num_proc - self.encoding = "utf-8" self.to_sql_kwargs = to_sql_kwargs def write(self) -> int: - _ = self.to_sql_kwargs.pop("path_or_buf", None) + _ = self.to_sql_kwargs.pop("sql", None) + _ = self.to_sql_kwargs.pop("con", None) - if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): - with contextlib.closing(connect(self.path_or_buf)) as conn: - written = self._write(conn=conn, **self.to_sql_kwargs) - else: - written = self._write(conn=self.path_or_buf, **self.to_sql_kwargs) + written = self._write(**self.to_sql_kwargs) return written - def _batch_sql(self, offset): + def _batch_sql(self, args): + offset, to_sql_kwargs = args + to_sql_kwargs = {**to_sql_kwargs, "if_exists": "append"} if offset > 0 else to_sql_kwargs batch = query_table( table=self.dataset.data, key=slice(offset, offset + self.batch_size), indices=self.dataset._indices, ) - return batch.to_pandas() + df = batch.to_pandas() + num_rows = df.to_sql(self.name, self.con, **to_sql_kwargs) + return num_rows or len(df) - def _write(self, conn: Connection, **to_sql_kwargs) -> int: - """Writes the pyarrow table as SQL to a binary file handle. + def _write(self, **to_sql_kwargs) -> int: + """Writes the pyarrow table as SQL to a database. - Caller is responsible for opening and closing the handle. + Caller is responsible for opening and closing the SQL connection. """ written = 0 @@ -109,19 +110,15 @@ def _write(self, conn: Connection, **to_sql_kwargs) -> int: disable=not logging.is_progress_bar_enabled(), desc="Creating SQL from Arrow format", ): - df = self._batch_sql(offset) - written += df.to_sql( - self.table_name, conn, **to_sql_kwargs, if_exists="replace" if offset == 0 else "append" - ) or len(df) - + written += self._batch_sql((offset, to_sql_kwargs)) else: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: - for idx, df in logging.tqdm( + for num_rows in logging.tqdm( enumerate( pool.imap( self._batch_sql, - [offset for offset in range(0, num_rows, batch_size)], + [(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)], ) ), total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, @@ -129,8 +126,6 @@ def _write(self, conn: Connection, **to_sql_kwargs) -> int: disable=not logging.is_progress_bar_enabled(), desc="Creating SQL from Arrow format", ): - written += df.to_sql( - self.table_name, conn, **to_sql_kwargs, if_exists="replace" if idx == 0 else "append" - ) or len(df) + written += num_rows return written diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 15c7308d84b..1e2d9920453 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -1,46 +1,60 @@ -import contextlib -import itertools from dataclasses import dataclass -from sqlite3 import Connection, connect -from typing import Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import pandas as pd import pyarrow as pa -from typing_extensions import Literal import datasets import datasets.config -from datasets import NamedSplit from datasets.features.features import require_storage_cast from datasets.table import table_cast -from datasets.utils.typing import NestedDataStructureLike, PathLike -logger = datasets.utils.logging.get_logger(__name__) +if TYPE_CHECKING: + import sqlalchemy @dataclass class SqlConfig(datasets.BuilderConfig): """BuilderConfig for SQL.""" - index_col: Optional[Union[int, str, List[int], List[str]]] = None - conn: Dict[Optional[NamedSplit], NestedDataStructureLike[Union[PathLike, Connection]]] = None - table_name: str = None - query: str = "SELECT * FROM `{table_name}`" + sql: Union[str, "sqlalchemy.sql.Selectable"] = None + con: str = None + index_col: Optional[Union[str, List[str]]] = None coerce_float: bool = True - params: Optional[Union[Sequence, Dict]] = None + params: Optional[Union[List, Tuple, Dict]] = None parse_dates: Optional[Union[List, Dict]] = None columns: Optional[List[str]] = None - chunksize: int = 10_000 + chunksize: Optional[int] = 10_000 features: Optional[datasets.Features] = None - encoding_errors: Optional[str] = "strict" - on_bad_lines: Literal["error", "warn", "skip"] = "error" def __post_init__(self): - if self.table_name is None: - raise ValueError("Expected argument `table_name` to be supplied.") - if self.conn is None: - raise ValueError("Expected argument `conn` to connect to the database") + assert self.sql is not None, "sql must be specified" + assert self.con is not None, "con must be specified" + + if not isinstance(self.con, str): + raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.") + + def create_config_id( + self, + config_kwargs: dict, + custom_features: Optional[datasets.Features] = None, + ) -> str: + # We need to stringify the Selectable object to make its hash deterministic + + # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html + sql = config_kwargs["sql"] + if not isinstance(sql, str): + if not datasets.config.SQLALCHEMY_AVAILABLE: + raise ImportError("Please pip install sqlalchemy.") + + import sqlalchemy + + config_kwargs = config_kwargs.copy() + engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") + sql_str = str(sql.compile(dialect=engine.dialect)) + config_kwargs["sql"] = sql_str + return super().create_config_id(config_kwargs, custom_features=custom_features) @property def read_sql_kwargs(self): @@ -50,7 +64,6 @@ def read_sql_kwargs(self): params=self.params, coerce_float=self.coerce_float, parse_dates=self.parse_dates, - chunksize=self.chunksize, ) return read_sql_kwargs @@ -62,21 +75,7 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): - """We handle string, Connection, list and dicts in datafiles""" - data_files = self.config.conn - if isinstance(data_files, (str, list, tuple)): - files = data_files - if isinstance(files, str): - files = [files] - files = [dl_manager.iter_files(file) for file in files] - return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] - splits = [] - for split_name, files in data_files.items(): - if isinstance(files, str): - files = [files] - files = [dl_manager.iter_files(file) for file in files] - splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) - return splits + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})] def _cast_table(self, pa_table: pa.Table) -> pa.Table: if self.config.features is not None: @@ -89,20 +88,10 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, schema) return pa_table - def _generate_tables(self, files): - for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - with contextlib.closing(connect(file)) as conn: - sql_file_reader = pd.read_sql( - self.config.query.format(table_name=self.config.table_name), conn, **self.config.read_sql_kwargs - ) - try: - for batch_idx, df in enumerate(sql_file_reader): - # Drop index column as it is not relevant. - pa_table = pa.Table.from_pandas(df.drop("index", axis=1, errors="ignore")) - # Uncomment for debugging (will print the Arrow table size and elements) - # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") - # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) - yield (file_idx, batch_idx), self._cast_table(pa_table) - except ValueError as e: - logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") - raise + def _generate_tables(self): + chunksize = self.config.chunksize + sql_reader = pd.read_sql(self.config.sql, self.config.con, chunksize=chunksize, **self.config.read_sql_kwargs) + sql_reader = [sql_reader] if chunksize is None else sql_reader + for chunk_idx, df in enumerate(sql_reader): + pa_table = pa.Table.from_pandas(df) + yield chunk_idx, self._cast_table(pa_table) From 5410f5121550c0aa273ac29e29c1c5f2d1091452 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 14:26:22 +0200 Subject: [PATCH 10/21] Update tests --- tests/fixtures/files.py | 13 +++--- tests/packaged_modules/test_sql.py | 55 ------------------------ tests/test_arrow_dataset.py | 67 ++++++++++-------------------- tests/utils.py | 14 +++++++ 4 files changed, 43 insertions(+), 106 deletions(-) delete mode 100644 tests/packaged_modules/test_sql.py diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index e9d96d50706..b1dd1f8785c 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -2,12 +2,11 @@ import csv import json import os +import sqlite3 import tarfile import textwrap import zipfile -from sqlite3 import connect -import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -242,10 +241,14 @@ def arrow_path(tmp_path_factory): @pytest.fixture(scope="session") -def sql_path(tmp_path_factory): +def sqlite_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite") - with contextlib.closing(connect(path)) as conn: - pd.DataFrame.from_records(DATA).to_sql("TABLE_NAME", conn) + with contextlib.closing(sqlite3.connect(path)) as con: + cur = con.cursor() + cur.execute("CREATE TABLE dataset(col_1 text, col_2 int, col_3 real)") + for item in DATA: + cur.execute("INSERT INTO dataset(col_1, col_2, col_3) VALUES (?, ?, ?)", tuple(item.values())) + con.commit() return path diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py deleted file mode 100644 index 65261ee63db..00000000000 --- a/tests/packaged_modules/test_sql.py +++ /dev/null @@ -1,55 +0,0 @@ -import contextlib -from sqlite3 import connect - -import pandas as pd -import pyarrow as pa -import pytest - -from datasets import ClassLabel, Features, Value -from datasets.packaged_modules.sql.sql import Sql - - -@pytest.fixture -def sqlite_file(tmp_path): - filename = str(tmp_path / "malformed_file.sqlite") - with contextlib.closing(connect(filename)) as conn: - pd.DataFrame.from_dict({"header1": [1, 10], "header2": [2, 20], "label": ["good", "bad"]}).to_sql( - "TABLE_NAME", con=conn - ) - return filename - - -def test_sql_cast_label(sqlite_file): - table_name = "TABLE_NAME" - with contextlib.closing(connect(sqlite_file)) as conn: - labels = pd.read_sql(f"SELECT * FROM {table_name}", conn)["label"].tolist() - sql = Sql( - conn=sqlite_file, - features=Features( - {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} - ), - table_name=table_name, - ) - generator = sql._generate_tables([[sqlite_file]]) - pa_table = pa.concat_tables([table for _, table in generator]) - assert pa_table.schema.field("label").type == ClassLabel(names=["good", "bad"])() - generated_content = pa_table.to_pydict()["label"] - assert generated_content == [ClassLabel(names=["good", "bad"]).str2int(label) for label in labels] - - -def test_missing_args(sqlite_file): - with pytest.raises(ValueError, match="Expected argument `table_name`"): - _ = Sql( - conn=sqlite_file, - features=Features( - {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} - ), - ) - - with pytest.raises(ValueError, match="Expected argument `conn`"): - _ = Sql( - features=Features( - {"header1": Value("int32"), "header2": Value("int32"), "label": ClassLabel(names=["good", "bad"])} - ), - table_name="TABLE_NAME", - ) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 0ba669f006f..789c59dc58e 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1,4 +1,3 @@ -import contextlib import copy import itertools import json @@ -8,7 +7,6 @@ import tempfile from functools import partial from pathlib import Path -from sqlite3 import connect from unittest import TestCase from unittest.mock import patch @@ -55,6 +53,7 @@ require_jax, require_pil, require_s3, + require_sqlalchemy, require_tf, require_torch, require_transformers, @@ -62,9 +61,6 @@ ) -SQLITE_TABLE_NAME = "TABLE_NAME" - - class Unpicklable: def __getstate__(self): raise pickle.PicklingError() @@ -2037,36 +2033,27 @@ def test_to_parquet(self, in_memory): self.assertEqual(parquet_dset.shape, dset.shape) self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + @require_sqlalchemy def test_to_sql(self, in_memory): - table_name = "TABLE_NAME" - - def read_sql(filepath): - - with contextlib.closing(connect(filepath)) as conn: - return pd.read_sql(f"SELECT * FROM {table_name}", conn).drop("index", axis=1, errors="ignore") - with tempfile.TemporaryDirectory() as tmp_dir: # File path argument with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: - file_path = os.path.join(tmp_dir, "test_path.sql") - _ = dset.to_sql(path_or_buf=file_path, table_name=table_name) + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, index=False) self.assertTrue(os.path.isfile(file_path)) - # self.assertEqual(bytes_written, os.path.getsize(file_path)) - sql_dset = read_sql(file_path) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) self.assertEqual(sql_dset.shape, dset.shape) self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) - # File buffer argument + # Test writing to a database in chunks with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: - file_path = os.path.join(tmp_dir, "test_buffer.sql") - with contextlib.closing(connect(file_path)) as buffer: - _ = dset.to_sql(path_or_buf=buffer, table_name=table_name) + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, batch_size=1, index=False, if_exists="replace") self.assertTrue(os.path.isfile(file_path)) - # self.assertEqual(bytes_written, os.path.getsize(file_path)) - sql_dset = read_sql(file_path) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) self.assertEqual(sql_dset.shape, dset.shape) self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) @@ -2074,24 +2061,22 @@ def read_sql(filepath): # After a select/shuffle transform with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: dset = dset.select(range(0, len(dset), 2)).shuffle() - file_path = os.path.join(tmp_dir, "test_path.sql") - _ = dset.to_sql(path_or_buf=file_path, table_name=table_name) + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, index=False, if_exists="replace") self.assertTrue(os.path.isfile(file_path)) - # self.assertEqual(bytes_written, os.path.getsize(file_path)) - sql_dset = read_sql(file_path) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) self.assertEqual(sql_dset.shape, dset.shape) self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) # With array features with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset: - file_path = os.path.join(tmp_dir, "test_path.sql") - _ = dset.to_sql(path_or_buf=file_path, table_name=table_name) + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, index=False, if_exists="replace") self.assertTrue(os.path.isfile(file_path)) - # self.assertEqual(bytes_written, os.path.getsize(file_path)) - sql_dset = read_sql(file_path) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) self.assertEqual(sql_dset.shape, dset.shape) self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) @@ -3305,6 +3290,7 @@ def _check_sql_dataset(dataset, expected_features): assert dataset.features[feature].dtype == expected_dtype +@require_sqlalchemy @pytest.mark.parametrize( "features", [ @@ -3315,40 +3301,29 @@ def _check_sql_dataset(dataset, expected_features): {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, ], ) -def test_dataset_from_sql_features(features, sql_path, tmp_path): +def test_dataset_from_sql_features(features, sqlite_path, tmp_path): cache_dir = tmp_path / "cache" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None ) - dataset = Dataset.from_sql(sql_path, table_name=SQLITE_TABLE_NAME, features=features, cache_dir=cache_dir) + dataset = Dataset.from_sql("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir) _check_sql_dataset(dataset, expected_features) +@require_sqlalchemy @pytest.mark.parametrize("keep_in_memory", [False, True]) -def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): +def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path): cache_dir = tmp_path / "cache" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): dataset = Dataset.from_sql( - sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, keep_in_memory=keep_in_memory + "dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory ) _check_sql_dataset(dataset, expected_features) -@pytest.mark.parametrize("path_type", [str, list]) -def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): - if issubclass(path_type, str): - path = sql_path - elif issubclass(path_type, list): - path = [sql_path] - cache_dir = tmp_path / "cache" - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - dataset = Dataset.from_sql(path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir) - _check_sql_dataset(dataset, expected_features) - - def test_dataset_to_json(dataset, tmp_path): file_path = tmp_path / "test_path.jsonl" bytes_written = dataset.to_json(path_or_buf=file_path) diff --git a/tests/utils.py b/tests/utils.py index 70853c6cdcb..6e4bdda4fe2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -121,6 +121,20 @@ def require_elasticsearch(test_case): return test_case +def require_sqlalchemy(test_case): + """ + Decorator marking a test that requires SQLAlchemy. + + These tests are skipped when SQLAlchemy isn't installed. + + """ + try: + import sqlalchemy # noqa + except ImportError: + test_case = unittest.skip("test requires sqlalchemy")(test_case) + return test_case + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch. From 3c128be2af495395b6997b17b76c1e10acb3e4d4 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 14:26:41 +0200 Subject: [PATCH 11/21] Update docs --- docs/source/loading.mdx | 15 ++++++++++++--- docs/source/package_reference/loading_methods.mdx | 4 ++++ docs/source/package_reference/main_classes.mdx | 2 ++ docs/source/process.mdx | 1 + setup.py | 1 + 5 files changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index efa46f86a61..a01e8daff4f 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -196,13 +196,22 @@ To load remote Parquet files via HTTP, pass the URLs instead: >>> wiki = load_dataset("parquet", data_files=data_files, split="train") ``` -### SQLite +### SQL -Datasets stored as a Table in a SQLite file can be loaded with: +Read database contents with with [`Dataset.from_sql`]. Both table names and queries are supported. + +For example, a table from a SQLite file can be loaded with: + +```py +>>> from datasets import Dataset +>>> dataset = Dataset.from_sql("data_table", "sqlite:///sqlite_file.db") +``` + +Use a query for a more precise read: ```py >>> from datasets import Dataset ->>> dataset = Dataset.from_sql('sqlite_file.db' table_name='Dataset') +>>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db") ``` ## In-memory data diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index bc478ffe777..c23d177b9dd 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -65,6 +65,10 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.parquet.ParquetConfig +### SQL + +[[autodoc]] datasets.packaged_modules.sql.SqlConfig + ### Images [[autodoc]] datasets.packaged_modules.imagefolder.ImageFolderConfig diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index b4dba764abb..1b2b700ff71 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -58,6 +58,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. - to_dict - to_json - to_parquet + - to_sql - add_faiss_index - add_faiss_index_from_external_arrays - save_faiss_index @@ -90,6 +91,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. - from_json - from_parquet - from_text + - from_sql - prepare_for_task - align_labels_with_mapping diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 72c60416aa0..60022d887dc 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -609,6 +609,7 @@ Want to save your dataset to a cloud storage provider? Read our [Cloud Storage]( | CSV | [`Dataset.to_csv`] | | JSON | [`Dataset.to_json`] | | Parquet | [`Dataset.to_parquet`] | +| SQL | [`Dataset.to_sql`] | | In-memory Python object | [`Dataset.to_pandas`] or [`Dataset.to_dict`] | For example, export your dataset to a CSV file like this: diff --git a/setup.py b/setup.py index 38925961ac9..ca1ef8f3408 100644 --- a/setup.py +++ b/setup.py @@ -155,6 +155,7 @@ "scipy", "sentencepiece", # for bleurt "seqeval", + "sqlalchemy" "tldextract", # to speed up pip backtracking "toml>=0.10.1", From 40268aeecc34a4b2a45586f8e8254025fec7b10d Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 15:10:44 +0200 Subject: [PATCH 12/21] Update some more tests --- src/datasets/io/sql.py | 8 ++--- tests/io/test_sql.py | 74 +++++++++++++++++++----------------------- 2 files changed, 36 insertions(+), 46 deletions(-) diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index c2f2fc7b6a9..0301908f50c 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -115,11 +115,9 @@ def _write(self, **to_sql_kwargs) -> int: num_rows, batch_size = len(self.dataset), self.batch_size with multiprocessing.Pool(self.num_proc) as pool: for num_rows in logging.tqdm( - enumerate( - pool.imap( - self._batch_sql, - [(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)], - ) + pool.imap( + self._batch_sql, + [(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)], ), total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, unit="ba", diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py index 244517a6fdb..143e1aa2201 100644 --- a/tests/io/test_sql.py +++ b/tests/io/test_sql.py @@ -1,17 +1,13 @@ import contextlib import os -from sqlite3 import connect +import sqlite3 -import pandas as pd import pytest from datasets import Dataset, Features, Value from datasets.io.sql import SqlDatasetReader, SqlDatasetWriter -from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases - - -SQLITE_TABLE_NAME = "TABLE_NAME" +from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_sqlalchemy def _check_sql_dataset(dataset, expected_features): @@ -23,17 +19,19 @@ def _check_sql_dataset(dataset, expected_features): assert dataset.features[feature].dtype == expected_dtype +@require_sqlalchemy @pytest.mark.parametrize("keep_in_memory", [False, True]) -def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): +def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path): cache_dir = tmp_path / "cache" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): dataset = SqlDatasetReader( - sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir, keep_in_memory=keep_in_memory + "dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory ).read() _check_sql_dataset(dataset, expected_features) +@require_sqlalchemy @pytest.mark.parametrize( "features", [ @@ -44,63 +42,57 @@ def test_dataset_from_sql_keep_in_memory(keep_in_memory, sql_path, tmp_path): {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, ], ) -def test_dataset_from_sql_features(features, sql_path, tmp_path): +def test_dataset_from_sql_features(features, sqlite_path, tmp_path): cache_dir = tmp_path / "cache" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None ) - dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, features=features, cache_dir=cache_dir).read() - _check_sql_dataset(dataset, expected_features) - - -@pytest.mark.parametrize("path_type", [str, list]) -def test_dataset_from_sql_path_type(path_type, sql_path, tmp_path): - if issubclass(path_type, str): - path = sql_path - elif issubclass(path_type, list): - path = [sql_path] - cache_dir = tmp_path / "cache" - expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - dataset = SqlDatasetReader(path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir).read() _check_sql_dataset(dataset, expected_features) -def iter_sql_file(sql_path): - with contextlib.closing(connect(sql_path)) as conn: - return pd.read_sql(f"SELECT * FROM {SQLITE_TABLE_NAME}", conn).drop("index", axis=1, errors="ignore") +def iter_sql_file(sqlite_path): + with contextlib.closing(sqlite3.connect(sqlite_path)) as con: + cur = con.cursor() + cur.execute("SELECT * FROM dataset") + for row in cur: + yield row -def test_dataset_to_sql(sql_path, tmp_path): +@require_sqlalchemy +def test_dataset_to_sql(sqlite_path, tmp_path): cache_dir = tmp_path / "cache" - output_sql = os.path.join(cache_dir, "tmp.sql") - dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() - SqlDatasetWriter(dataset, output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=1).write() + output_sqlite_path = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=1).write() - original_sql = iter_sql_file(sql_path) - expected_sql = iter_sql_file(output_sql) + original_sql = iter_sql_file(sqlite_path) + expected_sql = iter_sql_file(output_sqlite_path) for row1, row2 in zip(original_sql, expected_sql): assert row1 == row2 -def test_dataset_to_sql_multiproc(sql_path, tmp_path): +@require_sqlalchemy +def test_dataset_to_sql_multiproc(sqlite_path, tmp_path): cache_dir = tmp_path / "cache" - output_sql = os.path.join(cache_dir, "tmp.sql") - dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() - SqlDatasetWriter(dataset, output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=2).write() + output_sqlite_path = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=2).write() - original_sql = iter_sql_file(sql_path) - expected_sql = iter_sql_file(output_sql) + original_sql = iter_sql_file(sqlite_path) + expected_sql = iter_sql_file(output_sqlite_path) for row1, row2 in zip(original_sql, expected_sql): assert row1 == row2 -def test_dataset_to_sql_invalidproc(sql_path, tmp_path): +@require_sqlalchemy +def test_dataset_to_sql_invalidproc(sqlite_path, tmp_path): cache_dir = tmp_path / "cache" - output_sql = os.path.join(cache_dir, "tmp.sql") - dataset = SqlDatasetReader(sql_path, table_name=SQLITE_TABLE_NAME, cache_dir=cache_dir).read() + output_sqlite_path = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read() with pytest.raises(ValueError): - SqlDatasetWriter(dataset, output_sql, table_name=SQLITE_TABLE_NAME, index=False, num_proc=0) + SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=0).write() From dc005df8c8fef88b98f0ae4be99c1f196ec955a9 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 15:30:44 +0200 Subject: [PATCH 13/21] Missing comma --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index cde119b765a..6a89b08bc99 100644 --- a/setup.py +++ b/setup.py @@ -155,7 +155,7 @@ "scipy", "sentencepiece", # for bleurt "seqeval", - "sqlalchemy" + "sqlalchemy", "tldextract", # to speed up pip backtracking "toml>=0.10.1", From a3c39d9b8a39b820a37c9e89b4d8e7ca414cce17 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 17:14:56 +0200 Subject: [PATCH 14/21] Small docs fix --- src/datasets/packaged_modules/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index deabca1f35a..9cbaf083619 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -9,6 +9,7 @@ from .json import json from .pandas import pandas from .parquet import parquet +from .sql import sql # needed for the docs from .text import text From 7c4999e1c001a799217e194081863cc6520be8cc Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 21 Sep 2022 17:32:19 +0200 Subject: [PATCH 15/21] Style --- src/datasets/packaged_modules/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 9cbaf083619..f3553b0b961 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -9,7 +9,7 @@ from .json import json from .pandas import pandas from .parquet import parquet -from .sql import sql # needed for the docs +from .sql import sql # noqa F401 from .text import text From 920de97c955a09232c9fab931718b472a198896a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Fri, 23 Sep 2022 14:44:18 +0200 Subject: [PATCH 16/21] Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 5a6645f33a3..9cb2bcd1c93 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1121,7 +1121,7 @@ def from_sql( Example: ```py - >>> ds = Dataset.from_sql("test_data", "postgres:///db_name"s) + >>> ds = Dataset.from_sql("test_data", "postgres:///db_name") ``` """ from .io.sql import SqlDatasetReader From 9ecdb1f173eadf50c9f1cb53d81cfcd98c046d7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Fri, 23 Sep 2022 14:59:01 +0200 Subject: [PATCH 17/21] Update src/datasets/packaged_modules/sql/sql.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/packaged_modules/sql/sql.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 1e2d9920453..94e3f7a5641 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -45,15 +45,18 @@ def create_config_id( # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html sql = config_kwargs["sql"] if not isinstance(sql, str): - if not datasets.config.SQLALCHEMY_AVAILABLE: - raise ImportError("Please pip install sqlalchemy.") - - import sqlalchemy - - config_kwargs = config_kwargs.copy() - engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") - sql_str = str(sql.compile(dialect=engine.dialect)) - config_kwargs["sql"] = sql_str + if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules:: + import sqlalchemy + + if isinstance(sql, sqlalchemy.sql.Selectable): + config_kwargs = config_kwargs.copy() + engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") + sql_str = str(sql.compile(dialect=engine.dialect)) + config_kwargs["sql"] = sql_str + else: + raise TypeError(f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}") + else: + raise TypeError(f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}") return super().create_config_id(config_kwargs, custom_features=custom_features) @property From 27c9674f33e951fbc9c6bd90e456c5af892c0bdf Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 23 Sep 2022 14:59:47 +0200 Subject: [PATCH 18/21] Address some comments --- src/datasets/packaged_modules/csv/csv.py | 18 +++++++++--------- src/datasets/packaged_modules/sql/sql.py | 15 ++++++++------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index c3e3f9b906a..5220335ad03 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -70,8 +70,8 @@ def __post_init__(self): self.names = self.column_names @property - def read_csv_kwargs(self): - read_csv_kwargs = dict( + def pd_read_csv_kwargs(self): + pd_read_csv_kwargs = dict( sep=self.sep, header=self.header, names=self.names, @@ -112,16 +112,16 @@ def read_csv_kwargs(self): # some kwargs must not be passed if they don't have a default value # some others are deprecated and we can also not pass them if they are the default value - for read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS: - if read_csv_kwargs[read_csv_parameter] == getattr(CsvConfig(), read_csv_parameter): - del read_csv_kwargs[read_csv_parameter] + for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS: + if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter): + del pd_read_csv_kwargs[pd_read_csv_parameter] # Remove 1.3 new arguments if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3): - for read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS: - del read_csv_kwargs[read_csv_parameter] + for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS: + del pd_read_csv_kwargs[pd_read_csv_parameter] - return read_csv_kwargs + return pd_read_csv_kwargs class Csv(datasets.ArrowBasedBuilder): @@ -172,7 +172,7 @@ def _generate_tables(self, files): else None ) for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs) + csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs) try: for batch_idx, df in enumerate(csv_file_reader): pa_table = pa.Table.from_pandas(df) diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 1e2d9920453..e22ca2ceb97 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -29,9 +29,10 @@ class SqlConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None def __post_init__(self): - assert self.sql is not None, "sql must be specified" - assert self.con is not None, "con must be specified" - + if self.sql is None: + raise ValueError("sql must be specified") + if self.con is None: + raise ValueError("con must be specified") if not isinstance(self.con, str): raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.") @@ -57,15 +58,15 @@ def create_config_id( return super().create_config_id(config_kwargs, custom_features=custom_features) @property - def read_sql_kwargs(self): - read_sql_kwargs = dict( + def pd_read_sql_kwargs(self): + pd_read_sql_kwargs = dict( index_col=self.index_col, columns=self.columns, params=self.params, coerce_float=self.coerce_float, parse_dates=self.parse_dates, ) - return read_sql_kwargs + return pd_read_sql_kwargs class Sql(datasets.ArrowBasedBuilder): @@ -90,7 +91,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: def _generate_tables(self): chunksize = self.config.chunksize - sql_reader = pd.read_sql(self.config.sql, self.config.con, chunksize=chunksize, **self.config.read_sql_kwargs) + sql_reader = pd.read_sql(self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs) sql_reader = [sql_reader] if chunksize is None else sql_reader for chunk_idx, df in enumerate(sql_reader): pa_table = pa.Table.from_pandas(df) From 81ad0e46c683cbb3bdddcdfbd830f901f4aac9c0 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 23 Sep 2022 15:21:53 +0200 Subject: [PATCH 19/21] Address the rest --- src/datasets/arrow_dataset.py | 12 ++++++++++++ src/datasets/packaged_modules/sql/sql.py | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 9cb2bcd1c93..ae201b1ed92 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1121,7 +1121,14 @@ def from_sql( Example: ```py + >>> # Fetch a database table >>> ds = Dataset.from_sql("test_data", "postgres:///db_name") + >>> # Execute a SQL query on the table + >>> ds = Dataset.from_sql("SELECT sentence FROM test_data", "postgres:///db_name") + >>> # Use a Selectable object to specify the query + >>> from sqlalchemy import select, text + >>> stmt = select([text("sentence")]).select_from(text("test_data")) + >>> ds = Dataset.from_sql(stmt, "postgres:///db_name") ``` """ from .io.sql import SqlDatasetReader @@ -4164,7 +4171,12 @@ def to_sql( Example: ```py + >>> # con provided as a connection URI string >>> ds.to_sql("data", "sqlite:///my_own_db.sql") + >>> # con provided as a sqlite3 connection object + >>> import sqlite3 + >>> with sqlite3.connect("my_own_db.sql") as con: + ... ds.to_sql("data", con) ``` """ # Dynamic import to avoid circular dependency diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 72684b2302b..25c0178e264 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -1,3 +1,4 @@ +import sys from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union @@ -46,18 +47,22 @@ def create_config_id( # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html sql = config_kwargs["sql"] if not isinstance(sql, str): - if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules:: + if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules: import sqlalchemy - + if isinstance(sql, sqlalchemy.sql.Selectable): config_kwargs = config_kwargs.copy() engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") sql_str = str(sql.compile(dialect=engine.dialect)) config_kwargs["sql"] = sql_str else: - raise TypeError(f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}") + raise TypeError( + f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" + ) else: - raise TypeError(f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}") + raise TypeError( + f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" + ) return super().create_config_id(config_kwargs, custom_features=custom_features) @property @@ -94,7 +99,9 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: def _generate_tables(self): chunksize = self.config.chunksize - sql_reader = pd.read_sql(self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs) + sql_reader = pd.read_sql( + self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs + ) sql_reader = [sql_reader] if chunksize is None else sql_reader for chunk_idx, df in enumerate(sql_reader): pa_table = pa.Table.from_pandas(df) From 3714fb04e70395ead63674a935c394f9370619da Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 23 Sep 2022 15:40:26 +0200 Subject: [PATCH 20/21] Improve tests --- src/datasets/arrow_dataset.py | 3 ++- tests/test_arrow_dataset.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index ae201b1ed92..66ce0b56fa3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4175,7 +4175,8 @@ def to_sql( >>> ds.to_sql("data", "sqlite:///my_own_db.sql") >>> # con provided as a sqlite3 connection object >>> import sqlite3 - >>> with sqlite3.connect("my_own_db.sql") as con: + >>> con = sqlite3.connect("my_own_db.sql") + >>> with con: ... ds.to_sql("data", con) ``` """ diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 48a072f0a2b..3f1d45f420d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1,3 +1,4 @@ +import contextlib import copy import itertools import json @@ -2053,7 +2054,7 @@ def test_to_parquet(self, in_memory): @require_sqlalchemy def test_to_sql(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: - # File path argument + # Destionation specified as database URI string with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: file_path = os.path.join(tmp_dir, "test_path.sqlite") _ = dset.to_sql("data", "sqlite:///" + file_path, index=False) @@ -2064,6 +2065,20 @@ def test_to_sql(self, in_memory): self.assertEqual(sql_dset.shape, dset.shape) self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + # Destionation specified as sqlite3 connection + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + import sqlite3 + + file_path = os.path.join(tmp_dir, "test_path.sqlite") + with contextlib.closing(sqlite3.connect(file_path)) as con: + _ = dset.to_sql("data", con, index=False, if_exists="replace") + + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + # Test writing to a database in chunks with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: file_path = os.path.join(tmp_dir, "test_path.sqlite") From f3610c8d47673b9af63821ec22733bf7971fce1a Mon Sep 17 00:00:00 2001 From: mariosasko Date: Mon, 3 Oct 2022 18:00:09 +0200 Subject: [PATCH 21/21] sqlalchemy required tip --- src/datasets/arrow_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 66ce0b56fa3..89b781626ce 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1130,6 +1130,10 @@ def from_sql( >>> stmt = select([text("sentence")]).select_from(text("test_data")) >>> ds = Dataset.from_sql(stmt, "postgres:///db_name") ``` + + + `sqlalchemy` needs to be installed to use this function. + """ from .io.sql import SqlDatasetReader