Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dtype_backend="pandas|pyarrow" configuration #9719

Merged
merged 11 commits into from Dec 16, 2022
8 changes: 8 additions & 0 deletions dask/dask-schema.yaml
Expand Up @@ -72,6 +72,14 @@ properties:
task when reading a parquet dataset from a REMOTE file system.
Specifying 0 will result in serial execution on the client.

nullable_backend:
enum:
- pandas
- pyarrow
description: |
The nullable dtype implementation to use. Must be either "pandas" or
"pyarrow". Default is "pandas".

array:
type: object
properties:
Expand Down
1 change: 1 addition & 0 deletions dask/dask.yaml
Expand Up @@ -12,6 +12,7 @@ dataframe:
parquet:
metadata-task-size-local: 512 # Number of files per local metadata-processing task
metadata-task-size-remote: 16 # Number of files per remote metadata-processing task
nullable_backend: "pandas" # Nullable dtype implementation to use
Copy link
Member

@rjzamora rjzamora Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you expect this option (and its default) to interact with the corresponding pandas 2.0 config option? When pandas-2 is released, should the default just correspond to whatever the pandas default is?

For example, it would be nice if we were able to use a test like this for pandas-2:

with pd.option_context("io.nullable_backend", "pyarrow"):
    df = pd.read_parquet("tmp.parquet", engine="pyarrow", use_nullable_dtypes=True)
    ddf = dd.read_parquet("tmp.parquet", engine="pyarrow", use_nullable_dtypes=True)
assert_eq(dd, ddf)

Does client vs worker config options make this a challenge?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does client vs worker config options make this a challenge?

It's something we'll definitely need to account for. My guess is the most pleasant user experience will be if we pull the corresponding config value on the client and then embed it into the task graph (like we're doing in this PR). That way users won't need to worry about setting config options on the workers. Regardless, I suspect the implementation will be the same whether we pull the pandas or dask config option (see #9711 for an example).

The downside to supporting pandas config options is that we wouldn't support all the config options. We could explicitly document which ones we do support, and when, but still might be a source of confusion.

Either way, I think this is a good question to ask. But I'm not too concerned because there is a smooth path in either direction. If we don't support the pandas option, then no changes are needed. If we do, then we can either update the default for the dask config value to pull in the current pandas option, or we deprecate the dask config value altogether.


array:
backend: "numpy" # Backend array library for input IO and data creation
Expand Down
22 changes: 17 additions & 5 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -1580,19 +1580,31 @@ def _arrow_table_to_pandas(
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
# Determine is `pandas` or `pyarrow`-backed dtypes should be used
if use_nullable_dtypes == "pandas":
default_types_mapper = PYARROW_NULLABLE_DTYPE_MAPPING.get
else:
# use_nullable_dtypes == "pyarrow"

def default_types_mapper(pyarrow_dtype): # type: ignore
# Special case pyarrow strings to use more feature complete dtype
# See https://github.com/pandas-dev/pandas/issues/50074
if pyarrow_dtype == pa.string():
return pd.StringDtype("pyarrow")
else:
return pd.ArrowDtype(pyarrow_dtype)

if "types_mapper" in _kwargs:
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
# User-provided entries take priority over default_types_mapper
types_mapper = _kwargs["types_mapper"]

def _types_mapper(pa_type):
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
pa_type
)
return types_mapper(pa_type) or default_types_mapper(pa_type)

_kwargs["types_mapper"] = _types_mapper

else:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get
_kwargs["types_mapper"] = default_types_mapper

return arrow_table.to_pandas(categories=categories, **_kwargs)

Expand Down
23 changes: 21 additions & 2 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -185,7 +185,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
use_nullable_dtypes: bool = False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -257,6 +257,22 @@ def read_parquet(
engine : {'auto', 'pyarrow', 'fastparquet'}, default 'auto'
Parquet library to use. Defaults to 'auto', which uses ``pyarrow`` if
it is installed, and falls back to ``fastparquet`` otherwise.
use_nullable_dtypes : {False, True}
Whether to use extension dtypes for the resulting ``DataFrame``.
``use_nullable_dtypes=True`` is only supported when ``engine="pyarrow"``.

.. note::

Use the ``dataframe.nullable_backend`` config option to select which
dtype implementation to use.

``dataframe.nullable_backend="pandas"`` (the default) will use
pandas' ``numpy``-backed nullable dtypes (e.g. ``Int64``,
``string[python]``, etc.) while ``dataframe.nullable_backend="pyarrow"``
will use ``pyarrow``-backed extension dtypes (e.g. ``int64[pyarrow]``,
``string[pyarrow]``, etc.). ``dataframe.nullable_backend="pyarrow"``
requires ``pandas`` 1.5+.

calculate_divisions : bool, default False
Whether to use min/max statistics from the footer metadata (or global
``_metadata`` file) to calculate divisions for the output DataFrame
Expand Down Expand Up @@ -366,6 +382,9 @@ def read_parquet(
pyarrow.parquet.ParquetDataset
"""

if use_nullable_dtypes:
use_nullable_dtypes = dask.config.get("dataframe.nullable_backend")

# "Pre-deprecation" warning for `chunksize`
if chunksize:
warnings.warn(
Expand Down Expand Up @@ -552,7 +571,7 @@ def read_parquet(
if "retries" not in annotations and not _is_local_fs(fs):
ctx = dask.annotate(retries=5)
else:
ctx = contextlib.nullcontext()
ctx = contextlib.nullcontext() # type: ignore

with ctx:
# Construct the output collection with from_map
Expand Down
68 changes: 48 additions & 20 deletions dask/dataframe/io/tests/test_parquet.py
Expand Up @@ -15,7 +15,12 @@
import dask.dataframe as dd
import dask.multiprocessing
from dask.blockwise import Blockwise, optimize_blockwise
from dask.dataframe._compat import PANDAS_GT_110, PANDAS_GT_121, PANDAS_GT_130
from dask.dataframe._compat import (
PANDAS_GT_110,
PANDAS_GT_121,
PANDAS_GT_130,
PANDAS_GT_150,
)
from dask.dataframe.io.parquet.core import get_engine
from dask.dataframe.io.parquet.utils import _parse_pandas_metadata
from dask.dataframe.optimize import optimize_dataframe_getitem
Expand Down Expand Up @@ -618,17 +623,37 @@ def test_roundtrip_nullable_dtypes(tmp_path, write_engine, read_engine):


@PYARROW_MARK
def test_use_nullable_dtypes(tmp_path, engine):
@pytest.mark.parametrize(
"nullable_backend",
[
"pandas",
pytest.param(
"pyarrow",
marks=pytest.mark.skipif(
not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes"
),
),
],
)
def test_use_nullable_dtypes(tmp_path, engine, nullable_backend):
"""
Test reading a parquet file without pandas metadata,
but forcing use of nullable dtypes where appropriate
"""

if nullable_backend == "pandas":
dtype_extra = ""
else:
# nullable_backend == "pyarrow"
dtype_extra = "[pyarrow]"
df = pd.DataFrame(
{
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"),
"b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"),
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype=f"Int64{dtype_extra}"),
"b": pd.Series(
[True, pd.NA, False, True, False], dtype=f"boolean{dtype_extra}"
),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype=f"Float64{dtype_extra}"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype=f"string{dtype_extra}"),
}
)
ddf = dd.from_pandas(df, npartitions=2)
Expand All @@ -644,21 +669,24 @@ def write_partition(df, i):
partitions = ddf.to_delayed()
dask.compute([write_partition(p, i) for i, p in enumerate(partitions)])

# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(ValueError, match="`use_nullable_dtypes` is not supported"):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
with dask.config.set({"dataframe.nullable_backend": nullable_backend}):
# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(
ValueError, match="`use_nullable_dtypes` is not supported"
):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)

# Works in pyarrow
else:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(tmp_path, engine=engine)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)
# Works in pyarrow
else:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(tmp_path, engine=engine)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)


@pytest.mark.xfail(
Expand Down
45 changes: 44 additions & 1 deletion dask/tests/test_spark_compat.py
@@ -1,19 +1,22 @@
import decimal
import signal
import sys
import threading

import pytest

import dask
from dask.datasets import timeseries

dd = pytest.importorskip("dask.dataframe")
pyspark = pytest.importorskip("pyspark")
pytest.importorskip("pyarrow")
pa = pytest.importorskip("pyarrow")
pytest.importorskip("fastparquet")

import numpy as np
import pandas as pd

from dask.dataframe._compat import PANDAS_GT_150
from dask.dataframe.utils import assert_eq

pytestmark = pytest.mark.skipif(
Expand Down Expand Up @@ -149,3 +152,43 @@ def test_roundtrip_parquet_spark_to_dask_extension_dtypes(spark_session, tmpdir)
[pd.api.types.is_extension_array_dtype(dtype) for dtype in ddf.dtypes]
), ddf.dtypes
assert_eq(ddf, pdf, check_index=False)


@pytest.mark.skipif(not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes")
def test_read_decimal_dtype_pyarrow(spark_session, tmpdir):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One additional benefit of adding support for pyarrow dtypes is that we actually end up getting better Spark interoperability. For example, I ran into a user group offline who were using Spark with decimal type data. When they tried to read in the corresponding Spark-written Parquet dataset, Dask would end up converting them to object. With this PR we can now use dask.config.set({"dataframe.nullable_backend": "pyarrow"}) to read that data in backed by pyarrow's decimal128 type.

Anyways, that's the context around this test

tmpdir = str(tmpdir)
npartitions = 3
size = 6

decimal_data = [
decimal.Decimal("8093.234"),
decimal.Decimal("8094.234"),
decimal.Decimal("8095.234"),
decimal.Decimal("8096.234"),
decimal.Decimal("8097.234"),
decimal.Decimal("8098.234"),
]
pdf = pd.DataFrame(
{
"a": range(size),
"b": decimal_data,
}
)
sdf = spark_session.createDataFrame(pdf)
sdf = sdf.withColumn("b", sdf["b"].cast(pyspark.sql.types.DecimalType(7, 3)))
# We are not overwriting any data, but spark complains if the directory
# already exists (as tmpdir does) and we don't set overwrite
sdf.repartition(npartitions).write.parquet(tmpdir, mode="overwrite")

with dask.config.set({"dataframe.nullable_backend": "pyarrow"}):
ddf = dd.read_parquet(tmpdir, engine="pyarrow", use_nullable_dtypes=True)
assert ddf.b.dtype.pyarrow_dtype == pa.decimal128(7, 3)
assert ddf.b.compute().dtype.pyarrow_dtype == pa.decimal128(7, 3)
expected = pdf.astype(
{
"a": "int64[pyarrow]",
"b": pd.ArrowDtype(pa.decimal128(7, 3)),
}
)

assert_eq(ddf, expected, check_index=False)