diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index d3b0b8fbf31..16b44259efd 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -210,10 +210,19 @@ For example, a table from a SQLite file can be loaded with: Use a query for a more precise read: ```py +>>> from sqlite3 import connect +>>> con = connect(":memory") +>>> # db writes ... >>> from datasets import Dataset ->>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db") +>>> dataset = Dataset.from_sql("SELECT text FROM table WHERE length(text) > 100 LIMIT 10", con) ``` + + +You can specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) for the 🤗 Datasets caching to work across sessions. + + + ## In-memory data 🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index eda0b318128..17a4412693d 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1104,7 +1104,7 @@ def from_text( @staticmethod def from_sql( sql: Union[str, "sqlalchemy.sql.Selectable"], - con: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, @@ -1114,7 +1114,8 @@ def from_sql( Args: sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name. - con (`str`): A connection URI string used to instantiate a database connection. + con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`): + A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) used to instantiate a database connection or a SQLite3/SQLAlchemy connection object. features (:class:`Features`, optional): Dataset features. cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. @@ -1137,7 +1138,7 @@ def from_sql( ``` - `sqlalchemy` needs to be installed to use this function. + The returned dataset can only be cached if `con` is specified as URI string. """ from .io.sql import SqlDatasetReader @@ -4218,8 +4219,8 @@ def to_sql( Args: name (`str`): Name of SQL table. - con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`): - A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database. + con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`): + A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) or a SQLite3/SQLAlchemy connection object used to write to a database. batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once. Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. **sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql` diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 0301908f50c..c88cad49398 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -18,7 +18,7 @@ class SqlDatasetReader(AbstractDatasetInputStream): def __init__( self, sql: Union[str, "sqlalchemy.sql.Selectable"], - con: str, + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"], features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index 25c0178e264..5fe9d74acf1 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -12,15 +12,20 @@ if TYPE_CHECKING: + import sqlite3 + import sqlalchemy +logger = datasets.utils.logging.get_logger(__name__) + + @dataclass class SqlConfig(datasets.BuilderConfig): """BuilderConfig for SQL.""" sql: Union[str, "sqlalchemy.sql.Selectable"] = None - con: str = None + con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None index_col: Optional[Union[str, List[str]]] = None coerce_float: bool = True params: Optional[Union[List, Tuple, Dict]] = None @@ -34,14 +39,13 @@ def __post_init__(self): raise ValueError("sql must be specified") if self.con is None: raise ValueError("con must be specified") - if not isinstance(self.con, str): - raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.") def create_config_id( self, config_kwargs: dict, custom_features: Optional[datasets.Features] = None, ) -> str: + config_kwargs = config_kwargs.copy() # We need to stringify the Selectable object to make its hash deterministic # The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html @@ -51,7 +55,6 @@ def create_config_id( import sqlalchemy if isinstance(sql, sqlalchemy.sql.Selectable): - config_kwargs = config_kwargs.copy() engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://") sql_str = str(sql.compile(dialect=engine.dialect)) config_kwargs["sql"] = sql_str @@ -63,6 +66,13 @@ def create_config_id( raise TypeError( f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}" ) + con = config_kwargs["con"] + if not isinstance(con, str): + config_kwargs["con"] = id(con) + logger.info( + f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead." + ) + return super().create_config_id(config_kwargs, custom_features=custom_features) @property diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..54e1c62806a 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3356,6 +3356,36 @@ def _check_sql_dataset(dataset, expected_features): assert dataset.features[feature].dtype == expected_dtype +@require_sqlalchemy +@pytest.mark.parametrize("con_type", ["string", "engine"]) +def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + if con_type == "string": + con = "sqlite:///" + sqlite_path + elif con_type == "engine": + import sqlalchemy + + con = sqlalchemy.create_engine("sqlite:///" + sqlite_path) + # # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work + # with caplog.at_level(INFO): + # dataset = Dataset.from_sql( + # "dataset", + # con, + # cache_dir=cache_dir, + # ) + # if con_type == "string": + # assert "couldn't be hashed properly" not in caplog.text + # elif con_type == "engine": + # assert "couldn't be hashed properly" in caplog.text + dataset = Dataset.from_sql( + "dataset", + con, + cache_dir=cache_dir, + ) + _check_sql_dataset(dataset, expected_features) + + @require_sqlalchemy @pytest.mark.parametrize( "features",