Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for use_nullable_dtypes to dd.read_parquet #9617

Merged
merged 14 commits into from Dec 1, 2022
35 changes: 30 additions & 5 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -11,6 +11,7 @@

from dask.base import tokenize
from dask.core import flatten
from dask.dataframe._compat import PANDAS_GT_120
from dask.dataframe.backends import pyarrow_schema_dispatch
from dask.dataframe.io.parquet.utils import (
Engine,
Expand Down Expand Up @@ -43,6 +44,23 @@
partitioning_supported = _pa_version >= parse_version("5.0.0")
del _pa_version

PYARROW_NULLABLE_DTYPE_MAPPING = {
pa.int8(): pd.Int8Dtype(),
pa.int16(): pd.Int16Dtype(),
pa.int32(): pd.Int32Dtype(),
pa.int64(): pd.Int64Dtype(),
pa.uint8(): pd.UInt8Dtype(),
pa.uint16(): pd.UInt16Dtype(),
pa.uint32(): pd.UInt32Dtype(),
pa.uint64(): pd.UInt64Dtype(),
pa.bool_(): pd.BooleanDtype(),
pa.string(): pd.StringDtype(),
ian-r-rose marked this conversation as resolved.
Show resolved Hide resolved
}

if PANDAS_GT_120:
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float32()] = pd.Float32Dtype()
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float64()] = pd.Float64Dtype()

#
# Helper Utilities
#
Expand Down Expand Up @@ -327,6 +345,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand Down Expand Up @@ -356,7 +375,7 @@ def read_metadata(
)

# Stage 2: Generate output `meta`
meta = cls._create_dd_meta(dataset_info)
meta = cls._create_dd_meta(dataset_info, use_nullable_dtypes)

# Stage 3: Generate parts and stats
parts, stats, common_kwargs = cls._construct_collection_plan(dataset_info)
Expand All @@ -381,6 +400,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
partitions=(),
filters=None,
Expand Down Expand Up @@ -451,7 +471,9 @@ def read_partition(
arrow_table = pa.concat_tables(tables)

# Convert to pandas
df = cls._arrow_table_to_pandas(arrow_table, categories, **kwargs)
df = cls._arrow_table_to_pandas(
arrow_table, categories, use_nullable_dtypes=use_nullable_dtypes, **kwargs
)

# For pyarrow.dataset api, need to convert partition columns
# to categorigal manually for integer types.
Expand Down Expand Up @@ -964,7 +986,7 @@ def _collect_dataset_info(
}

@classmethod
def _create_dd_meta(cls, dataset_info):
def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
"""Use parquet schema and hive-partition information
(stored in dataset_info) to construct DataFrame metadata.
"""
Expand Down Expand Up @@ -1036,7 +1058,7 @@ def _create_dd_meta(cls, dataset_info):
"categories: {} | columns: {}".format(categories, list(all_columns))
)

dtypes = _get_pyarrow_dtypes(schema, categories)
dtypes = _get_pyarrow_dtypes(schema, categories, use_nullable_dtypes)
dtypes = {storage_name_mapping.get(k, k): v for k, v in dtypes.items()}

index_cols = index or ()
Expand Down Expand Up @@ -1547,11 +1569,14 @@ def _read_table(

@classmethod
def _arrow_table_to_pandas(
cls, arrow_table: pa.Table, categories, **kwargs
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
) -> pd.DataFrame:
_kwargs = kwargs.get("arrow_to_pandas", {})
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of an FYI if there is future appetite to get back a pandas DataFrame with any pyarrow type, I think to_pandas(..., type_mapper=...) would go from arrow -> numpy -> arrow.

To avoid this conversion, I essentially split the pa.Table into pa.ChunkedArrays and stuck them into each column with as a pd.ArrowExtensionArray: https://github.com/pandas-dev/pandas/pull/49039/files#diff-868f7f48a0ed35429e240d9be0b98ad9303ceb2a7771b5bd21390eca332b0da4R267


return arrow_table.to_pandas(categories=categories, **_kwargs)

@classmethod
Expand Down
35 changes: 31 additions & 4 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -45,6 +45,7 @@ def __init__(
meta,
columns,
index,
use_nullable_dtypes,
kwargs,
common_kwargs,
):
Expand All @@ -53,6 +54,7 @@ def __init__(
self.meta = meta
self._columns = columns
self.index = index
self.use_nullable_dtypes = use_nullable_dtypes

# `kwargs` = user-defined kwargs to be passed
# identically for all partitions.
Expand All @@ -78,6 +80,7 @@ def project_columns(self, columns):
self.meta,
columns,
self.index,
self.use_nullable_dtypes,
None, # Already merged into common_kwargs
self.common_kwargs,
)
Expand All @@ -101,6 +104,7 @@ def __call__(self, part):
],
self.columns,
self.index,
self.use_nullable_dtypes,
self.common_kwargs,
)

Expand Down Expand Up @@ -181,6 +185,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -433,6 +438,7 @@ def read_parquet(
"index": index,
"storage_options": storage_options,
"engine": engine,
"use_nullable_dtypes": use_nullable_dtypes,
"calculate_divisions": calculate_divisions,
"ignore_metadata_file": ignore_metadata_file,
"metadata_task_size": metadata_task_size,
Expand Down Expand Up @@ -475,6 +481,7 @@ def read_parquet(
paths,
categories=categories,
index=index,
use_nullable_dtypes=use_nullable_dtypes,
gather_statistics=calculate_divisions,
filters=filters,
split_row_groups=split_row_groups,
Expand Down Expand Up @@ -540,6 +547,7 @@ def read_parquet(
meta,
columns,
index,
use_nullable_dtypes,
{}, # All kwargs should now be in `common_kwargs`
common_kwargs,
)
Expand Down Expand Up @@ -578,7 +586,9 @@ def check_multi_support(engine):
return hasattr(engine, "multi_support") and engine.multi_support()


def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
def read_parquet_part(
fs, engine, meta, part, columns, index, use_nullable_dtypes, kwargs
):
"""Read a part of a parquet dataset

This function is used by `read_parquet`."""
Expand All @@ -587,22 +597,39 @@ def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
# Part kwargs expected
func = engine.read_partition
dfs = [
func(fs, rg, columns.copy(), index, **toolz.merge(kwargs, kw))
func(
fs,
rg,
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, kw),
)
for (rg, kw) in part
]
df = concat(dfs, axis=0) if len(dfs) > 1 else dfs[0]
else:
# No part specific kwargs, let engine read
# list of parts at once
df = engine.read_partition(
fs, [p[0] for p in part], columns.copy(), index, **kwargs
fs,
[p[0] for p in part],
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**kwargs,
)
else:
# NOTE: `kwargs` are the same for all parts, while `part_kwargs` may
# be different for each part.
rg, part_kwargs = part
df = engine.read_partition(
fs, rg, columns, index, **toolz.merge(kwargs, part_kwargs)
fs,
rg,
columns,
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, part_kwargs),
)

if meta.columns.name:
Expand Down
6 changes: 6 additions & 0 deletions dask/dataframe/io/parquet/fastparquet.py
Expand Up @@ -820,6 +820,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand All @@ -830,6 +831,10 @@ def read_metadata(
parquet_file_extension=None,
**kwargs,
):
if use_nullable_dtypes:
raise ValueError(
"`use_nullable_dtypes` is not supported by the fastparquet engine"
)

# Stage 1: Collect general dataset information
dataset_info = cls._collect_dataset_info(
Expand Down Expand Up @@ -889,6 +894,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
root_cats=None,
root_file_scheme=None,
Expand Down
11 changes: 10 additions & 1 deletion dask/dataframe/io/parquet/utils.py
Expand Up @@ -17,6 +17,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
**kwargs,
Expand All @@ -37,6 +38,9 @@ def read_metadata(
The column name(s) to be used as the index.
If set to ``None``, pandas metadata (if available) can be used
to reset the value in this function
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
gather_statistics: bool
Whether or not to gather statistics to calculate divisions
for the output DataFrame collection.
Expand Down Expand Up @@ -73,7 +77,9 @@ def read_metadata(
raise NotImplementedError()

@classmethod
def read_partition(cls, fs, piece, columns, index, **kwargs):
def read_partition(
cls, fs, piece, columns, index, use_nullable_dtypes=False, **kwargs
):
"""Read a single piece of a Parquet dataset into a Pandas DataFrame

This function is called many times in individual tasks
Expand All @@ -88,6 +94,9 @@ def read_partition(cls, fs, piece, columns, index, **kwargs):
List of column names to pull out of that row group
index: str, List[str], or False
The index name(s).
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
**kwargs:
Includes `"kwargs"` values stored within the `parts` output
of `engine.read_metadata`. May also include arguments to be
Expand Down
79 changes: 77 additions & 2 deletions dask/dataframe/io/tests/test_parquet.py
Expand Up @@ -594,6 +594,81 @@ def test_roundtrip_from_pandas(tmpdir, write_engine, read_engine):
assert_eq(dfp, ddf)


@write_read_engines()
def test_roundtrip_nullable_dtypes(tmpdir, write_engine, read_engine):
"""
Test round-tripping nullable extension dtypes. Parquet engines will
typically add dtype metadata for this.
"""
if read_engine == "fastparquet" or write_engine == "fastparquet":
pytest.xfail("https://github.com/dask/fastparquet/issues/465")

fn = str(tmpdir.join("test.parquet"))
df = pd.DataFrame(
{
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"),
"b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"),
}
)
ddf = dd.from_pandas(df, npartitions=2)
ddf.to_parquet(
fn, engine="pyarrow" if write_engine.startswith("pyarrow") else "fastparquet"
)
ddf2 = dd.read_parquet(fn, engine=read_engine)
print(ddf2.dtypes)
assert_eq(df, ddf2)


@PYARROW_MARK
def test_use_nullable_dtypes(tmpdir, engine):
"""
Test reading a parquet file without pandas metadata,
but forcing use of nullable dtypes where appropriate
"""
fn = str(tmpdir)
df = pd.DataFrame(
{
"a": pd.Series([1, 2, pd.NA, 3, 4], dtype="Int64"),
"b": pd.Series([True, pd.NA, False, True, False], dtype="boolean"),
"c": pd.Series([0.1, 0.2, 0.3, pd.NA, 0.4], dtype="Float64"),
"d": pd.Series(["a", "b", "c", "d", pd.NA], dtype="string"),
}
)
ddf = dd.from_pandas(df, npartitions=2)

@dask.delayed
def write_partition(df, i):
"Write a parquet file without the pandas metadata"
import pyarrow as pa
import pyarrow.parquet as pq

table = pa.Table.from_pandas(df).replace_schema_metadata({})
pq.write_table(table, fn + f"/part.{i}.parquet")

# Create a pandas-metadata-free partitioned parquet. By default it will
# not read into nullable extension dtypes
partitions = ddf.to_delayed()
dask.compute([write_partition(p, i) for i, p in enumerate(partitions)])

# Not supported by fastparquet
if engine == "fastparquet":
with pytest.raises(ValueError, match="not supported"):
ddf2 = dd.read_parquet(fn, engine=engine, use_nullable_dtypes=True)

# Works in pyarrow
elif "arrow" in engine:
# Doesn't round-trip by default when we aren't using nullable dtypes
with pytest.raises(AssertionError):
ddf2 = dd.read_parquet(fn, engine=engine, use_nullable_dtypes=False)
assert_eq(df, ddf2)

# Round trip works when we use nullable dtypes
ddf2 = dd.read_parquet(fn, engine=engine, use_nullable_dtypes=True)
assert_eq(df, ddf2, check_index=False)


@write_read_engines()
def test_categorical(tmpdir, write_engine, read_engine):
tmp = str(tmpdir)
Expand Down Expand Up @@ -3119,11 +3194,11 @@ def clamp_arrow_datetimes(cls, arrow_table: pa.Table) -> pa.Table:

@classmethod
def _arrow_table_to_pandas(
cls, arrow_table: pa.Table, categories, **kwargs
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
) -> pd.DataFrame:
fixed_arrow_table = cls.clamp_arrow_datetimes(arrow_table)
return super()._arrow_table_to_pandas(
fixed_arrow_table, categories, **kwargs
fixed_arrow_table, categories, use_nullable_dtypes, **kwargs
)

# this should not fail, but instead produce timestamps that are in the valid range
Expand Down