Skip to content

Commit

Permalink
Add support for use_nullable_dtypes to dd.read_parquet (#9617)
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-r-rose committed Dec 1, 2022
1 parent f309f9f commit b1e468e
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 13 deletions.
46 changes: 42 additions & 4 deletions dask/dataframe/io/parquet/arrow.py
Expand Up @@ -11,6 +11,7 @@

from dask.base import tokenize
from dask.core import flatten
from dask.dataframe._compat import PANDAS_GT_120
from dask.dataframe.backends import pyarrow_schema_dispatch
from dask.dataframe.io.parquet.utils import (
Engine,
Expand All @@ -37,6 +38,23 @@
partitioning_supported = _pa_version >= parse_version("5.0.0")
del _pa_version

PYARROW_NULLABLE_DTYPE_MAPPING = {
pa.int8(): pd.Int8Dtype(),
pa.int16(): pd.Int16Dtype(),
pa.int32(): pd.Int32Dtype(),
pa.int64(): pd.Int64Dtype(),
pa.uint8(): pd.UInt8Dtype(),
pa.uint16(): pd.UInt16Dtype(),
pa.uint32(): pd.UInt32Dtype(),
pa.uint64(): pd.UInt64Dtype(),
pa.bool_(): pd.BooleanDtype(),
pa.string(): pd.StringDtype(),
}

if PANDAS_GT_120:
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float32()] = pd.Float32Dtype()
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float64()] = pd.Float64Dtype()

#
# Helper Utilities
#
Expand Down Expand Up @@ -321,6 +339,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand Down Expand Up @@ -350,7 +369,7 @@ def read_metadata(
)

# Stage 2: Generate output `meta`
meta = cls._create_dd_meta(dataset_info)
meta = cls._create_dd_meta(dataset_info, use_nullable_dtypes)

# Stage 3: Generate parts and stats
parts, stats, common_kwargs = cls._construct_collection_plan(dataset_info)
Expand All @@ -375,6 +394,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
partitions=(),
filters=None,
Expand Down Expand Up @@ -445,7 +465,9 @@ def read_partition(
arrow_table = pa.concat_tables(tables)

# Convert to pandas
df = cls._arrow_table_to_pandas(arrow_table, categories, **kwargs)
df = cls._arrow_table_to_pandas(
arrow_table, categories, use_nullable_dtypes=use_nullable_dtypes, **kwargs
)

# For pyarrow.dataset api, need to convert partition columns
# to categorigal manually for integer types.
Expand Down Expand Up @@ -958,7 +980,7 @@ def _collect_dataset_info(
}

@classmethod
def _create_dd_meta(cls, dataset_info):
def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
"""Use parquet schema and hive-partition information
(stored in dataset_info) to construct DataFrame metadata.
"""
Expand Down Expand Up @@ -989,6 +1011,7 @@ def _create_dd_meta(cls, dataset_info):
schema.empty_table(),
categories,
arrow_to_pandas=arrow_to_pandas,
use_nullable_dtypes=use_nullable_dtypes,
)
index_names = list(meta.index.names)
column_names = list(meta.columns)
Expand Down Expand Up @@ -1543,11 +1566,26 @@ def _read_table(

@classmethod
def _arrow_table_to_pandas(
cls, arrow_table: pa.Table, categories, **kwargs
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
) -> pd.DataFrame:
_kwargs = kwargs.get("arrow_to_pandas", {})
_kwargs.update({"use_threads": False, "ignore_metadata": False})

if use_nullable_dtypes:
if "types_mapper" in _kwargs:
# User-provided entries take priority over PYARROW_NULLABLE_DTYPE_MAPPING
types_mapper = _kwargs["types_mapper"]

def _types_mapper(pa_type):
return types_mapper(pa_type) or PYARROW_NULLABLE_DTYPE_MAPPING.get(
pa_type
)

_kwargs["types_mapper"] = _types_mapper

else:
_kwargs["types_mapper"] = PYARROW_NULLABLE_DTYPE_MAPPING.get

return arrow_table.to_pandas(categories=categories, **_kwargs)

@classmethod
Expand Down
35 changes: 31 additions & 4 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -45,6 +45,7 @@ def __init__(
meta,
columns,
index,
use_nullable_dtypes,
kwargs,
common_kwargs,
):
Expand All @@ -53,6 +54,7 @@ def __init__(
self.meta = meta
self._columns = columns
self.index = index
self.use_nullable_dtypes = use_nullable_dtypes

# `kwargs` = user-defined kwargs to be passed
# identically for all partitions.
Expand All @@ -78,6 +80,7 @@ def project_columns(self, columns):
self.meta,
columns,
self.index,
self.use_nullable_dtypes,
None, # Already merged into common_kwargs
self.common_kwargs,
)
Expand All @@ -101,6 +104,7 @@ def __call__(self, part):
],
self.columns,
self.index,
self.use_nullable_dtypes,
self.common_kwargs,
)

Expand Down Expand Up @@ -181,6 +185,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -433,6 +438,7 @@ def read_parquet(
"index": index,
"storage_options": storage_options,
"engine": engine,
"use_nullable_dtypes": use_nullable_dtypes,
"calculate_divisions": calculate_divisions,
"ignore_metadata_file": ignore_metadata_file,
"metadata_task_size": metadata_task_size,
Expand Down Expand Up @@ -475,6 +481,7 @@ def read_parquet(
paths,
categories=categories,
index=index,
use_nullable_dtypes=use_nullable_dtypes,
gather_statistics=calculate_divisions,
filters=filters,
split_row_groups=split_row_groups,
Expand Down Expand Up @@ -540,6 +547,7 @@ def read_parquet(
meta,
columns,
index,
use_nullable_dtypes,
{}, # All kwargs should now be in `common_kwargs`
common_kwargs,
)
Expand Down Expand Up @@ -578,7 +586,9 @@ def check_multi_support(engine):
return hasattr(engine, "multi_support") and engine.multi_support()


def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
def read_parquet_part(
fs, engine, meta, part, columns, index, use_nullable_dtypes, kwargs
):
"""Read a part of a parquet dataset
This function is used by `read_parquet`."""
Expand All @@ -587,22 +597,39 @@ def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
# Part kwargs expected
func = engine.read_partition
dfs = [
func(fs, rg, columns.copy(), index, **toolz.merge(kwargs, kw))
func(
fs,
rg,
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, kw),
)
for (rg, kw) in part
]
df = concat(dfs, axis=0) if len(dfs) > 1 else dfs[0]
else:
# No part specific kwargs, let engine read
# list of parts at once
df = engine.read_partition(
fs, [p[0] for p in part], columns.copy(), index, **kwargs
fs,
[p[0] for p in part],
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**kwargs,
)
else:
# NOTE: `kwargs` are the same for all parts, while `part_kwargs` may
# be different for each part.
rg, part_kwargs = part
df = engine.read_partition(
fs, rg, columns, index, **toolz.merge(kwargs, part_kwargs)
fs,
rg,
columns,
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, part_kwargs),
)

if meta.columns.name:
Expand Down
6 changes: 6 additions & 0 deletions dask/dataframe/io/parquet/fastparquet.py
Expand Up @@ -821,6 +821,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand All @@ -831,6 +832,10 @@ def read_metadata(
parquet_file_extension=None,
**kwargs,
):
if use_nullable_dtypes:
raise ValueError(
"`use_nullable_dtypes` is not supported by the fastparquet engine"
)

# Stage 1: Collect general dataset information
dataset_info = cls._collect_dataset_info(
Expand Down Expand Up @@ -890,6 +895,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
root_cats=None,
root_file_scheme=None,
Expand Down
11 changes: 10 additions & 1 deletion dask/dataframe/io/parquet/utils.py
Expand Up @@ -17,6 +17,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
**kwargs,
Expand All @@ -37,6 +38,9 @@ def read_metadata(
The column name(s) to be used as the index.
If set to ``None``, pandas metadata (if available) can be used
to reset the value in this function
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
gather_statistics: bool
Whether or not to gather statistics to calculate divisions
for the output DataFrame collection.
Expand Down Expand Up @@ -73,7 +77,9 @@ def read_metadata(
raise NotImplementedError()

@classmethod
def read_partition(cls, fs, piece, columns, index, **kwargs):
def read_partition(
cls, fs, piece, columns, index, use_nullable_dtypes=False, **kwargs
):
"""Read a single piece of a Parquet dataset into a Pandas DataFrame
This function is called many times in individual tasks
Expand All @@ -88,6 +94,9 @@ def read_partition(cls, fs, piece, columns, index, **kwargs):
List of column names to pull out of that row group
index: str, List[str], or False
The index name(s).
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
**kwargs:
Includes `"kwargs"` values stored within the `parts` output
of `engine.read_metadata`. May also include arguments to be
Expand Down

0 comments on commit b1e468e

Please sign in to comment.