From e68a5fde7b6bb7a754d893b917e44270aa9cce8e Mon Sep 17 00:00:00 2001 From: Fabien Aulaire Date: Thu, 13 Oct 2022 22:16:59 +0200 Subject: [PATCH 1/5] Fix map_overlap in order to accept pandas arguments --- dask/dataframe/rolling.py | 11 ++++++++--- dask/dataframe/tests/test_rolling.py | 29 +++++++++++++++++++--------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/dask/dataframe/rolling.py b/dask/dataframe/rolling.py index 3a1d6ef6a31..6fd8e24f847 100644 --- a/dask/dataframe/rolling.py +++ b/dask/dataframe/rolling.py @@ -21,6 +21,8 @@ no_default, partitionwise_graph, ) +from dask.dataframe.io import from_pandas +from dask.dataframe.utils import is_dask_collection, is_series_like, is_dataframe_like from dask.dataframe.multi import _maybe_align_partitions from dask.delayed import unpack_collections from dask.highlevelgraph import HighLevelGraph @@ -145,9 +147,10 @@ def map_overlap( -------- dd.DataFrame.map_overlap """ - args = (df,) + args - dfs = [df for df in args if isinstance(df, _Frame)] + df = from_pandas(df, 1) if (is_series_like(df) or is_dataframe_like(df)) and not is_dask_collection(df) else df + + args = (df,) + args if isinstance(before, str): before = pd.to_timedelta(before) @@ -155,7 +158,7 @@ def map_overlap( after = pd.to_timedelta(after) if isinstance(before, datetime.timedelta) or isinstance(after, datetime.timedelta): - if not is_datetime64_any_dtype(dfs[0].index._meta_nonempty.inferred_type): + if not is_datetime64_any_dtype(df.index._meta_nonempty.inferred_type): raise TypeError( "Must have a `DatetimeIndex` when using string offset " "for `before` and `after`" @@ -190,6 +193,8 @@ def map_overlap( "calling `map_overlap` directly, pass `align_dataframes=False`." ) from e + dfs = [df for df in args if isinstance(df, _Frame)] + meta = _get_meta_map_partitions(args, dfs, func, kwargs, meta, parent_meta) if all(isinstance(arg, Scalar) for arg in args): diff --git a/dask/dataframe/tests/test_rolling.py b/dask/dataframe/tests/test_rolling.py index e05f4491a24..ba72e1fe29f 100644 --- a/dask/dataframe/tests/test_rolling.py +++ b/dask/dataframe/tests/test_rolling.py @@ -6,6 +6,7 @@ from packaging.version import parse as parse_version import dask.dataframe as dd +import dask.dataframe.rolling from dask.dataframe.utils import assert_eq N = 40 @@ -53,20 +54,25 @@ def shifted_sum(df, before, after, c=0): @pytest.mark.parametrize("npartitions", [1, 4]) -def test_map_overlap(npartitions): - ddf = dd.from_pandas(df, npartitions) +@pytest.mark.parametrize("use_dask_input", [True, False]) +def test_map_overlap(npartitions, use_dask_input): + ddf = df + if use_dask_input: + ddf = dd.from_pandas(df, npartitions) + for before, after in [(0, 3), (3, 0), (3, 3), (0, 0)]: # DataFrame - res = ddf.map_overlap(shifted_sum, before, after, before, after, c=2) + res = dask.dataframe.rolling.map_overlap(shifted_sum, ddf, before, after, before, after, c=2) sol = shifted_sum(df, before, after, c=2) assert_eq(res, sol) # Series - res = ddf.b.map_overlap(shifted_sum, before, after, before, after, c=2) + res = dask.dataframe.rolling.map_overlap(shifted_sum, ddf.b, before, after, before, after, c=2) sol = shifted_sum(df.b, before, after, c=2) assert_eq(res, sol) +@pytest.mark.parametrize("use_dask_input", [True, False]) @pytest.mark.parametrize("npartitions", [1, 4]) @pytest.mark.parametrize("enforce_metadata", [True, False]) @pytest.mark.parametrize("transform_divisions", [True, False]) @@ -87,12 +93,15 @@ def test_map_overlap(npartitions): ], ) def test_map_overlap_multiple_dataframes( - npartitions, enforce_metadata, transform_divisions, align_dataframes, overlap_setup + use_dask_input, npartitions, enforce_metadata, transform_divisions, align_dataframes, overlap_setup ): dataframe, before, after = overlap_setup - ddf = dd.from_pandas(dataframe, npartitions) - ddf2 = dd.from_pandas(dataframe * 2, npartitions) + ddf = dataframe + ddf2 = dataframe * 2 + if use_dask_input: + ddf = dd.from_pandas(ddf, npartitions) + ddf2 = dd.from_pandas(ddf2, npartitions) def get_shifted_sum_arg(overlap): return ( @@ -104,8 +113,9 @@ def get_shifted_sum_arg(overlap): ), get_shifted_sum_arg(after) # DataFrame - res = ddf.map_overlap( + res = dask.dataframe.rolling.map_overlap( shifted_sum, + ddf, before, after, before_shifted_sum, @@ -119,8 +129,9 @@ def get_shifted_sum_arg(overlap): assert_eq(res, sol) # Series - res = ddf.b.map_overlap( + res = dask.dataframe.rolling.map_overlap( shifted_sum, + ddf.b, before, after, before_shifted_sum, From 7dd9e72a1c080971ce4887dd68e8916225135df9 Mon Sep 17 00:00:00 2001 From: Fabien Aulaire Date: Thu, 13 Oct 2022 22:23:08 +0200 Subject: [PATCH 2/5] pre-commit fix --- dask/dataframe/rolling.py | 8 ++++++-- dask/dataframe/tests/test_rolling.py | 15 ++++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/dask/dataframe/rolling.py b/dask/dataframe/rolling.py index 6fd8e24f847..efe423742ac 100644 --- a/dask/dataframe/rolling.py +++ b/dask/dataframe/rolling.py @@ -22,8 +22,8 @@ partitionwise_graph, ) from dask.dataframe.io import from_pandas -from dask.dataframe.utils import is_dask_collection, is_series_like, is_dataframe_like from dask.dataframe.multi import _maybe_align_partitions +from dask.dataframe.utils import is_dask_collection, is_dataframe_like, is_series_like from dask.delayed import unpack_collections from dask.highlevelgraph import HighLevelGraph from dask.utils import M, apply, derived_from, funcname, has_keyword @@ -148,7 +148,11 @@ def map_overlap( dd.DataFrame.map_overlap """ - df = from_pandas(df, 1) if (is_series_like(df) or is_dataframe_like(df)) and not is_dask_collection(df) else df + df = ( + from_pandas(df, 1) + if (is_series_like(df) or is_dataframe_like(df)) and not is_dask_collection(df) + else df + ) args = (df,) + args diff --git a/dask/dataframe/tests/test_rolling.py b/dask/dataframe/tests/test_rolling.py index ba72e1fe29f..3362ce98181 100644 --- a/dask/dataframe/tests/test_rolling.py +++ b/dask/dataframe/tests/test_rolling.py @@ -62,12 +62,16 @@ def test_map_overlap(npartitions, use_dask_input): for before, after in [(0, 3), (3, 0), (3, 3), (0, 0)]: # DataFrame - res = dask.dataframe.rolling.map_overlap(shifted_sum, ddf, before, after, before, after, c=2) + res = dask.dataframe.rolling.map_overlap( + shifted_sum, ddf, before, after, before, after, c=2 + ) sol = shifted_sum(df, before, after, c=2) assert_eq(res, sol) # Series - res = dask.dataframe.rolling.map_overlap(shifted_sum, ddf.b, before, after, before, after, c=2) + res = dask.dataframe.rolling.map_overlap( + shifted_sum, ddf.b, before, after, before, after, c=2 + ) sol = shifted_sum(df.b, before, after, c=2) assert_eq(res, sol) @@ -93,7 +97,12 @@ def test_map_overlap(npartitions, use_dask_input): ], ) def test_map_overlap_multiple_dataframes( - use_dask_input, npartitions, enforce_metadata, transform_divisions, align_dataframes, overlap_setup + use_dask_input, + npartitions, + enforce_metadata, + transform_divisions, + align_dataframes, + overlap_setup, ): dataframe, before, after = overlap_setup From 00ba7ec7f67e0317af5506a88bd30437e9b15617 Mon Sep 17 00:00:00 2001 From: Fabien Aulaire <306648+faulaire@users.noreply.github.com> Date: Wed, 30 Nov 2022 21:29:41 +0100 Subject: [PATCH 3/5] Liting --- dask/dataframe/rolling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dask/dataframe/rolling.py b/dask/dataframe/rolling.py index 29657d93992..66273a6fd90 100644 --- a/dask/dataframe/rolling.py +++ b/dask/dataframe/rolling.py @@ -23,7 +23,12 @@ ) from dask.dataframe.io import from_pandas from dask.dataframe.multi import _maybe_align_partitions -from dask.dataframe.utils import is_dask_collection, is_dataframe_like, is_series_like, insert_meta_param_description +from dask.dataframe.utils import ( + is_dask_collection, + is_dataframe_like, + is_series_like, + insert_meta_param_description +) from dask.delayed import unpack_collections from dask.highlevelgraph import HighLevelGraph from dask.utils import M, apply, derived_from, funcname, has_keyword From 2a0706c136bc2d7f37d53d7badd8683c5a04aeb8 Mon Sep 17 00:00:00 2001 From: Fabien Aulaire <306648+faulaire@users.noreply.github.com> Date: Wed, 30 Nov 2022 21:49:13 +0100 Subject: [PATCH 4/5] Linting --- dask/dataframe/rolling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask/dataframe/rolling.py b/dask/dataframe/rolling.py index 66273a6fd90..44807d848cf 100644 --- a/dask/dataframe/rolling.py +++ b/dask/dataframe/rolling.py @@ -24,10 +24,10 @@ from dask.dataframe.io import from_pandas from dask.dataframe.multi import _maybe_align_partitions from dask.dataframe.utils import ( + insert_meta_param_description is_dask_collection, is_dataframe_like, - is_series_like, - insert_meta_param_description + is_series_like ) from dask.delayed import unpack_collections from dask.highlevelgraph import HighLevelGraph From d93de2b2df44482cd9813d08765218a3b2849c3f Mon Sep 17 00:00:00 2001 From: Fabien Aulaire Date: Wed, 30 Nov 2022 21:58:16 +0100 Subject: [PATCH 5/5] Linting --- dask/dataframe/rolling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask/dataframe/rolling.py b/dask/dataframe/rolling.py index 44807d848cf..5636312e82e 100644 --- a/dask/dataframe/rolling.py +++ b/dask/dataframe/rolling.py @@ -24,10 +24,10 @@ from dask.dataframe.io import from_pandas from dask.dataframe.multi import _maybe_align_partitions from dask.dataframe.utils import ( - insert_meta_param_description + insert_meta_param_description, is_dask_collection, is_dataframe_like, - is_series_like + is_series_like, ) from dask.delayed import unpack_collections from dask.highlevelgraph import HighLevelGraph