diff --git a/dask/dataframe/_compat.py b/dask/dataframe/_compat.py index c4f74bef049..1e899d85eaa 100644 --- a/dask/dataframe/_compat.py +++ b/dask/dataframe/_compat.py @@ -4,18 +4,19 @@ import numpy as np import pandas as pd -from packaging.version import parse as parse_version - -PANDAS_VERSION = parse_version(pd.__version__) -PANDAS_GT_104 = PANDAS_VERSION >= parse_version("1.0.4") -PANDAS_GT_110 = PANDAS_VERSION >= parse_version("1.1.0") -PANDAS_GT_120 = PANDAS_VERSION >= parse_version("1.2.0") -PANDAS_GT_121 = PANDAS_VERSION >= parse_version("1.2.1") -PANDAS_GT_130 = PANDAS_VERSION >= parse_version("1.3.0") -PANDAS_GT_131 = PANDAS_VERSION >= parse_version("1.3.1") -PANDAS_GT_133 = PANDAS_VERSION >= parse_version("1.3.3") -PANDAS_GT_140 = PANDAS_VERSION >= parse_version("1.4.0") -PANDAS_GT_150 = PANDAS_VERSION >= parse_version("1.5.0") +from packaging.version import Version + +PANDAS_VERSION = Version(pd.__version__) +PANDAS_GT_104 = PANDAS_VERSION >= Version("1.0.4") +PANDAS_GT_110 = PANDAS_VERSION >= Version("1.1.0") +PANDAS_GT_120 = PANDAS_VERSION >= Version("1.2.0") +PANDAS_GT_121 = PANDAS_VERSION >= Version("1.2.1") +PANDAS_GT_130 = PANDAS_VERSION >= Version("1.3.0") +PANDAS_GT_131 = PANDAS_VERSION >= Version("1.3.1") +PANDAS_GT_133 = PANDAS_VERSION >= Version("1.3.3") +PANDAS_GT_140 = PANDAS_VERSION >= Version("1.4.0") +PANDAS_GT_150 = PANDAS_VERSION >= Version("1.5.0") +PANDAS_GT_200 = PANDAS_VERSION.major >= 2 import pandas.testing as tm diff --git a/dask/dataframe/_pyarrow_compat.py b/dask/dataframe/_pyarrow_compat.py index 8bf4cf0a291..844a818a3bf 100644 --- a/dask/dataframe/_pyarrow_compat.py +++ b/dask/dataframe/_pyarrow_compat.py @@ -1,7 +1,5 @@ import copyreg -import math -import numpy as np import pandas as pd try: @@ -9,6 +7,8 @@ except ImportError: pa = None +from dask.dataframe._compat import PANDAS_GT_130, PANDAS_GT_150, PANDAS_GT_200 + # Pickling of pyarrow arrays is effectively broken - pickling a slice of an # array ends up pickling the entire backing array. # @@ -16,128 +16,32 @@ # # This comes up when using pandas `string[pyarrow]` dtypes, which are backed by # a `pyarrow.StringArray`. To fix this, we register a *global* override for -# pickling `pandas.core.arrays.ArrowStringArray` types. We do this at the -# pandas level rather than the pyarrow level for efficiency reasons (a pandas -# ArrowStringArray may contain many small pyarrow StringArray objects). -# -# This pickling implementation manually mucks with the backing buffers in a -# fairly efficient way: -# -# - The data buffer is never copied -# - The offsets buffer is only copied if the array is sliced with a start index -# (x[start:]) -# - The mask buffer is never copied -# -# This implementation works with pickle protocol 5, allowing support for true -# zero-copy sends. +# pickling `ArrowStringArray` or `ArrowExtensionArray` types (where available). +# We do this at the pandas level rather than the pyarrow level for efficiency reasons +# (a pandas ArrowStringArray may contain many small pyarrow StringArray objects). # -# XXX: Once pyarrow (or pandas) has fixed this bug, we should skip registering -# with copyreg for versions that lack this issue. - - -def pyarrow_stringarray_to_parts(array): - """Decompose a ``pyarrow.StringArray`` into a tuple of components. - - The resulting tuple can be passed to - ``pyarrow_stringarray_from_parts(*components)`` to reconstruct the - ``pyarrow.StringArray``. - """ - # Access the backing buffers. - # - # - mask: None, or a bitmask of length ceil(nitems / 8). 0 bits mark NULL - # elements, only present if NULL data is present, commonly None. - # - offsets: A uint32 array of nitems + 1 items marking the start/stop - # indices for the individual elements in `data` - # - data: All the utf8 string data concatenated together - # - # The structure of these buffers comes from the arrow format, documented at - # https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout. - # In particular, this is a `StringArray` (4 byte offsets), rather than a - # `LargeStringArray` (8 byte offsets). - assert pa.types.is_string(array.type) - - mask, offsets, data = array.buffers() - nitems = len(array) - - if not array.offset: - # No leading offset, only need to slice any unnecessary data from the - # backing buffers - offsets = offsets[: 4 * (nitems + 1)] - data_stop = int.from_bytes(offsets[-4:], "little") - data = data[:data_stop] - if mask is None: - return nitems, offsets, data - else: - mask = mask[: math.ceil(nitems / 8)] - return nitems, offsets, data, mask - - # There is a leading offset. This complicates things a bit. - offsets_start = array.offset * 4 - offsets_stop = offsets_start + (nitems + 1) * 4 - data_start = int.from_bytes(offsets[offsets_start : offsets_start + 4], "little") - data_stop = int.from_bytes(offsets[offsets_stop - 4 : offsets_stop], "little") - data = data[data_start:data_stop] - - if mask is None: - npad = 0 - else: - # Since the mask is a bitmask, it can only represent even units of 8 - # elements. To avoid shifting any bits, we pad the array with up to 7 - # elements so the mask array can always be serialized zero copy. - npad = array.offset % 8 - mask_start = array.offset // 8 - mask_stop = math.ceil((array.offset + nitems) / 8) - mask = mask[mask_start:mask_stop] - - # Subtract the offset of the starting element from every used offset in the - # offsets array, ensuring the first element in the serialized `offsets` - # array is always 0. - offsets_array = np.frombuffer(offsets, dtype="i4") - offsets_array = ( - offsets_array[array.offset : array.offset + nitems + 1] - - offsets_array[array.offset] - ) - # Pad the new offsets by `npad` offsets of 0 (see the `mask` comment above). We wrap - # this in a `pyarrow.py_buffer`, since this type transparently supports pickle 5, - # avoiding an extra copy inside the pickler. - offsets = pa.py_buffer( - b"\x00" * (4 * npad) + offsets_array.data if npad else offsets_array.data - ) - - if mask is None: - return nitems, offsets, data - else: - return nitems, offsets, data, mask, npad - - -def pyarrow_stringarray_from_parts(nitems, data_offsets, data, mask=None, offset=0): - """Reconstruct a ``pyarrow.StringArray`` from the parts returned by - ``pyarrow_stringarray_to_parts``.""" - return pa.StringArray.from_buffers(nitems, data_offsets, data, mask, offset=offset) +# The implementation here is based on https://github.com/pandas-dev/pandas/pull/49078 +# which is included in pandas=2+. We can remove all this once Dask's minimum +# supported pandas version is at least 2.0.0. -def rebuild_arrowstringarray(*chunk_parts): - """Rebuild a ``pandas.core.arrays.ArrowStringArray``""" - array = pa.chunked_array( - [pyarrow_stringarray_from_parts(*parts) for parts in chunk_parts], - type=pa.string(), - ) - return pd.arrays.ArrowStringArray(array) +def rebuild_arrowextensionarray(type_, chunks): + array = pa.chunked_array(chunks) + return type_(array) -def reduce_arrowstringarray(x): - """A pickle override for ``pandas.core.arrays.ArrowStringArray`` that avoids - serializing unnecessary data, while also avoiding/minimizing data copies""" - # Decompose each chunk in the backing ChunkedArray into their individual - # components for serialization. We filter out 0-length chunks, since they - # add no meaningful value to the chunked array. - chunks = tuple( - pyarrow_stringarray_to_parts(chunk) - for chunk in x._data.chunks - if len(chunk) > 0 - ) - return (rebuild_arrowstringarray, chunks) +def reduce_arrowextensionarray(x): + return (rebuild_arrowextensionarray, (type(x), x._data.combine_chunks())) -if hasattr(pd.arrays, "ArrowStringArray") and pa is not None: - copyreg.dispatch_table[pd.arrays.ArrowStringArray] = reduce_arrowstringarray +# `pandas=2` includes efficient serialization of `pyarrow`-backed extension arrays. +# See https://github.com/pandas-dev/pandas/pull/49078 for details. +# We only need to backport efficient serialization for `pandas<2`. +if pa is not None and not PANDAS_GT_200: + if PANDAS_GT_150: + # Applies to all `pyarrow`-backed extension arrays (e.g. `string[pyarrow]`, `int64[pyarrow]`) + for type_ in [pd.arrays.ArrowExtensionArray, pd.arrays.ArrowStringArray]: + copyreg.dispatch_table[type_] = reduce_arrowextensionarray + elif PANDAS_GT_130: + # Only `string[pyarrow]` is implemented, so just patch that + copyreg.dispatch_table[pd.arrays.ArrowStringArray] = reduce_arrowextensionarray diff --git a/dask/dataframe/tests/test_pyarrow_compat.py b/dask/dataframe/tests/test_pyarrow_compat.py index 18394a0aba2..ecd8daa853b 100644 --- a/dask/dataframe/tests/test_pyarrow_compat.py +++ b/dask/dataframe/tests/test_pyarrow_compat.py @@ -1,110 +1,130 @@ -import math import pickle -import random -import string +from datetime import date, datetime, time, timedelta import pandas as pd +import pandas._testing as tm import pytest pa = pytest.importorskip("pyarrow") -from dask.dataframe._pyarrow_compat import ( - pyarrow_stringarray_from_parts, - pyarrow_stringarray_to_parts, -) - -if not hasattr(pd.arrays, "ArrowStringArray"): - pytestmark = pytest.mark.skip("pandas.arrays.ArrowStringArray is not available") +from dask.dataframe._compat import PANDAS_GT_130, PANDAS_GT_150 +pytestmark = pytest.mark.skipif( + not PANDAS_GT_130, reason="No `pyarrow`-backed extension arrays are available" +) -def randstr(i): - """A random string, prefixed with the index number to make it clearer what the data - boundaries are""" - return str(i) + "".join( - random.choices(string.ascii_letters, k=random.randint(3, 8)) - ) +# Tests are from https://github.com/pandas-dev/pandas/pull/49078 + + +@pytest.fixture +def data(dtype): + if PANDAS_GT_150: + pa_dtype = dtype.pyarrow_dtype + else: + pa_dtype = pa.string() + if pa.types.is_boolean(pa_dtype): + data = [True, False] * 4 + [None] + [True, False] * 44 + [None] + [True, False] + elif pa.types.is_floating(pa_dtype): + data = [1.0, 0.0] * 4 + [None] + [-2.0, -1.0] * 44 + [None] + [0.5, 99.5] + elif pa.types.is_signed_integer(pa_dtype): + data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99] + elif pa.types.is_unsigned_integer(pa_dtype): + data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99] + elif pa.types.is_date(pa_dtype): + data = ( + [date(2022, 1, 1), date(1999, 12, 31)] * 4 + + [None] + + [date(2022, 1, 1), date(2022, 1, 1)] * 44 + + [None] + + [date(1999, 12, 31), date(1999, 12, 31)] + ) + elif pa.types.is_timestamp(pa_dtype): + data = ( + [datetime(2020, 1, 1, 1, 1, 1, 1), datetime(1999, 1, 1, 1, 1, 1, 1)] * 4 + + [None] + + [datetime(2020, 1, 1, 1), datetime(1999, 1, 1, 1)] * 44 + + [None] + + [datetime(2020, 1, 1), datetime(1999, 1, 1)] + ) + elif pa.types.is_duration(pa_dtype): + data = ( + [timedelta(1), timedelta(1, 1)] * 4 + + [None] + + [timedelta(-1), timedelta(0)] * 44 + + [None] + + [timedelta(-10), timedelta(10)] + ) + elif pa.types.is_time(pa_dtype): + data = ( + [time(12, 0), time(0, 12)] * 4 + + [None] + + [time(0, 0), time(1, 1)] * 44 + + [None] + + [time(0, 5), time(5, 0)] + ) + elif pa.types.is_string(pa_dtype): + data = ["a", "b"] * 4 + [None] + ["1", "2"] * 44 + [None] + ["!", ">"] + elif pa.types.is_binary(pa_dtype): + data = [b"a", b"b"] * 4 + [None] + [b"1", b"2"] * 44 + [None] + [b"!", b">"] + else: + raise NotImplementedError + return pd.array(data * 100, dtype=dtype) + + +PYARROW_TYPES = tm.ALL_PYARROW_DTYPES if PANDAS_GT_150 else [pa.string()] + + +@pytest.fixture(params=PYARROW_TYPES, ids=str) +def dtype(request): + if PANDAS_GT_150: + return pd.ArrowDtype(pyarrow_dtype=request.param) + else: + return pd.StringDtype("pyarrow") + + +def test_pickle_roundtrip(data): + expected = pd.Series(data) + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + # Make sure slicing gives a large reduction in serialized bytes + assert len(full_pickled) > len(sliced_pickled) * 3 + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced) -@pytest.mark.parametrize("length", [6, 8, 12, 20]) @pytest.mark.parametrize( - "slc", + "string_dtype", [ - slice(None), - slice(0, 5), - slice(2), - slice(2, 5), - slice(2, None, 2), - slice(0, 0), - slice(7, 10), - slice(7, 19), - slice(15, 19), + "stringdtype", + pytest.param( + "arrowdtype", + marks=pytest.mark.skipif(not PANDAS_GT_150, reason="Requires ArrowDtype"), + ), ], ) -@pytest.mark.parametrize("has_mask", [True, False]) -def test_roundtrip_stringarray(length, slc, has_mask): - x = pa.array( - [randstr(i) if (not has_mask or i % 3) else None for i in range(length)], - )[slc] - - def unpack(nitems, offsets, data, mask=None, offset=0): - return nitems, offsets, data, mask, offset - - parts = pyarrow_stringarray_to_parts(x) - nitems, offsets, data, mask, offset = unpack(*parts) - - # Check individual serialized components are correct - assert nitems == len(x) - - assert len(offsets) == 4 * (nitems + offset + 1) - - expected_data = "".join(x.drop_null().tolist()).encode("utf-8") - assert bytes(data) == expected_data - - if mask is not None: - assert len(mask) == math.ceil((nitems + offset) / 8) - assert x.offset % 8 == offset - - # Test rebuilding from components works - x2 = pyarrow_stringarray_from_parts(*parts) - assert x == x2 - - # Test pickle roundtrip works - pd_x = pd.arrays.ArrowStringArray(x) - pd_x2 = pickle.loads(pickle.dumps(pd_x)) - assert pd_x.equals(pd_x2) - - -@pytest.mark.parametrize("has_mask", [True, False]) -@pytest.mark.parametrize("start,end", [(None, -1), (1, None), (1, -1)]) -def test_pickle_stringarray_slice_doesnt_serialize_whole_array(has_mask, start, end): - x = pd.array( - ["apple", "banana", "carrot", "durian", "eggplant", "fennel", "grape"], - dtype="string[pyarrow]", - ) - if has_mask: - x[3] = None - - x_sliced = x[start:end] - buf = pickle.dumps(x_sliced) - loaded = pickle.loads(buf) - assert loaded.equals(x_sliced) - - if start is not None: - assert b"apple" not in buf - if end is not None: - assert b"grape" not in buf - - -@pytest.mark.parametrize("has_mask", [True, False]) -def test_pickle_stringarray_supports_pickle_5(has_mask): - x = pd.array( - ["apple", "banana", "carrot", "durian", "eggplant", "fennel", "grape"], - dtype="string[pyarrow]", - ) - x[3] = None - - buffers = [] - buf = pickle.dumps(x, protocol=5, buffer_callback=buffers.append) - assert buffers - x2 = pickle.loads(buf, buffers=buffers) - assert x.equals(x2) +def test_pickle_roundtrip_pyarrow_string_implementations(string_dtype): + # There are two pyarrow string implementations in pandas. + # This tests that both implementations have similar serialization performance. + if string_dtype == "stringdtype": + string_dtype = pd.StringDtype("pyarrow") + else: + string_dtype = pd.ArrowDtype(pa.string()) + expected = pd.Series(map(str, range(1_000)), dtype=string_dtype) + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + # Make sure slicing gives a large reduction in serialized bytes + assert len(full_pickled) > len(sliced_pickled) * 3 + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced)