Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow connection objects in from_sql + small doc improvement #5091

Merged
merged 7 commits into from Oct 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/source/loading.mdx
Expand Up @@ -210,10 +210,19 @@ For example, a table from a SQLite file can be loaded with:
Use a query for a more precise read:

```py
>>> from sqlite3 import connect
>>> con = connect(":memory")
>>> # db writes ...
>>> from datasets import Dataset
>>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db")
>>> dataset = Dataset.from_sql("SELECT text FROM table WHERE length(text) > 100 LIMIT 10", con)
```

<Tip>

You can specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) for the 🤗 Datasets caching to work across sessions.

</Tip>

## In-memory data

🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.
Expand Down
11 changes: 6 additions & 5 deletions src/datasets/arrow_dataset.py
Expand Up @@ -1104,7 +1104,7 @@ def from_text(
@staticmethod
def from_sql(
sql: Union[str, "sqlalchemy.sql.Selectable"],
con: str,
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
Expand All @@ -1114,7 +1114,8 @@ def from_sql(

Args:
sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name.
con (`str`): A connection URI string used to instantiate a database connection.
con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`):
A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) used to instantiate a database connection or a SQLite3/SQLAlchemy connection object.
features (:class:`Features`, optional): Dataset features.
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
Expand All @@ -1137,7 +1138,7 @@ def from_sql(
```

<Tip {warning=true}>
`sqlalchemy` needs to be installed to use this function.
The returned dataset can only be cached if `con` is specified as URI string.
</Tip>
"""
from .io.sql import SqlDatasetReader
Expand Down Expand Up @@ -4218,8 +4219,8 @@ def to_sql(

Args:
name (`str`): Name of SQL table.
con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`):
A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database.
con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`):
A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) or a SQLite3/SQLAlchemy connection object used to write to a database.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/io/sql.py
Expand Up @@ -18,7 +18,7 @@ class SqlDatasetReader(AbstractDatasetInputStream):
def __init__(
self,
sql: Union[str, "sqlalchemy.sql.Selectable"],
con: str,
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
Expand Down
18 changes: 14 additions & 4 deletions src/datasets/packaged_modules/sql/sql.py
Expand Up @@ -12,15 +12,20 @@


if TYPE_CHECKING:
import sqlite3

import sqlalchemy


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class SqlConfig(datasets.BuilderConfig):
"""BuilderConfig for SQL."""

sql: Union[str, "sqlalchemy.sql.Selectable"] = None
con: str = None
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None
index_col: Optional[Union[str, List[str]]] = None
coerce_float: bool = True
params: Optional[Union[List, Tuple, Dict]] = None
Expand All @@ -34,14 +39,13 @@ def __post_init__(self):
raise ValueError("sql must be specified")
if self.con is None:
raise ValueError("con must be specified")
if not isinstance(self.con, str):
raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.")

def create_config_id(
self,
config_kwargs: dict,
custom_features: Optional[datasets.Features] = None,
) -> str:
config_kwargs = config_kwargs.copy()
# We need to stringify the Selectable object to make its hash deterministic

# The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html
Expand All @@ -51,7 +55,6 @@ def create_config_id(
import sqlalchemy

if isinstance(sql, sqlalchemy.sql.Selectable):
config_kwargs = config_kwargs.copy()
engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://")
sql_str = str(sql.compile(dialect=engine.dialect))
config_kwargs["sql"] = sql_str
Expand All @@ -63,6 +66,13 @@ def create_config_id(
raise TypeError(
f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}"
)
con = config_kwargs["con"]
if not isinstance(con, str):
config_kwargs["con"] = id(con)
logger.info(
f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead."
)

return super().create_config_id(config_kwargs, custom_features=custom_features)

@property
Expand Down
30 changes: 30 additions & 0 deletions tests/test_arrow_dataset.py
Expand Up @@ -3356,6 +3356,36 @@ def _check_sql_dataset(dataset, expected_features):
assert dataset.features[feature].dtype == expected_dtype


@require_sqlalchemy
@pytest.mark.parametrize("con_type", ["string", "engine"])
def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path):
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
if con_type == "string":
con = "sqlite:///" + sqlite_path
elif con_type == "engine":
import sqlalchemy

con = sqlalchemy.create_engine("sqlite:///" + sqlite_path)
# # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work
# with caplog.at_level(INFO):
# dataset = Dataset.from_sql(
# "dataset",
# con,
# cache_dir=cache_dir,
# )
# if con_type == "string":
# assert "couldn't be hashed properly" not in caplog.text
# elif con_type == "engine":
# assert "couldn't be hashed properly" in caplog.text
dataset = Dataset.from_sql(
"dataset",
con,
cache_dir=cache_dir,
)
_check_sql_dataset(dataset, expected_features)


@require_sqlalchemy
@pytest.mark.parametrize(
"features",
Expand Down