Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for use_nullable_dtypes to dd.read_parquet #9617

Merged
merged 14 commits into from Dec 1, 2022
46 changes: 42 additions & 4 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -11,6 +11,7 @@

from dask.base import tokenize
from dask.core import flatten
from dask.dataframe._compat import PANDAS_GT_120
from dask.dataframe.backends import pyarrow_schema_dispatch
from dask.dataframe.io.parquet.utils import (
Engine,
Expand All @@ -37,6 +38,23 @@
partitioning_supported = _pa_version >= parse_version("5.0.0")
del _pa_version

PYARROW_NULLABLE_DTYPE_MAPPING = {
pa.int8(): pd.Int8Dtype(),
pa.int16(): pd.Int16Dtype(),
pa.int32(): pd.Int32Dtype(),
pa.int64(): pd.Int64Dtype(),
pa.uint8(): pd.UInt8Dtype(),
pa.uint16(): pd.UInt16Dtype(),
pa.uint32(): pd.UInt32Dtype(),
pa.uint64(): pd.UInt64Dtype(),
pa.bool_(): pd.BooleanDtype(),
pa.string(): pd.StringDtype(),
ian-r-rose marked this conversation as resolved.
Show resolved Hide resolved
}

if PANDAS_GT_120:
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float32()] = pd.Float32Dtype()
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float64()] = pd.Float64Dtype()

#
# Helper Utilities
#
Expand Down Expand Up @@ -321,6 +339,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand Down Expand Up @@ -350,7 +369,7 @@ def read_metadata(
)

# Stage 2: Generate output `meta`
meta = cls._create_dd_meta(dataset_info)
meta = cls._create_dd_meta(dataset_info, use_nullable_dtypes)

# Stage 3: Generate parts and stats
parts, stats, common_kwargs = cls._construct_collection_plan(dataset_info)
Expand All @@ -375,6 +394,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
partitions=(),
filters=None,
Expand Down Expand Up @@ -445,7 +465,9 @@ def read_partition(
arrow_table = pa.concat_tables(tables)

# Convert to pandas
df = cls._arrow_table_to_pandas(arrow_table, categories, **kwargs)
df = cls._arrow_table_to_pandas(
arrow_table, categories, use_nullable_dtypes=use_nullable_dtypes, **kwargs
)

# For pyarrow.dataset api, need to convert partition columns
# to categorigal manually for integer types.
Expand Down Expand Up @@ -958,7 +980,7 @@ def _collect_dataset_info(
}

@classmethod
def _create_dd_meta(cls, dataset_info):
def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
"""Use parquet schema and hive-partition information
(stored in dataset_info) to construct DataFrame metadata.
"""
Expand Down Expand Up @@ -989,6 +1011,7 @@ def _create_dd_meta(cls, dataset_info):
schema.empty_table(),
categories,
arrow_to_pandas=arrow_to_pandas,
use_nullable_dtypes=use_nullable_dtypes,
)
index_names = list(meta.index.names)
column_names = list(meta.columns)
Expand Down Expand Up @@ -1543,11 +1566,26 @@ def _read_table(

@classmethod
def _arrow_table_to_pandas(
cls, arrow_table: pa.Table, categories, **kwargs
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
) -> pd.DataFrame:
_kwargs = kwargs.get("arrow_to_pandas", {})
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
if "types_mapper" in _kwargs:
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
types_mapper = _kwargs["types_mapper"]

def _types_mapper(pa_type):
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
pa_type
)

_kwargs["types_mapper"] = _types_mapper

else:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get

return arrow_table.to_pandas(categories=categories, **_kwargs)

@classmethod
Expand Down
35 changes: 31 additions & 4 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -45,6 +45,7 @@ def __init__(
meta,
columns,
index,
use_nullable_dtypes,
kwargs,
common_kwargs,
):
Expand All @@ -53,6 +54,7 @@ def __init__(
self.meta = meta
self._columns = columns
self.index = index
self.use_nullable_dtypes = use_nullable_dtypes

# `kwargs` = user-defined kwargs to be passed
# identically for all partitions.
Expand All @@ -78,6 +80,7 @@ def project_columns(self, columns):
self.meta,
columns,
self.index,
self.use_nullable_dtypes,
None, # Already merged into common_kwargs
self.common_kwargs,
)
Expand All @@ -101,6 +104,7 @@ def __call__(self, part):
],
self.columns,
self.index,
self.use_nullable_dtypes,
self.common_kwargs,
)

Expand Down Expand Up @@ -181,6 +185,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -433,6 +438,7 @@ def read_parquet(
"index": index,
"storage_options": storage_options,
"engine": engine,
"use_nullable_dtypes": use_nullable_dtypes,
"calculate_divisions": calculate_divisions,
"ignore_metadata_file": ignore_metadata_file,
"metadata_task_size": metadata_task_size,
Expand Down Expand Up @@ -475,6 +481,7 @@ def read_parquet(
paths,
categories=categories,
index=index,
use_nullable_dtypes=use_nullable_dtypes,
gather_statistics=calculate_divisions,
filters=filters,
split_row_groups=split_row_groups,
Expand Down Expand Up @@ -540,6 +547,7 @@ def read_parquet(
meta,
columns,
index,
use_nullable_dtypes,
{}, # All kwargs should now be in `common_kwargs`
common_kwargs,
)
Expand Down Expand Up @@ -578,7 +586,9 @@ def check_multi_support(engine):
return hasattr(engine, "multi_support") and engine.multi_support()


def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
def read_parquet_part(
fs, engine, meta, part, columns, index, use_nullable_dtypes, kwargs
):
"""Read a part of a parquet dataset

This function is used by `read_parquet`."""
Expand All @@ -587,22 +597,39 @@ def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
# Part kwargs expected
func = engine.read_partition
dfs = [
func(fs, rg, columns.copy(), index, **toolz.merge(kwargs, kw))
func(
fs,
rg,
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, kw),
)
for (rg, kw) in part
]
df = concat(dfs, axis=0) if len(dfs) > 1 else dfs[0]
else:
# No part specific kwargs, let engine read
# list of parts at once
df = engine.read_partition(
fs, [p[0] for p in part], columns.copy(), index, **kwargs
fs,
[p[0] for p in part],
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**kwargs,
)
else:
# NOTE: `kwargs` are the same for all parts, while `part_kwargs` may
# be different for each part.
rg, part_kwargs = part
df = engine.read_partition(
fs, rg, columns, index, **toolz.merge(kwargs, part_kwargs)
fs,
rg,
columns,
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, part_kwargs),
)

if meta.columns.name:
Expand Down
6 changes: 6 additions & 0 deletions dask/dataframe/io/parquet/fastparquet.py
Expand Up @@ -821,6 +821,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand All @@ -831,6 +832,10 @@ def read_metadata(
parquet_file_extension=None,
**kwargs,
):
if use_nullable_dtypes:
raise ValueError(
"`use_nullable_dtypes` is not supported by the fastparquet engine"
)

# Stage 1: Collect general dataset information
dataset_info = cls._collect_dataset_info(
Expand Down Expand Up @@ -890,6 +895,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
root_cats=None,
root_file_scheme=None,
Expand Down
11 changes: 10 additions & 1 deletion dask/dataframe/io/parquet/utils.py
Expand Up @@ -17,6 +17,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
**kwargs,
Expand All @@ -37,6 +38,9 @@ def read_metadata(
The column name(s) to be used as the index.
If set to ``None``, pandas metadata (if available) can be used
to reset the value in this function
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
gather_statistics: bool
Whether or not to gather statistics to calculate divisions
for the output DataFrame collection.
Expand Down Expand Up @@ -73,7 +77,9 @@ def read_metadata(
raise NotImplementedError()

@classmethod
def read_partition(cls, fs, piece, columns, index, **kwargs):
def read_partition(
cls, fs, piece, columns, index, use_nullable_dtypes=False, **kwargs
):
"""Read a single piece of a Parquet dataset into a Pandas DataFrame

This function is called many times in individual tasks
Expand All @@ -88,6 +94,9 @@ def read_partition(cls, fs, piece, columns, index, **kwargs):
List of column names to pull out of that row group
index: str, List[str], or False
The index name(s).
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
**kwargs:
Includes `"kwargs"` values stored within the `parts` output
of `engine.read_metadata`. May also include arguments to be
Expand Down