Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix map_overlap in order to accept pandas arguments #9571

Merged
merged 6 commits into from Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 17 additions & 4 deletions dask/dataframe/rolling.py
Expand Up @@ -21,8 +21,14 @@
no_default,
partitionwise_graph,
)
from dask.dataframe.io import from_pandas
from dask.dataframe.multi import _maybe_align_partitions
from dask.dataframe.utils import insert_meta_param_description
from dask.dataframe.utils import (
insert_meta_param_description,
is_dask_collection,
is_dataframe_like,
is_series_like,
)
from dask.delayed import unpack_collections
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M, apply, derived_from, funcname, has_keyword
Expand Down Expand Up @@ -147,17 +153,22 @@ def map_overlap(
--------
dd.DataFrame.map_overlap
"""
args = (df,) + args

dfs = [df for df in args if isinstance(df, _Frame)]
df = (
from_pandas(df, 1)
if (is_series_like(df) or is_dataframe_like(df)) and not is_dask_collection(df)
else df
)

args = (df,) + args

if isinstance(before, str):
before = pd.to_timedelta(before)
if isinstance(after, str):
after = pd.to_timedelta(after)

if isinstance(before, datetime.timedelta) or isinstance(after, datetime.timedelta):
if not is_datetime64_any_dtype(dfs[0].index._meta_nonempty.inferred_type):
if not is_datetime64_any_dtype(df.index._meta_nonempty.inferred_type):
raise TypeError(
"Must have a `DatetimeIndex` when using string offset "
"for `before` and `after`"
Expand Down Expand Up @@ -192,6 +203,8 @@ def map_overlap(
"calling `map_overlap` directly, pass `align_dataframes=False`."
) from e

dfs = [df for df in args if isinstance(df, _Frame)]

meta = _get_meta_map_partitions(args, dfs, func, kwargs, meta, parent_meta)

if all(isinstance(arg, Scalar) for arg in args):
Expand Down
38 changes: 29 additions & 9 deletions dask/dataframe/tests/test_rolling.py
Expand Up @@ -6,6 +6,7 @@
from packaging.version import parse as parse_version

import dask.dataframe as dd
import dask.dataframe.rolling
from dask.dataframe.utils import assert_eq

N = 40
Expand Down Expand Up @@ -53,20 +54,29 @@ def shifted_sum(df, before, after, c=0):


@pytest.mark.parametrize("npartitions", [1, 4])
def test_map_overlap(npartitions):
ddf = dd.from_pandas(df, npartitions)
@pytest.mark.parametrize("use_dask_input", [True, False])
def test_map_overlap(npartitions, use_dask_input):
ddf = df
if use_dask_input:
ddf = dd.from_pandas(df, npartitions)

for before, after in [(0, 3), (3, 0), (3, 3), (0, 0)]:
# DataFrame
res = ddf.map_overlap(shifted_sum, before, after, before, after, c=2)
res = dask.dataframe.rolling.map_overlap(
shifted_sum, ddf, before, after, before, after, c=2
)
sol = shifted_sum(df, before, after, c=2)
assert_eq(res, sol)

# Series
res = ddf.b.map_overlap(shifted_sum, before, after, before, after, c=2)
res = dask.dataframe.rolling.map_overlap(
shifted_sum, ddf.b, before, after, before, after, c=2
)
sol = shifted_sum(df.b, before, after, c=2)
assert_eq(res, sol)


@pytest.mark.parametrize("use_dask_input", [True, False])
@pytest.mark.parametrize("npartitions", [1, 4])
@pytest.mark.parametrize("enforce_metadata", [True, False])
@pytest.mark.parametrize("transform_divisions", [True, False])
Expand All @@ -87,12 +97,20 @@ def test_map_overlap(npartitions):
],
)
def test_map_overlap_multiple_dataframes(
npartitions, enforce_metadata, transform_divisions, align_dataframes, overlap_setup
use_dask_input,
npartitions,
enforce_metadata,
transform_divisions,
align_dataframes,
overlap_setup,
):
dataframe, before, after = overlap_setup

ddf = dd.from_pandas(dataframe, npartitions)
ddf2 = dd.from_pandas(dataframe * 2, npartitions)
ddf = dataframe
ddf2 = dataframe * 2
if use_dask_input:
ddf = dd.from_pandas(ddf, npartitions)
ddf2 = dd.from_pandas(ddf2, npartitions)

def get_shifted_sum_arg(overlap):
return (
Expand All @@ -104,8 +122,9 @@ def get_shifted_sum_arg(overlap):
), get_shifted_sum_arg(after)

# DataFrame
res = ddf.map_overlap(
res = dask.dataframe.rolling.map_overlap(
shifted_sum,
ddf,
before,
after,
before_shifted_sum,
Expand All @@ -119,8 +138,9 @@ def get_shifted_sum_arg(overlap):
assert_eq(res, sol)

# Series
res = ddf.b.map_overlap(
res = dask.dataframe.rolling.map_overlap(
shifted_sum,
ddf.b,
before,
after,
before_shifted_sum,
Expand Down