From 034bf186985dae77f1a4781c1c1c6f568f2de070 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 7 Oct 2022 14:28:35 +0200 Subject: [PATCH 1/7] Allow connecton objects in `from_sql` --- src/datasets/arrow_dataset.py | 11 ++++++----- src/datasets/io/sql.py | 2 +- src/datasets/packaged_modules/sql/sql.py | 18 ++++++++++++++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..17a4412693d 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1104,7 +1104,7 @@ def from_text( @staticmethod def from_sql( sql: Union[str, "sqlalchemy.sql.Selectable"], - con: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, @@ -1114,7 +1114,8 @@ def from_sql( Args: sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name. - con (`str`): A connection URI string used to instantiate a database connection. + con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`): + A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) used to instantiate a database connection or a SQLite3/SQLAlchemy connection object. features (:class:`Features`, optional): Dataset features. cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. @@ -1137,7 +1138,7 @@ def from_sql( ``` - `sqlalchemy` needs to be installed to use this function. + The returned dataset can only be cached if `con` is specified as URI string. """ from .io.sql import SqlDatasetReader @@ -4218,8 +4219,8 @@ def to_sql( Args: name (`str`): Name of SQL table. - con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`): - A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database. + con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`): + A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) or a SQLite3/SQLAlchemy connection object used to write to a database. batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once. Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. **sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql` diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 0301908f50c..c88cad49398 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -18,7 +18,7 @@ class SqlDatasetReader(AbstractDatasetInputStream): def __init__( self, sql: Union[str, "sqlalchemy.sql.Selectable"], - con: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 25c0178e264..85727f99730 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -12,15 +12,20 @@ if TYPE_CHECKING: + import sqlite3 + import sqlalchemy +logger = datasets.utils.logging.get_logger(__name__) + + @dataclass class SqlConfig(datasets.BuilderConfig): """BuilderConfig for SQL.""" sql: Union[str, "sqlalchemy.sql.Selectable"] = None - con: str = None + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None index_col: Optional[Union[str, List[str]]] = None coerce_float: bool = True params: Optional[Union[List, Tuple, Dict]] = None @@ -34,14 +39,13 @@ def __post_init__(self): raise ValueError("sql must be specified") if self.con is None: raise ValueError("con must be specified") - if not isinstance(self.con, str): - raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.") def create_config_id( self, config_kwargs: dict, custom_features: Optional[datasets.Features] = None, ) -> str: + config_kwargs = config_kwargs.copy() # We need to stringify the Selectable object to make its hash deterministic # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html @@ -51,7 +55,6 @@ def create_config_id( import sqlalchemy if isinstance(sql, sqlalchemy.sql.Selectable): - config_kwargs = config_kwargs.copy() engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") sql_str = str(sql.compile(dialect=engine.dialect)) config_kwargs["sql"] = sql_str @@ -63,6 +66,13 @@ def create_config_id( raise TypeError( f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" ) + con = config_kwargs["con"] + if not isinstance(con, str): + config_kwargs["con"] = id(con) + logger.info( + f"'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." + ) + return super().create_config_id(config_kwargs, custom_features=custom_features) @property From c7a19618bc4783d28168dde179d7fcf5b0b573b8 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 7 Oct 2022 14:29:03 +0200 Subject: [PATCH 2/7] Improve docs --- docs/source/loading.mdx | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index d3b0b8fbf31..3e598fbceca 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -210,10 +210,19 @@ For example, a table from a SQLite file can be loaded with: Use a query for a more precise read: ```py +>>> from sqlite3 import connect +>>> con = connect(":memory") +>>> # db writes ... >>> from datasets import Dataset ->>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db") +>>> dataset = Dataset.from_sql("SELECT text FROM table WHERE length(text) > 100 LIMIT 10", con) ``` + + +To cache the read, specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls). + + + ## In-memory data 🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames. From 353bec3ad1b2fc464f267e8e4f2110b7dda28db6 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 7 Oct 2022 14:31:05 +0200 Subject: [PATCH 3/7] Test --- tests/test_arrow_dataset.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..911086e37de 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3356,6 +3356,25 @@ def _check_sql_dataset(dataset, expected_features): assert dataset.features[feature].dtype == expected_dtype +@require_sqlalchemy +@pytest.mark.parametrize("con_type", ["string", "engine"]) +def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + if con_type == "string": + con = "sqlite:///" + sqlite_path + elif con_type == "engine": + import sqlalchemy + + con = sqlalchemy.create_engine("sqlite:///" + sqlite_path) + dataset = Dataset.from_sql( + "dataset", + con, + cache_dir=cache_dir, + ) + _check_sql_dataset(dataset, expected_features) + + @require_sqlalchemy @pytest.mark.parametrize( "features", From 9390b7bc194bd76079a5177b0c6cdfa4b7dc3919 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 7 Oct 2022 15:39:27 +0200 Subject: [PATCH 4/7] Add comment to test --- tests/test_arrow_dataset.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 911086e37de..c4907a4f13a 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3358,7 +3358,7 @@ def _check_sql_dataset(dataset, expected_features): @require_sqlalchemy @pytest.mark.parametrize("con_type", ["string", "engine"]) -def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path): +def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, caplog): cache_dir = tmp_path / "cache" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} if con_type == "string": @@ -3367,6 +3367,17 @@ def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path): import sqlalchemy con = sqlalchemy.create_engine("sqlite:///" + sqlite_path) + # # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work + # with caplog.at_level(INFO): + # dataset = Dataset.from_sql( + # "dataset", + # con, + # cache_dir=cache_dir, + # ) + # if con_type == "string": + # assert "couldn't be hashed properly" not in caplog.text + # elif con_type == "engine": + # assert "couldn't be hashed properly" in caplog.text dataset = Dataset.from_sql( "dataset", con, From b4b95ce6f6593fc05e7ea70938d645470e797bc2 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 7 Oct 2022 15:40:10 +0200 Subject: [PATCH 5/7] Unused param --- tests/test_arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index c4907a4f13a..54e1c62806a 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3358,7 +3358,7 @@ def _check_sql_dataset(dataset, expected_features): @require_sqlalchemy @pytest.mark.parametrize("con_type", ["string", "engine"]) -def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, caplog): +def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path): cache_dir = tmp_path / "cache" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} if con_type == "string": From 7ff5f87837f6e322a879410edc1c9ae18a00770c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Sun, 9 Oct 2022 14:57:48 +0200 Subject: [PATCH 6/7] Update src/datasets/packaged_modules/sql/sql.py Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- src/datasets/packaged_modules/sql/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 85727f99730..5fe9d74acf1 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -70,7 +70,7 @@ def create_config_id( if not isinstance(con, str): config_kwargs["con"] = id(con) logger.info( - f"'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." + f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." ) return super().create_config_id(config_kwargs, custom_features=custom_features) From f1d639b7ac3d606f37a50872d797818b0d8fd8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Sun, 9 Oct 2022 14:57:56 +0200 Subject: [PATCH 7/7] Update docs/source/loading.mdx Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> --- docs/source/loading.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index 3e598fbceca..16b44259efd 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -219,7 +219,7 @@ Use a query for a more precise read: -To cache the read, specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls). +You can specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) for the 🤗 Datasets caching to work across sessions.