Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load nullables from dask #9632

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions dask/dask-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,22 @@ properties:
task when reading a parquet dataset from a REMOTE file system.
Specifying 0 will result in serial execution on the client.

dtypes:
type: object
properties:

nullable:
type: boolean

string:
type: object
properties:

storage:
type: string
description: |
The storage option for a Pandas StringDtype. One of python of pyarrow

array:
type: object
properties:
Expand Down
4 changes: 4 additions & 0 deletions dask/dask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ dataframe:
parquet:
metadata-task-size-local: 512 # Number of files per local metadata-processing task
metadata-task-size-remote: 16 # Number of files per remote metadata-processing task
dtypes: # Configure use of pandas extension dtypes
nullable: False # Default behavior for Pandas extension dtypes
string:
storage: "python"

array:
backend: "numpy" # Backend array library for input IO and data creation
Expand Down
49 changes: 41 additions & 8 deletions dask/dataframe/io/parquet/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import pyarrow.parquet as pq
from packaging.version import parse as parse_version

from dask import config
from dask.base import tokenize
from dask.core import flatten
from dask.dataframe._compat import PANDAS_GT_120
from dask.dataframe.backends import pyarrow_schema_dispatch
from dask.dataframe.io.parquet.utils import (
Engine,
Expand Down Expand Up @@ -43,6 +45,28 @@
partitioning_supported = _pa_version >= parse_version("5.0.0")
del _pa_version

PYARROW_NULLABLE_DTYPE_MAPPING = {
pa.int8(): pd.Int8Dtype(),
pa.int16(): pd.Int16Dtype(),
pa.int32(): pd.Int32Dtype(),
pa.int64(): pd.Int64Dtype(),
pa.uint8(): pd.UInt8Dtype(),
pa.uint16(): pd.UInt16Dtype(),
pa.uint32(): pd.UInt32Dtype(),
pa.uint64(): pd.UInt64Dtype(),
pa.bool_(): pd.BooleanDtype(),
pa.string(): pd.StringDtype(
storage=config.get("dataframe.parquet.dtypes.string.storage")
),
"null": pd.StringDtype(
storage="python"
), # null values can not be stored with pyarrow
}

if PANDAS_GT_120:
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float32()] = pd.Float32Dtype()
PYARROW_NULLABLE_DTYPE_MAPPING[pa.float64()] = pd.Float64Dtype()

#
# Helper Utilities
#
Expand Down Expand Up @@ -177,7 +201,6 @@ def _frag_subset(old_frag, row_groups):

def _get_pandas_metadata(schema):
"""Get pandas-specific metadata from schema."""

has_pandas_metadata = schema.metadata is not None and b"pandas" in schema.metadata
if has_pandas_metadata:
return json.loads(schema.metadata[b"pandas"].decode("utf8"))
Expand Down Expand Up @@ -327,6 +350,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand Down Expand Up @@ -356,7 +380,7 @@ def read_metadata(
)

# Stage 2: Generate output `meta`
meta = cls._create_dd_meta(dataset_info)
meta = cls._create_dd_meta(dataset_info, use_nullable_dtypes)

# Stage 3: Generate parts and stats
parts, stats, common_kwargs = cls._construct_collection_plan(dataset_info)
Expand All @@ -381,6 +405,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
partitions=(),
filters=None,
Expand Down Expand Up @@ -451,7 +476,9 @@ def read_partition(
arrow_table = pa.concat_tables(tables)

# Convert to pandas
df = cls._arrow_table_to_pandas(arrow_table, categories, **kwargs)
df = cls._arrow_table_to_pandas(
arrow_table, categories, use_nullable_dtypes=use_nullable_dtypes, **kwargs
)

# For pyarrow.dataset api, need to convert partition columns
# to categorigal manually for integer types.
Expand Down Expand Up @@ -964,7 +991,7 @@ def _collect_dataset_info(
}

@classmethod
def _create_dd_meta(cls, dataset_info):
def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False):
"""Use parquet schema and hive-partition information
(stored in dataset_info) to construct DataFrame metadata.
"""
Expand Down Expand Up @@ -1036,7 +1063,7 @@ def _create_dd_meta(cls, dataset_info):
"categories: {} | columns: {}".format(categories, list(all_columns))
)

dtypes = _get_pyarrow_dtypes(schema, categories)
dtypes = _get_pyarrow_dtypes(schema, categories, use_nullable_dtypes)
dtypes = {storage_name_mapping.get(k, k): v for k, v in dtypes.items()}

index_cols = index or ()
Expand Down Expand Up @@ -1547,12 +1574,18 @@ def _read_table(

@classmethod
def _arrow_table_to_pandas(
cls, arrow_table: pa.Table, categories, **kwargs
cls, arrow_table: pa.Table, categories, use_nullable_dtypes=False, **kwargs
) -> pd.DataFrame:
_kwargs = kwargs.get("arrow_to_pandas", {})
_kwargs.update({"use_threads": False, "ignore_metadata": False})

return arrow_table.to_pandas(categories=categories, **_kwargs)
if use_nullable_dtypes is True:
return arrow_table.to_pandas(
categories=categories,
types_mapper=PYARROW_NULLABLE_DTYPE_MAPPING.get,
**_kwargs,
)
else:
return arrow_table.to_pandas(categories=categories, **_kwargs)

@classmethod
def collect_file_metadata(cls, path, fs, file_path):
Expand Down
35 changes: 31 additions & 4 deletions dask/dataframe/io/parquet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
meta,
columns,
index,
use_nullable_dtypes,
kwargs,
common_kwargs,
):
Expand All @@ -53,6 +54,7 @@ def __init__(
self.meta = meta
self._columns = columns
self.index = index
self.use_nullable_dtypes = use_nullable_dtypes

# `kwargs` = user-defined kwargs to be passed
# identically for all partitions.
Expand All @@ -78,6 +80,7 @@ def project_columns(self, columns):
self.meta,
columns,
self.index,
self.use_nullable_dtypes,
None, # Already merged into common_kwargs
self.common_kwargs,
)
Expand All @@ -101,6 +104,7 @@ def __call__(self, part):
],
self.columns,
self.index,
self.use_nullable_dtypes,
self.common_kwargs,
)

Expand Down Expand Up @@ -181,6 +185,7 @@ def read_parquet(
index=None,
storage_options=None,
engine="auto",
use_nullable_dtypes=False,
calculate_divisions=None,
ignore_metadata_file=False,
metadata_task_size=None,
Expand Down Expand Up @@ -433,6 +438,7 @@ def read_parquet(
"index": index,
"storage_options": storage_options,
"engine": engine,
"use_nullable_dtypes": use_nullable_dtypes,
"calculate_divisions": calculate_divisions,
"ignore_metadata_file": ignore_metadata_file,
"metadata_task_size": metadata_task_size,
Expand Down Expand Up @@ -475,6 +481,7 @@ def read_parquet(
paths,
categories=categories,
index=index,
use_nullable_dtypes=use_nullable_dtypes,
gather_statistics=calculate_divisions,
filters=filters,
split_row_groups=split_row_groups,
Expand Down Expand Up @@ -540,6 +547,7 @@ def read_parquet(
meta,
columns,
index,
use_nullable_dtypes,
{}, # All kwargs should now be in `common_kwargs`
common_kwargs,
)
Expand Down Expand Up @@ -578,7 +586,9 @@ def check_multi_support(engine):
return hasattr(engine, "multi_support") and engine.multi_support()


def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
def read_parquet_part(
fs, engine, meta, part, columns, index, use_nullable_dtypes, kwargs
):
"""Read a part of a parquet dataset

This function is used by `read_parquet`."""
Expand All @@ -587,22 +597,39 @@ def read_parquet_part(fs, engine, meta, part, columns, index, kwargs):
# Part kwargs expected
func = engine.read_partition
dfs = [
func(fs, rg, columns.copy(), index, **toolz.merge(kwargs, kw))
func(
fs,
rg,
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, kw),
)
for (rg, kw) in part
]
df = concat(dfs, axis=0) if len(dfs) > 1 else dfs[0]
else:
# No part specific kwargs, let engine read
# list of parts at once
df = engine.read_partition(
fs, [p[0] for p in part], columns.copy(), index, **kwargs
fs,
[p[0] for p in part],
columns.copy(),
index,
use_nullable_dtypes=use_nullable_dtypes,
**kwargs,
)
else:
# NOTE: `kwargs` are the same for all parts, while `part_kwargs` may
# be different for each part.
rg, part_kwargs = part
df = engine.read_partition(
fs, rg, columns, index, **toolz.merge(kwargs, part_kwargs)
fs,
rg,
columns,
index,
use_nullable_dtypes=use_nullable_dtypes,
**toolz.merge(kwargs, part_kwargs),
)

if meta.columns.name:
Expand Down
6 changes: 6 additions & 0 deletions dask/dataframe/io/parquet/fastparquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
split_row_groups=False,
Expand All @@ -830,6 +831,10 @@ def read_metadata(
parquet_file_extension=None,
**kwargs,
):
if use_nullable_dtypes:
raise ValueError(
"`use_nullable_dtypes` is not supported by the fastparquet engine"
)

# Stage 1: Collect general dataset information
dataset_info = cls._collect_dataset_info(
Expand Down Expand Up @@ -889,6 +894,7 @@ def read_partition(
pieces,
columns,
index,
use_nullable_dtypes=False,
categories=(),
root_cats=None,
root_file_scheme=None,
Expand Down
11 changes: 10 additions & 1 deletion dask/dataframe/io/parquet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def read_metadata(
paths,
categories=None,
index=None,
use_nullable_dtypes=False,
gather_statistics=None,
filters=None,
**kwargs,
Expand All @@ -37,6 +38,9 @@ def read_metadata(
The column name(s) to be used as the index.
If set to ``None``, pandas metadata (if available) can be used
to reset the value in this function
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
gather_statistics: bool
Whether or not to gather statistics to calculate divisions
for the output DataFrame collection.
Expand Down Expand Up @@ -73,7 +77,9 @@ def read_metadata(
raise NotImplementedError()

@classmethod
def read_partition(cls, fs, piece, columns, index, **kwargs):
def read_partition(
cls, fs, piece, columns, index, use_nullable_dtypes=False, **kwargs
):
"""Read a single piece of a Parquet dataset into a Pandas DataFrame

This function is called many times in individual tasks
Expand All @@ -88,6 +94,9 @@ def read_partition(cls, fs, piece, columns, index, **kwargs):
List of column names to pull out of that row group
index: str, List[str], or False
The index name(s).
use_nullable_dtypes: boolean
Whether to use pandas nullable dtypes (like "string" or "Int64")
where appropriate when reading parquet files.
**kwargs:
Includes `"kwargs"` values stored within the `parts` output
of `engine.read_metadata`. May also include arguments to be
Expand Down