diff --git a/dask/dask-schema.yaml b/dask/dask-schema.yaml index 8573c0e5b66..7a7bc84680d 100644 --- a/dask/dask-schema.yaml +++ b/dask/dask-schema.yaml @@ -72,6 +72,14 @@ properties: task when reading a parquet dataset from a REMOTE file system. Specifying 0 will result in serial execution on the client. + dtype_backend: + enum: + - pandas + - pyarrow + description: | + The nullable dtype implementation to use. Must be either "pandas" or + "pyarrow". Default is "pandas". + array: type: object properties: diff --git a/dask/dask.yaml b/dask/dask.yaml index 4b649c80a81..2d640fc64a8 100644 --- a/dask/dask.yaml +++ b/dask/dask.yaml @@ -12,6 +12,7 @@ dataframe: parquet: metadata-task-size-local: 512 # Number of files per local metadata-processing task metadata-task-size-remote: 16 # Number of files per remote metadata-processing task + dtype_backend: "pandas" # Dtype implementation to use array: backend: "numpy" # Backend array library for input IO and data creation diff --git a/dask/dataframe/io/parquet/arrow.py b/dask/dataframe/io/parquet/arrow.py index 369de1707e1..e00cf050c8e 100644 --- a/dask/dataframe/io/parquet/arrow.py +++ b/dask/dataframe/io/parquet/arrow.py @@ -1580,19 +1580,31 @@ def _arrow_table_to_pandas( _kwargs.update({"use_threads": False, "ignore_metadata": False}) if use_nullable_dtypes: + # Determine is `pandas` or `pyarrow`-backed dtypes should be used + if use_nullable_dtypes == "pandas": + default_types_mapper = PYARROW_NULLABLE_DTYPE_MAPPING.get + else: + # use_nullable_dtypes == "pyarrow" + + def default_types_mapper(pyarrow_dtype): # type: ignore + # Special case pyarrow strings to use more feature complete dtype + # See https://github.com/pandas-dev/pandas/issues/50074 + if pyarrow_dtype == pa.string(): + return pd.StringDtype("pyarrow") + else: + return pd.ArrowDtype(pyarrow_dtype) + if "types_mapper" in _kwargs: - # User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING + # User-provided entries take priority over default_types_mapper types_mapper = _kwargs["types_mapper"] def _types_mapper(pa_type): - return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get( - pa_type - ) + return types_mapper(pa_type) or default_types_mapper(pa_type) _kwargs["types_mapper"] = _types_mapper else: - _kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get + _kwargs["types_mapper"] = default_types_mapper return arrow_table.to_pandas(categories=categories, **_kwargs) diff --git a/dask/dataframe/io/parquet/core.py b/dask/dataframe/io/parquet/core.py index 8ed92b6a892..b7beb88931a 100644 --- a/dask/dataframe/io/parquet/core.py +++ b/dask/dataframe/io/parquet/core.py @@ -185,7 +185,7 @@ def read_parquet( index=None, storage_options=None, engine="auto", - use_nullable_dtypes=False, + use_nullable_dtypes: bool = False, calculate_divisions=None, ignore_metadata_file=False, metadata_task_size=None, @@ -257,6 +257,22 @@ def read_parquet( engine : {'auto', 'pyarrow', 'fastparquet'}, default 'auto' Parquet library to use. Defaults to 'auto', which uses ``pyarrow`` if it is installed, and falls back to ``fastparquet`` otherwise. + use_nullable_dtypes : {False, True} + Whether to use extension dtypes for the resulting ``DataFrame``. + ``use_nullable_dtypes=True`` is only supported when ``engine="pyarrow"``. + + .. note:: + + Use the ``dataframe.dtype_backend`` config option to select which + dtype implementation to use. + + ``dataframe.dtype_backend="pandas"`` (the default) will use + pandas' ``numpy``-backed nullable dtypes (e.g. ``Int64``, + ``string[python]``, etc.) while ``dataframe.dtype_backend="pyarrow"`` + will use ``pyarrow``-backed extension dtypes (e.g. ``int64[pyarrow]``, + ``string[pyarrow]``, etc.). ``dataframe.dtype_backend="pyarrow"`` + requires ``pandas`` 1.5+. + calculate_divisions : bool, default False Whether to use min/max statistics from the footer metadata (or global ``_metadata`` file) to calculate divisions for the output DataFrame @@ -366,6 +382,9 @@ def read_parquet( pyarrow.parquet.ParquetDataset """ + if use_nullable_dtypes: + use_nullable_dtypes = dask.config.get("dataframe.dtype_backend") + # "Pre-deprecation" warning for `chunksize` if chunksize: warnings.warn( @@ -552,7 +571,7 @@ def read_parquet( if "retries" not in annotations and not _is_local_fs(fs): ctx = dask.annotate(retries=5) else: - ctx = contextlib.nullcontext() + ctx = contextlib.nullcontext() # type: ignore with ctx: # Construct the output collection with from_map diff --git a/dask/dataframe/io/tests/test_parquet.py b/dask/dataframe/io/tests/test_parquet.py index c12f956f35f..dd8450de134 100644 --- a/dask/dataframe/io/tests/test_parquet.py +++ b/dask/dataframe/io/tests/test_parquet.py @@ -15,7 +15,12 @@ import dask.dataframe as dd import dask.multiprocessing from dask.blockwise import Blockwise, optimize_blockwise -from dask.dataframe._compat import PANDAS_GT_110, PANDAS_GT_121, PANDAS_GT_130 +from dask.dataframe._compat import ( + PANDAS_GT_110, + PANDAS_GT_121, + PANDAS_GT_130, + PANDAS_GT_150, +) from dask.dataframe.io.parquet.core import get_engine from dask.dataframe.io.parquet.utils import _parse_pandas_metadata from dask.dataframe.optimize import optimize_dataframe_getitem @@ -618,17 +623,37 @@ def test_roundtrip_nullable_dtypes(tmp_path, write_engine, read_engine): @PYARROW_MARK -def test_use_nullable_dtypes(tmp_path, engine): +@pytest.mark.parametrize( + "dtype_backend", + [ + "pandas", + pytest.param( + "pyarrow", + marks=pytest.mark.skipif( + not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes" + ), + ), + ], +) +def test_use_nullable_dtypes(tmp_path, engine, dtype_backend): """ Test reading a parquet file without pandas metadata, but forcing use of nullable dtypes where appropriate """ + + if dtype_backend == "pandas": + dtype_extra = "" + else: + # dtype_backend == "pyarrow" + dtype_extra = "[pyarrow]" df = pd.DataFrame( { - "a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"), - "b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"), - "c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"), - "d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"), + "a": pd.Series([1, 2, pd.NA, 3, 4], dtype=f"Int64{dtype_extra}"), + "b": pd.Series( + [True, pd.NA, False, True, False], dtype=f"boolean{dtype_extra}" + ), + "c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype=f"Float64{dtype_extra}"), + "d": pd.Series(["a", "b", "c", "d", pd.NA], dtype=f"string{dtype_extra}"), } ) ddf = dd.from_pandas(df, npartitions=2) @@ -644,21 +669,24 @@ def write_partition(df, i): partitions = ddf.to_delayed() dask.compute([write_partition(p, i) for i, p in enumerate(partitions)]) - # Not supported by fastparquet - if engine == "fastparquet": - with pytest.raises(ValueError, match="`use_nullable_dtypes` is not supported"): - dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True) + with dask.config.set({"dataframe.dtype_backend": dtype_backend}): + # Not supported by fastparquet + if engine == "fastparquet": + with pytest.raises( + ValueError, match="`use_nullable_dtypes` is not supported" + ): + dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True) - # Works in pyarrow - else: - # Doesn't round-trip by default when we aren't using nullable dtypes - with pytest.raises(AssertionError): - ddf2 = dd.read_parquet(tmp_path, engine=engine) - assert_eq(df, ddf2) - - # Round trip works when we use nullable dtypes - ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True) - assert_eq(df, ddf2, check_index=False) + # Works in pyarrow + else: + # Doesn't round-trip by default when we aren't using nullable dtypes + with pytest.raises(AssertionError): + ddf2 = dd.read_parquet(tmp_path, engine=engine) + assert_eq(df, ddf2) + + # Round trip works when we use nullable dtypes + ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True) + assert_eq(df, ddf2, check_index=False) @pytest.mark.xfail( diff --git a/dask/tests/test_spark_compat.py b/dask/tests/test_spark_compat.py index 41e21d05f6e..f26a859dfce 100644 --- a/dask/tests/test_spark_compat.py +++ b/dask/tests/test_spark_compat.py @@ -1,19 +1,22 @@ +import decimal import signal import sys import threading import pytest +import dask from dask.datasets import timeseries dd = pytest.importorskip("dask.dataframe") pyspark = pytest.importorskip("pyspark") -pytest.importorskip("pyarrow") +pa = pytest.importorskip("pyarrow") pytest.importorskip("fastparquet") import numpy as np import pandas as pd +from dask.dataframe._compat import PANDAS_GT_150 from dask.dataframe.utils import assert_eq pytestmark = pytest.mark.skipif( @@ -149,3 +152,43 @@ def test_roundtrip_parquet_spark_to_dask_extension_dtypes(spark_session, tmpdir) [pd.api.types.is_extension_array_dtype(dtype) for dtype in ddf.dtypes] ), ddf.dtypes assert_eq(ddf, pdf, check_index=False) + + +@pytest.mark.skipif(not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes") +def test_read_decimal_dtype_pyarrow(spark_session, tmpdir): + tmpdir = str(tmpdir) + npartitions = 3 + size = 6 + + decimal_data = [ + decimal.Decimal("8093.234"), + decimal.Decimal("8094.234"), + decimal.Decimal("8095.234"), + decimal.Decimal("8096.234"), + decimal.Decimal("8097.234"), + decimal.Decimal("8098.234"), + ] + pdf = pd.DataFrame( + { + "a": range(size), + "b": decimal_data, + } + ) + sdf = spark_session.createDataFrame(pdf) + sdf = sdf.withColumn("b", sdf["b"].cast(pyspark.sql.types.DecimalType(7, 3))) + # We are not overwriting any data, but spark complains if the directory + # already exists (as tmpdir does) and we don't set overwrite + sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite") + + with dask.config.set({"dataframe.dtype_backend": "pyarrow"}): + ddf = dd.read_parquet(tmpdir, engine="pyarrow", use_nullable_dtypes=True) + assert ddf.b.dtype.pyarrow_dtype == pa.decimal128(7, 3) + assert ddf.b.compute().dtype.pyarrow_dtype == pa.decimal128(7, 3) + expected = pdf.astype( + { + "a": "int64[pyarrow]", + "b": pd.ArrowDtype(pa.decimal128(7, 3)), + } + ) + + assert_eq(ddf, expected, check_index=False)