diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index 6e63488e274..a01e8daff4f 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -196,6 +196,24 @@ To load remote Parquet files via HTTP, pass the URLs instead: >>> wiki = load_dataset("parquet", data_files=data_files, split="train") ``` +### SQL + +Read database contents with with [`Dataset.from_sql`]. Both table names and queries are supported. + +For example, a table from a SQLite file can be loaded with: + +```py +>>> from datasets import Dataset +>>> dataset = Dataset.from_sql("data_table", "sqlite:///sqlite_file.db") +``` + +Use a query for a more precise read: + +```py +>>> from datasets import Dataset +>>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db") +``` + ## In-memory data 🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames. diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index bc478ffe777..c23d177b9dd 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -65,6 +65,10 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.parquet.ParquetConfig +### SQL + +[[autodoc]] datasets.packaged_modules.sql.SqlConfig + ### Images [[autodoc]] datasets.packaged_modules.imagefolder.ImageFolderConfig diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index b4dba764abb..1b2b700ff71 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -58,6 +58,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. - to_dict - to_json - to_parquet + - to_sql - add_faiss_index - add_faiss_index_from_external_arrays - save_faiss_index @@ -90,6 +91,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. - from_json - from_parquet - from_text + - from_sql - prepare_for_task - align_labels_with_mapping diff --git a/docs/source/process.mdx b/docs/source/process.mdx index 72c60416aa0..60022d887dc 100644 --- a/docs/source/process.mdx +++ b/docs/source/process.mdx @@ -609,6 +609,7 @@ Want to save your dataset to a cloud storage provider? Read our [Cloud Storage]( | CSV | [`Dataset.to_csv`] | | JSON | [`Dataset.to_json`] | | Parquet | [`Dataset.to_parquet`] | +| SQL | [`Dataset.to_sql`] | | In-memory Python object | [`Dataset.to_pandas`] or [`Dataset.to_dict`] | For example, export your dataset to a CSV file like this: diff --git a/setup.py b/setup.py index 76113499b8c..6a89b08bc99 100644 --- a/setup.py +++ b/setup.py @@ -155,6 +155,7 @@ "scipy", "sentencepiece", # for bleurt "seqeval", + "sqlalchemy", "tldextract", # to speed up pip backtracking "toml>=0.10.1", diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 64270730cf3..89b781626ce 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -111,6 +111,10 @@ if TYPE_CHECKING: + import sqlite3 + + import sqlalchemy + from .dataset_dict import DatasetDict logger = logging.get_logger(__name__) @@ -1092,6 +1096,56 @@ def from_text( path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ).read() + @staticmethod + def from_sql( + sql: Union[str, "sqlalchemy.sql.Selectable"], + con: str, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + **kwargs, + ): + """Create Dataset from SQL query or database table. + + Args: + sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name. + con (`str`): A connection URI string used to instantiate a database connection. + features (:class:`Features`, optional): Dataset features. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. + keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. + **kwargs (additional keyword arguments): Keyword arguments to be passed to :class:`SqlConfig`. + + Returns: + :class:`Dataset` + + Example: + + ```py + >>> # Fetch a database table + >>> ds = Dataset.from_sql("test_data", "postgres:///db_name") + >>> # Execute a SQL query on the table + >>> ds = Dataset.from_sql("SELECT sentence FROM test_data", "postgres:///db_name") + >>> # Use a Selectable object to specify the query + >>> from sqlalchemy import select, text + >>> stmt = select([text("sentence")]).select_from(text("test_data")) + >>> ds = Dataset.from_sql(stmt, "postgres:///db_name") + ``` + + + `sqlalchemy` needs to be installed to use this function. + + """ + from .io.sql import SqlDatasetReader + + return SqlDatasetReader( + sql, + con, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + **kwargs, + ).read() + def __del__(self): if hasattr(self, "_data"): del self._data @@ -4098,6 +4152,43 @@ def to_parquet( return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write() + def to_sql( + self, + name: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], + batch_size: Optional[int] = None, + **sql_writer_kwargs, + ) -> int: + """Exports the dataset to a SQL database. + + Args: + name (`str`): Name of SQL table. + con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`): + A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database. + batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once. + Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + **sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql` + + Returns: + int: The number of records written. + + Example: + + ```py + >>> # con provided as a connection URI string + >>> ds.to_sql("data", "sqlite:///my_own_db.sql") + >>> # con provided as a sqlite3 connection object + >>> import sqlite3 + >>> con = sqlite3.connect("my_own_db.sql") + >>> with con: + ... ds.to_sql("data", con) + ``` + """ + # Dynamic import to avoid circular dependency + from .io.sql import SqlDatasetWriter + + return SqlDatasetWriter(self, name, con, batch_size=batch_size, **sql_writer_kwargs).write() + def _push_parquet_shards_to_hub( self, repo_id: str, diff --git a/src/datasets/config.py b/src/datasets/config.py index 2bd5419cbe3..e98166a8676 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -125,6 +125,9 @@ logger.info("Disabling Apache Beam because USE_BEAM is set to False") +# Optional tools for data loading +SQLALCHEMY_AVAILABLE = importlib.util.find_spec("sqlalchemy") is not None + # Optional tools for feature decoding PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py new file mode 100644 index 00000000000..0301908f50c --- /dev/null +++ b/src/datasets/io/sql.py @@ -0,0 +1,129 @@ +import multiprocessing +from typing import TYPE_CHECKING, Optional, Union + +from .. import Dataset, Features, config +from ..formatting import query_table +from ..packaged_modules.sql.sql import Sql +from ..utils import logging +from .abc import AbstractDatasetInputStream + + +if TYPE_CHECKING: + import sqlite3 + + import sqlalchemy + + +class SqlDatasetReader(AbstractDatasetInputStream): + def __init__( + self, + sql: Union[str, "sqlalchemy.sql.Selectable"], + con: str, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + **kwargs, + ): + super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs) + self.builder = Sql( + cache_dir=cache_dir, + features=features, + sql=sql, + con=con, + **kwargs, + ) + + def read(self): + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + + # Build dataset for splits + dataset = self.builder.as_dataset( + split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) + return dataset + + +class SqlDatasetWriter: + def __init__( + self, + dataset: Dataset, + name: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], + batch_size: Optional[int] = None, + num_proc: Optional[int] = None, + **to_sql_kwargs, + ): + + if num_proc is not None and num_proc <= 0: + raise ValueError(f"num_proc {num_proc} must be an integer > 0.") + + self.dataset = dataset + self.name = name + self.con = con + self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE + self.num_proc = num_proc + self.to_sql_kwargs = to_sql_kwargs + + def write(self) -> int: + _ = self.to_sql_kwargs.pop("sql", None) + _ = self.to_sql_kwargs.pop("con", None) + + written = self._write(**self.to_sql_kwargs) + return written + + def _batch_sql(self, args): + offset, to_sql_kwargs = args + to_sql_kwargs = {**to_sql_kwargs, "if_exists": "append"} if offset > 0 else to_sql_kwargs + batch = query_table( + table=self.dataset.data, + key=slice(offset, offset + self.batch_size), + indices=self.dataset._indices, + ) + df = batch.to_pandas() + num_rows = df.to_sql(self.name, self.con, **to_sql_kwargs) + return num_rows or len(df) + + def _write(self, **to_sql_kwargs) -> int: + """Writes the pyarrow table as SQL to a database. + + Caller is responsible for opening and closing the SQL connection. + """ + written = 0 + + if self.num_proc is None or self.num_proc == 1: + for offset in logging.tqdm( + range(0, len(self.dataset), self.batch_size), + unit="ba", + disable=not logging.is_progress_bar_enabled(), + desc="Creating SQL from Arrow format", + ): + written += self._batch_sql((offset, to_sql_kwargs)) + else: + num_rows, batch_size = len(self.dataset), self.batch_size + with multiprocessing.Pool(self.num_proc) as pool: + for num_rows in logging.tqdm( + pool.imap( + self._batch_sql, + [(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)], + ), + total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size, + unit="ba", + disable=not logging.is_progress_bar_enabled(), + desc="Creating SQL from Arrow format", + ): + written += num_rows + + return written diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index deabca1f35a..f3553b0b961 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -9,6 +9,7 @@ from .json import json from .pandas import pandas from .parquet import parquet +from .sql import sql # noqa F401 from .text import text diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index c3e3f9b906a..5220335ad03 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -70,8 +70,8 @@ def __post_init__(self): self.names = self.column_names @property - def read_csv_kwargs(self): - read_csv_kwargs = dict( + def pd_read_csv_kwargs(self): + pd_read_csv_kwargs = dict( sep=self.sep, header=self.header, names=self.names, @@ -112,16 +112,16 @@ def read_csv_kwargs(self): # some kwargs must not be passed if they don't have a default value # some others are deprecated and we can also not pass them if they are the default value - for read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS: - if read_csv_kwargs[read_csv_parameter] == getattr(CsvConfig(), read_csv_parameter): - del read_csv_kwargs[read_csv_parameter] + for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS: + if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter): + del pd_read_csv_kwargs[pd_read_csv_parameter] # Remove 1.3 new arguments if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3): - for read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS: - del read_csv_kwargs[read_csv_parameter] + for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS: + del pd_read_csv_kwargs[pd_read_csv_parameter] - return read_csv_kwargs + return pd_read_csv_kwargs class Csv(datasets.ArrowBasedBuilder): @@ -172,7 +172,7 @@ def _generate_tables(self, files): else None ) for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs) + csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs) try: for batch_idx, df in enumerate(csv_file_reader): pa_table = pa.Table.from_pandas(df) diff --git a/src/datasets/packaged_modules/sql/__init__.py b/src/datasets/packaged_modules/sql/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py new file mode 100644 index 00000000000..25c0178e264 --- /dev/null +++ b/src/datasets/packaged_modules/sql/sql.py @@ -0,0 +1,108 @@ +import sys +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import pandas as pd +import pyarrow as pa + +import datasets +import datasets.config +from datasets.features.features import require_storage_cast +from datasets.table import table_cast + + +if TYPE_CHECKING: + import sqlalchemy + + +@dataclass +class SqlConfig(datasets.BuilderConfig): + """BuilderConfig for SQL.""" + + sql: Union[str, "sqlalchemy.sql.Selectable"] = None + con: str = None + index_col: Optional[Union[str, List[str]]] = None + coerce_float: bool = True + params: Optional[Union[List, Tuple, Dict]] = None + parse_dates: Optional[Union[List, Dict]] = None + columns: Optional[List[str]] = None + chunksize: Optional[int] = 10_000 + features: Optional[datasets.Features] = None + + def __post_init__(self): + if self.sql is None: + raise ValueError("sql must be specified") + if self.con is None: + raise ValueError("con must be specified") + if not isinstance(self.con, str): + raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.") + + def create_config_id( + self, + config_kwargs: dict, + custom_features: Optional[datasets.Features] = None, + ) -> str: + # We need to stringify the Selectable object to make its hash deterministic + + # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html + sql = config_kwargs["sql"] + if not isinstance(sql, str): + if datasets.config.SQLALCHEMY_AVAILABLE and "sqlalchemy" in sys.modules: + import sqlalchemy + + if isinstance(sql, sqlalchemy.sql.Selectable): + config_kwargs = config_kwargs.copy() + engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") + sql_str = str(sql.compile(dialect=engine.dialect)) + config_kwargs["sql"] = sql_str + else: + raise TypeError( + f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" + ) + else: + raise TypeError( + f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" + ) + return super().create_config_id(config_kwargs, custom_features=custom_features) + + @property + def pd_read_sql_kwargs(self): + pd_read_sql_kwargs = dict( + index_col=self.index_col, + columns=self.columns, + params=self.params, + coerce_float=self.coerce_float, + parse_dates=self.parse_dates, + ) + return pd_read_sql_kwargs + + +class Sql(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = SqlConfig + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})] + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + # cheaper cast + pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) + else: + # more expensive cast; allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, schema) + return pa_table + + def _generate_tables(self): + chunksize = self.config.chunksize + sql_reader = pd.read_sql( + self.config.sql, self.config.con, chunksize=chunksize, **self.config.pd_read_sql_kwargs + ) + sql_reader = [sql_reader] if chunksize is None else sql_reader + for chunk_idx, df in enumerate(sql_reader): + pa_table = pa.Table.from_pandas(df) + yield chunk_idx, self._cast_table(pa_table) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index a6502468549..b1dd1f8785c 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -1,6 +1,8 @@ +import contextlib import csv import json import os +import sqlite3 import tarfile import textwrap import zipfile @@ -238,6 +240,18 @@ def arrow_path(tmp_path_factory): return path +@pytest.fixture(scope="session") +def sqlite_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite") + with contextlib.closing(sqlite3.connect(path)) as con: + cur = con.cursor() + cur.execute("CREATE TABLE dataset(col_1 text, col_2 int, col_3 real)") + for item in DATA: + cur.execute("INSERT INTO dataset(col_1, col_2, col_3) VALUES (?, ?, ?)", tuple(item.values())) + con.commit() + return path + + @pytest.fixture(scope="session") def csv_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.csv") diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py new file mode 100644 index 00000000000..143e1aa2201 --- /dev/null +++ b/tests/io/test_sql.py @@ -0,0 +1,98 @@ +import contextlib +import os +import sqlite3 + +import pytest + +from datasets import Dataset, Features, Value +from datasets.io.sql import SqlDatasetReader, SqlDatasetWriter + +from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_sqlalchemy + + +def _check_sql_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@require_sqlalchemy +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = SqlDatasetReader( + "dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory + ).read() + _check_sql_dataset(dataset, expected_features) + + +@require_sqlalchemy +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_sql_features(features, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir).read() + _check_sql_dataset(dataset, expected_features) + + +def iter_sql_file(sqlite_path): + with contextlib.closing(sqlite3.connect(sqlite_path)) as con: + cur = con.cursor() + cur.execute("SELECT * FROM dataset") + for row in cur: + yield row + + +@require_sqlalchemy +def test_dataset_to_sql(sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + output_sqlite_path = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=1).write() + + original_sql = iter_sql_file(sqlite_path) + expected_sql = iter_sql_file(output_sqlite_path) + + for row1, row2 in zip(original_sql, expected_sql): + assert row1 == row2 + + +@require_sqlalchemy +def test_dataset_to_sql_multiproc(sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + output_sqlite_path = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read() + SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=2).write() + + original_sql = iter_sql_file(sqlite_path) + expected_sql = iter_sql_file(output_sqlite_path) + + for row1, row2 in zip(original_sql, expected_sql): + assert row1 == row2 + + +@require_sqlalchemy +def test_dataset_to_sql_invalidproc(sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + output_sqlite_path = os.path.join(cache_dir, "tmp.sql") + dataset = SqlDatasetReader("dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir).read() + with pytest.raises(ValueError): + SqlDatasetWriter(dataset, "dataset", "sqlite:///" + output_sqlite_path, index=False, num_proc=0).write() diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 25df537c0ac..3f1d45f420d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1,3 +1,4 @@ +import contextlib import copy import itertools import json @@ -53,6 +54,7 @@ require_jax, require_pil, require_s3, + require_sqlalchemy, require_tf, require_torch, require_transformers, @@ -2049,6 +2051,68 @@ def test_to_parquet(self, in_memory): self.assertEqual(parquet_dset.shape, dset.shape) self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + @require_sqlalchemy + def test_to_sql(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + # Destionation specified as database URI string + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, index=False) + + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # Destionation specified as sqlite3 connection + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + import sqlite3 + + file_path = os.path.join(tmp_dir, "test_path.sqlite") + with contextlib.closing(sqlite3.connect(file_path)) as con: + _ = dset.to_sql("data", con, index=False, if_exists="replace") + + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # Test writing to a database in chunks + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, batch_size=1, index=False, if_exists="replace") + + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # After a select/shuffle transform + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset = dset.select(range(0, len(dset), 2)).shuffle() + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, index=False, if_exists="replace") + + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + + # With array features + with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.sqlite") + _ = dset.to_sql("data", "sqlite:///" + file_path, index=False, if_exists="replace") + + self.assertTrue(os.path.isfile(file_path)) + sql_dset = pd.read_sql("data", "sqlite:///" + file_path) + + self.assertEqual(sql_dset.shape, dset.shape) + self.assertListEqual(list(sql_dset.columns), list(dset.column_names)) + def test_train_test_split(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir) as dset: @@ -3249,6 +3313,49 @@ def test_dataset_from_generator_features(features, data_generator, tmp_path): _check_generator_dataset(dataset, expected_features) +def _check_sql_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@require_sqlalchemy +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_sql_features(features, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = Dataset.from_sql("dataset", "sqlite:///" + sqlite_path, features=features, cache_dir=cache_dir) + _check_sql_dataset(dataset, expected_features) + + +@require_sqlalchemy +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_sql_keep_in_memory(keep_in_memory, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = Dataset.from_sql( + "dataset", "sqlite:///" + sqlite_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory + ) + _check_sql_dataset(dataset, expected_features) + + def test_dataset_to_json(dataset, tmp_path): file_path = tmp_path / "test_path.jsonl" bytes_written = dataset.to_json(path_or_buf=file_path) diff --git a/tests/utils.py b/tests/utils.py index eecdef966d2..9ce1e4e06ed 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -130,6 +130,20 @@ def require_elasticsearch(test_case): return test_case +def require_sqlalchemy(test_case): + """ + Decorator marking a test that requires SQLAlchemy. + + These tests are skipped when SQLAlchemy isn't installed. + + """ + try: + import sqlalchemy # noqa + except ImportError: + test_case = unittest.skip("test requires sqlalchemy")(test_case) + return test_case + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch.