Skip to content

Commit

Permalink
Fix map_overlap in order to accept pandas arguments (#9571)
Browse files Browse the repository at this point in the history
* Fix map_overlap in order to accept pandas arguments

* pre-commit fix

* Liting

* Linting

* Linting
  • Loading branch information
faulaire committed Nov 30, 2022
1 parent 99123bd commit a7efdf9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 13 deletions.
21 changes: 17 additions & 4 deletions dask/dataframe/rolling.py
Expand Up @@ -21,8 +21,14 @@
no_default,
partitionwise_graph,
)
from dask.dataframe.io import from_pandas
from dask.dataframe.multi import _maybe_align_partitions
from dask.dataframe.utils import insert_meta_param_description
from dask.dataframe.utils import (
insert_meta_param_description,
is_dask_collection,
is_dataframe_like,
is_series_like,
)
from dask.delayed import unpack_collections
from dask.highlevelgraph import HighLevelGraph
from dask.utils import M, apply, derived_from, funcname, has_keyword
Expand Down Expand Up @@ -147,17 +153,22 @@ def map_overlap(
--------
dd.DataFrame.map_overlap
"""
args = (df,) + args

dfs = [df for df in args if isinstance(df, _Frame)]
df = (
from_pandas(df, 1)
if (is_series_like(df) or is_dataframe_like(df)) and not is_dask_collection(df)
else df
)

args = (df,) + args

if isinstance(before, str):
before = pd.to_timedelta(before)
if isinstance(after, str):
after = pd.to_timedelta(after)

if isinstance(before, datetime.timedelta) or isinstance(after, datetime.timedelta):
if not is_datetime64_any_dtype(dfs[0].index._meta_nonempty.inferred_type):
if not is_datetime64_any_dtype(df.index._meta_nonempty.inferred_type):
raise TypeError(
"Must have a `DatetimeIndex` when using string offset "
"for `before` and `after`"
Expand Down Expand Up @@ -192,6 +203,8 @@ def map_overlap(
"calling `map_overlap` directly, pass `align_dataframes=False`."
) from e

dfs = [df for df in args if isinstance(df, _Frame)]

meta = _get_meta_map_partitions(args, dfs, func, kwargs, meta, parent_meta)

if all(isinstance(arg, Scalar) for arg in args):
Expand Down
38 changes: 29 additions & 9 deletions dask/dataframe/tests/test_rolling.py
Expand Up @@ -6,6 +6,7 @@
from packaging.version import parse as parse_version

import dask.dataframe as dd
import dask.dataframe.rolling
from dask.dataframe.utils import assert_eq

N = 40
Expand Down Expand Up @@ -53,20 +54,29 @@ def shifted_sum(df, before, after, c=0):


@pytest.mark.parametrize("npartitions", [1, 4])
def test_map_overlap(npartitions):
ddf = dd.from_pandas(df, npartitions)
@pytest.mark.parametrize("use_dask_input", [True, False])
def test_map_overlap(npartitions, use_dask_input):
ddf = df
if use_dask_input:
ddf = dd.from_pandas(df, npartitions)

for before, after in [(0, 3), (3, 0), (3, 3), (0, 0)]:
# DataFrame
res = ddf.map_overlap(shifted_sum, before, after, before, after, c=2)
res = dask.dataframe.rolling.map_overlap(
shifted_sum, ddf, before, after, before, after, c=2
)
sol = shifted_sum(df, before, after, c=2)
assert_eq(res, sol)

# Series
res = ddf.b.map_overlap(shifted_sum, before, after, before, after, c=2)
res = dask.dataframe.rolling.map_overlap(
shifted_sum, ddf.b, before, after, before, after, c=2
)
sol = shifted_sum(df.b, before, after, c=2)
assert_eq(res, sol)


@pytest.mark.parametrize("use_dask_input", [True, False])
@pytest.mark.parametrize("npartitions", [1, 4])
@pytest.mark.parametrize("enforce_metadata", [True, False])
@pytest.mark.parametrize("transform_divisions", [True, False])
Expand All @@ -87,12 +97,20 @@ def test_map_overlap(npartitions):
],
)
def test_map_overlap_multiple_dataframes(
npartitions, enforce_metadata, transform_divisions, align_dataframes, overlap_setup
use_dask_input,
npartitions,
enforce_metadata,
transform_divisions,
align_dataframes,
overlap_setup,
):
dataframe, before, after = overlap_setup

ddf = dd.from_pandas(dataframe, npartitions)
ddf2 = dd.from_pandas(dataframe * 2, npartitions)
ddf = dataframe
ddf2 = dataframe * 2
if use_dask_input:
ddf = dd.from_pandas(ddf, npartitions)
ddf2 = dd.from_pandas(ddf2, npartitions)

def get_shifted_sum_arg(overlap):
return (
Expand All @@ -104,8 +122,9 @@ def get_shifted_sum_arg(overlap):
), get_shifted_sum_arg(after)

# DataFrame
res = ddf.map_overlap(
res = dask.dataframe.rolling.map_overlap(
shifted_sum,
ddf,
before,
after,
before_shifted_sum,
Expand All @@ -119,8 +138,9 @@ def get_shifted_sum_arg(overlap):
assert_eq(res, sol)

# Series
res = ddf.b.map_overlap(
res = dask.dataframe.rolling.map_overlap(
shifted_sum,
ddf.b,
before,
after,
before_shifted_sum,
Expand Down

0 comments on commit a7efdf9

Please sign in to comment.