Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dtype_backend="pandas|pyarrow" configuration #9719

Merged
merged 11 commits into from Dec 16, 2022
21 changes: 16 additions & 5 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -1580,19 +1580,30 @@ def _arrow_table_to_pandas(
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
# Determine is `pandas` or `pyarrow`-backed dtypes should be used
if use_nullable_dtypes in ("pandas", True):
default_types_mapper = PYARROW_NULLABLE_DTYPE_MAPPING.get
elif use_nullable_dtypes == "pyarrow":

def default_types_mapper(pyarrow_dtype): # type: ignore
# Special case pyarrow strings to use more feature complete dtype
# See https://github.com/pandas-dev/pandas/issues/50074
if pyarrow_dtype == pa.string():
return pd.StringDtype("pyarrow")
else:
return pd.ArrowDtype(pyarrow_dtype)

if "types_mapper" in _kwargs:
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
# User-provided entries take priority over default_types_mapper
types_mapper = _kwargs["types_mapper"]

def _types_mapper(pa_type):
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
pa_type
)
return types_mapper(pa_type) or default_types_mapper(pa_type)

_kwargs["types_mapper"] = _types_mapper

else:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get
_kwargs["types_mapper"] = default_types_mapper

return arrow_table.to_pandas(categories=categories, **_kwargs)

Expand Down
22 changes: 20 additions & 2 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -3,6 +3,7 @@
import contextlib
import math
import warnings
from typing import Literal

import tlz as toolz
from fsspec.core import get_fs_token_paths
Expand Down Expand Up @@ -185,7 +186,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
use_nullable_dtypes: bool | Literal["pandas", "pyarrow"] = False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, @mroeschke this is the PR I mentioned offline about extending use_nullable_dtypes to support "pandas" and "pyarrow"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah this is pretty clean!

Our pandas issues pandas-dev/pandas#48957 (offline discussion happened here) and pandas-dev/pandas#49997 are examples where some discussion/preference of keeping use_nullable_dtypes boolean

calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -257,6 +258,17 @@ def read_parquet(
engine : {'auto', 'pyarrow', 'fastparquet'}, default 'auto'
Parquet library to use. Defaults to 'auto', which uses ``pyarrow`` if
it is installed, and falls back to ``fastparquet`` otherwise.
use_nullable_dtypes : {False, True, "pandas", "pyarrow"}
Whether to use dtypes that use ``pd.NA`` as a missing value indicator
for the resulting ``DataFrame``. ``True`` and ``"pandas"`` will use
pandas nullable dtypes (e.g. ``Int64``, ``string[python]``, etc.) while
``"pyarrow"`` will use ``pyarrow``-backed extension dtypes (e.g.
``int64[pyarrow]``, ``string[pyarrow]``, etc.).

.. note::
``use_nullable_dtypes`` is only supported when ``engine="pyarrow"``
and ``use_nullable_dtypes="pyarrow"`` requires ``pandas`` 1.5+.

calculate_divisions : bool, default False
Whether to use min/max statistics from the footer metadata (or global
``_metadata`` file) to calculate divisions for the output DataFrame
Expand Down Expand Up @@ -366,6 +378,12 @@ def read_parquet(
pyarrow.parquet.ParquetDataset
"""

if use_nullable_dtypes not in (True, False, "pandas", "pyarrow"):
raise ValueError(
"Invalid value for `use_nullable_dtypes` received. Expected `True`, `False`, "
f"`'pandas'`, or `'pyarrow'` but got {use_nullable_dtypes} instead."
)

# "Pre-deprecation" warning for `chunksize`
if chunksize:
warnings.warn(
Expand Down Expand Up @@ -552,7 +570,7 @@ def read_parquet(
if "retries" not in annotations and not _is_local_fs(fs):
ctx = dask.annotate(retries=5)
else:
ctx = contextlib.nullcontext()
ctx = contextlib.nullcontext() # type: ignore

with ctx:
# Construct the output collection with from_map
Expand Down
67 changes: 59 additions & 8 deletions dask/dataframe/io/tests/test_parquet.py
Expand Up @@ -15,7 +15,12 @@
import dask.dataframe as dd
import dask.multiprocessing
from dask.blockwise import Blockwise, optimize_blockwise
from dask.dataframe._compat import PANDAS_GT_110, PANDAS_GT_121, PANDAS_GT_130
from dask.dataframe._compat import (
PANDAS_GT_110,
PANDAS_GT_121,
PANDAS_GT_130,
PANDAS_GT_150,
)
from dask.dataframe.io.parquet.core import get_engine
from dask.dataframe.io.parquet.utils import _parse_pandas_metadata
from dask.dataframe.optimize import optimize_dataframe_getitem
Expand Down Expand Up @@ -618,17 +623,41 @@ def test_roundtrip_nullable_dtypes(tmp_path, write_engine, read_engine):


@PYARROW_MARK
def test_use_nullable_dtypes(tmp_path, engine):
@pytest.mark.parametrize(
"use_nullable_dtypes",
[
True,
"pandas",
pytest.param(
"pyarrow",
marks=pytest.mark.skipif(
not PANDAS_GT_150, reason="Requires pyarrow-backed nullable dtypes"
),
),
],
)
def test_use_nullable_dtypes(tmp_path, engine, use_nullable_dtypes):
"""
Test reading a parquet file without pandas metadata,
but forcing use of nullable dtypes where appropriate
"""

if use_nullable_dtypes in (True, "pandas"):
nullable_backend = ""
else:
nullable_backend = "[pyarrow]"
df = pd.DataFrame(
{
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"),
"b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"),
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype=f"Int64{nullable_backend}"),
"b": pd.Series(
[True, pd.NA, False, True, False], dtype=f"boolean{nullable_backend}"
),
"c": pd.Series(
[0.1, 0.2, 0.3, pd.NA, 0.4], dtype=f"Float64{nullable_backend}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice, I actually didn't know this was case insensitive. (The Float64 is parsed by pyarrow)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, pyarrow converts everything to be lowercase here. It makes it useful for writing these types of tests where I want to easily switch between pandas- and pyarrow-backed extension dtypes. Though once pandas-dev/pandas#50094 lands and is released, I could see us using that too!

),
"d": pd.Series(
["a", "b", "c", "d", pd.NA], dtype=f"string{nullable_backend}"
),
}
)
ddf = dd.from_pandas(df, npartitions=2)
Expand All @@ -647,7 +676,9 @@ def write_partition(df, i):
# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(ValueError, match="`use_nullable_dtypes` is not supported"):
dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
dd.read_parquet(
tmp_path, engine=engine, use_nullable_dtypes=use_nullable_dtypes
)

# Works in pyarrow
else:
Expand All @@ -657,10 +688,30 @@ def write_partition(df, i):
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(tmp_path, engine=engine, use_nullable_dtypes=True)
ddf2 = dd.read_parquet(
tmp_path, engine=engine, use_nullable_dtypes=use_nullable_dtypes
)
assert_eq(df, ddf2, check_index=False)


def test_use_nullable_dtypes_raises(tmp_path, engine):
# Raise an informative error message when `use_nullable_dtypes` is invalid
df = pd.DataFrame({"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64")})
ddf = dd.from_pandas(df, npartitions=3)
ddf.to_parquet(tmp_path, engine=engine)

bad_use_nullable_dtypes = "not-a-valid-option"
with pytest.raises(ValueError) as excinfo:
dd.read_parquet(
tmp_path,
engine=engine,
use_nullable_dtypes=bad_use_nullable_dtypes,
)
msg = str(excinfo.value)
assert "Invalid value for `use_nullable_dtypes`" in msg
assert bad_use_nullable_dtypes in msg


@pytest.mark.xfail(
not PANDAS_GT_130,
reason=(
Expand Down