Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dtype_backend="pandas|pyarrow" configuration #9719

Merged
merged 11 commits into from Dec 16, 2022
8 changes: 8 additions & 0 deletions dask/dask-schema.yaml
Expand Up @@ -72,6 +72,14 @@ properties:
task when reading a parquet dataset from a REMOTE file system.
Specifying 0 will result in serial execution on the client.

dtype_backend:
enum:
- pandas
- pyarrow
description: |
The nullable dtype implementation to use. Must be either "pandas" or
"pyarrow". Default is "pandas".

array:
type: object
properties:
Expand Down
1 change: 1 addition & 0 deletions dask/dask.yaml
Expand Up @@ -12,6 +12,7 @@ dataframe:
parquet:
metadata-task-size-local: 512 # Number of files per local metadata-processing task
metadata-task-size-remote: 16 # Number of files per remote metadata-processing task
dtype_backend: "pandas" # Dtype implementation to use

array:
backend: "numpy" # Backend array library for input IO and data creation
Expand Down
22 changes: 17 additions & 5 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -1580,19 +1580,31 @@ def _arrow_table_to_pandas(
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
# Determine is `pandas` or `pyarrow`-backed dtypes should be used
if use_nullable_dtypes == "pandas":
default_types_mapper = PYARROW_NULLABLE_DTYPE_MAPPING.get
else:
# use_nullable_dtypes == "pyarrow"

def default_types_mapper(pyarrow_dtype): # type: ignore
# Special case pyarrow strings to use more feature complete dtype
# See https://github.com/pandas-dev/pandas/issues/50074
if pyarrow_dtype == pa.string():
return pd.StringDtype("pyarrow")
else:
return pd.ArrowDtype(pyarrow_dtype)

if "types_mapper" in _kwargs:
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
# User-provided entries take priority over default_types_mapper
types_mapper = _kwargs["types_mapper"]

def _types_mapper(pa_type):
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
pa_type
)
return types_mapper(pa_type) or default_types_mapper(pa_type)

_kwargs["types_mapper"] = _types_mapper

else:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get
_kwargs["types_mapper"] = default_types_mapper

return arrow_table.to_pandas(categories=categories, **_kwargs)

Expand Down
23 changes: 21 additions & 2 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -185,7 +185,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
use_nullable_dtypes: bool = False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -257,6 +257,22 @@ def read_parquet(
engine : {'auto', 'pyarrow', 'fastparquet'}, default 'auto'
Parquet library to use. Defaults to 'auto', which uses ``pyarrow`` if
it is installed, and falls back to ``fastparquet`` otherwise.
use_nullable_dtypes : {False, True}
Whether to use extension dtypes for the resulting ``DataFrame``.
``use_nullable_dtypes=True`` is only supported when ``engine="pyarrow"``.

.. note::

Use the ``dataframe.dtype_backend`` config option to select which
dtype implementation to use.

``dataframe.dtype_backend="pandas"`` (the default) will use
pandas' ``numpy``-backed nullable dtypes (e.g. ``Int64``,
``string[python]``, etc.) while ``dataframe.dtype_backend="pyarrow"``
will use ``pyarrow``-backed extension dtypes (e.g. ``int64[pyarrow]``,
``string[pyarrow]``, etc.). ``dataframe.dtype_backend="pyarrow"``
requires ``pandas`` 1.5+.

calculate_divisions : bool, default False
Whether to use min/max statistics from the footer metadata (or global
``_metadata`` file) to calculate divisions for the output DataFrame
Expand Down Expand Up @@ -366,6 +382,9 @@ def read_parquet(
pyarrow.parquet.ParquetDataset
"""

if use_nullable_dtypes:
use_nullable_dtypes = dask.config.get("dataframe.dtype_backend")

# "Pre-deprecation" warning for `chunksize`
if chunksize:
warnings.warn(
Expand Down Expand Up @@ -552,7 +571,7 @@ def read_parquet(
if "retries" not in annotations and not _is_local_fs(fs):
ctx = dask.annotate(retries=5)
else:
ctx = contextlib.nullcontext()
ctx = contextlib.nullcontext() # type: ignore

with ctx:
# Construct the output collection with from_map
Expand Down
68 changes: 48 additions & 20 deletions dask/dataframe/io/tests/test_parquet.py
Expand Up @@ -15,7 +15,12 @@
import dask.dataframe as dd
import dask.multiprocessing
from dask.blockwise import Blockwise, optimize_blockwise
from dask.dataframe._compat import PANDAS_GT_110, PANDAS_GT_121, PANDAS_GT_130
from dask.dataframe._compat import (
PANDAS_GT_110,
PANDAS_GT_121,
PANDAS_GT_130,
PANDAS_GT_150,
)
from dask.dataframe.io.parquet.core import get_engine
from dask.dataframe.io.parquet.utils import _parse_pandas_metadata
from dask.dataframe.optimize import optimize_dataframe_getitem
Expand Down Expand Up @@ -618,17 +623,37 @@ def test_roundtrip_nullable_dtypes(tmp_path, write_engine, read_engine):


@PYARROW_MARK
def test_use_nullable_dtypes(tmp_path, engine):
@pytest.mark.parametrize(
"dtype_backend",
[
"pandas",
pytest.param(
"pyarrow",
marks=pytest.mark.skipif(
not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes"
),
),
],
)
def test_use_nullable_dtypes(tmp_path, engine, dtype_backend):
"""
Test reading a parquet file without pandas metadata,
but forcing use of nullable dtypes where appropriate
"""

if dtype_backend == "pandas":
dtype_extra = ""
else:
# dtype_backend == "pyarrow"
dtype_extra = "[pyarrow]"
df = pd.DataFrame(
{
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"),
"b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"),
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype=f"Int64{dtype_extra}"),
"b": pd.Series(
[True, pd.NA, False, True, False], dtype=f"boolean{dtype_extra}"
),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype=f"Float64{dtype_extra}"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype=f"string{dtype_extra}"),
}
)
ddf = dd.from_pandas(df, npartitions=2)
Expand All @@ -644,21 +669,24 @@ def write_partition(df, i):
partitions = ddf.to_delayed()
dask.compute([write_partition(p, i) for i, p in enumerate(partitions)])

# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(ValueError, match="`use_nullable_dtypes` is not supported"):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
with dask.config.set({"dataframe.dtype_backend": dtype_backend}):
# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(
ValueError, match="`use_nullable_dtypes` is not supported"
):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)

# Works in pyarrow
else:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(tmp_path, engine=engine)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)
# Works in pyarrow
else:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(tmp_path, engine=engine)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)


@pytest.mark.xfail(
Expand Down
45 changes: 44 additions & 1 deletion dask/tests/test_spark_compat.py
@@ -1,19 +1,22 @@
import decimal
import signal
import sys
import threading

import pytest

import dask
from dask.datasets import timeseries

dd = pytest.importorskip("dask.dataframe")
pyspark = pytest.importorskip("pyspark")
pytest.importorskip("pyarrow")
pa = pytest.importorskip("pyarrow")
pytest.importorskip("fastparquet")

import numpy as np
import pandas as pd

from dask.dataframe._compat import PANDAS_GT_150
from dask.dataframe.utils import assert_eq

pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -149,3 +152,43 @@ def test_roundtrip_parquet_spark_to_dask_extension_dtypes(spark_session, tmpdir)
[pd.api.types.is_extension_array_dtype(dtype) for dtype in ddf.dtypes]
), ddf.dtypes
assert_eq(ddf, pdf, check_index=False)


@pytest.mark.skipif(not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes")
def test_read_decimal_dtype_pyarrow(spark_session, tmpdir):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One additional benefit of adding support for pyarrow dtypes is that we actually end up getting better Spark interoperability. For example, I ran into a user group offline who were using Spark with decimal type data. When they tried to read in the corresponding Spark-written Parquet dataset, Dask would end up converting them to object. With this PR we can now use dask.config.set({"dataframe.nullable_backend": "pyarrow"}) to read that data in backed by pyarrow's decimal128 type.

Anyways, that's the context around this test

tmpdir = str(tmpdir)
npartitions = 3
size = 6

decimal_data = [
decimal.Decimal("8093.234"),
decimal.Decimal("8094.234"),
decimal.Decimal("8095.234"),
decimal.Decimal("8096.234"),
decimal.Decimal("8097.234"),
decimal.Decimal("8098.234"),
]
pdf = pd.DataFrame(
{
"a": range(size),
"b": decimal_data,
}
)
sdf = spark_session.createDataFrame(pdf)
sdf = sdf.withColumn("b", sdf["b"].cast(pyspark.sql.types.DecimalType(7, 3)))
# We are not overwriting any data, but spark complains if the directory
# already exists (as tmpdir does) and we don't set overwrite
sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite")

with dask.config.set({"dataframe.dtype_backend": "pyarrow"}):
ddf = dd.read_parquet(tmpdir, engine="pyarrow", use_nullable_dtypes=True)
assert ddf.b.dtype.pyarrow_dtype == pa.decimal128(7, 3)
assert ddf.b.compute().dtype.pyarrow_dtype == pa.decimal128(7, 3)
expected = pdf.astype(
{
"a": "int64[pyarrow]",
"b": pd.ArrowDtype(pa.decimal128(7, 3)),
}
)

assert_eq(ddf, expected, check_index=False)