From 353bec3ad1b2fc464f267e8e4f2110b7dda28db6 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 7 Oct 2022 14:31:05 +0200 Subject: [PATCH] Test --- tests/test_arrow_dataset.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4fe9ea1ea2b..911086e37de 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3356,6 +3356,25 @@ def _check_sql_dataset(dataset, expected_features): assert dataset.features[feature].dtype == expected_dtype +@require_sqlalchemy +@pytest.mark.parametrize("con_type", ["string", "engine"]) +def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + if con_type == "string": + con = "sqlite:///" + sqlite_path + elif con_type == "engine": + import sqlalchemy + + con = sqlalchemy.create_engine("sqlite:///" + sqlite_path) + dataset = Dataset.from_sql( + "dataset", + con, + cache_dir=cache_dir, + ) + _check_sql_dataset(dataset, expected_features) + + @require_sqlalchemy @pytest.mark.parametrize( "features",