Skip to content

Commit

Permalink
Generalize array checking and remove pd.Index call in `_get_partiti…
Browse files Browse the repository at this point in the history
…ons` (#9634)
  • Loading branch information
quasiben committed Nov 8, 2022
1 parent fe68659 commit 8be183c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 4 additions & 4 deletions dask/dataframe/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dask.dataframe.core import Series, new_dd_object
from dask.dataframe.utils import is_index_like, is_series_like, meta_nonempty
from dask.highlevelgraph import HighLevelGraph
from dask.utils import is_arraylike


class _IndexerBase:
Expand Down Expand Up @@ -116,10 +117,10 @@ def _loc(self, iindexer, cindexer):

if isinstance(iindexer, slice):
return self._loc_slice(iindexer, cindexer)
elif isinstance(iindexer, (list, np.ndarray)):
return self._loc_list(iindexer, cindexer)
elif is_series_like(iindexer) and not is_bool_dtype(iindexer.dtype):
return self._loc_list(iindexer.values, cindexer)
elif isinstance(iindexer, list) or is_arraylike(iindexer):
return self._loc_list(iindexer, cindexer)
else:
# element should raise KeyError
return self._loc_element(iindexer, cindexer)
Expand Down Expand Up @@ -209,7 +210,7 @@ def _loc_element(self, iindexer, cindexer):
return new_dd_object(graph, name, meta=meta, divisions=[iindexer, iindexer])

def _get_partitions(self, keys):
if isinstance(keys, (list, np.ndarray)):
if isinstance(keys, list) or is_arraylike(keys):
return _partitions_of_index_values(self.obj.divisions, keys)
else:
# element
Expand Down Expand Up @@ -343,7 +344,6 @@ def _partitions_of_index_values(divisions, values):
raise ValueError(msg)

results = defaultdict(list)
values = pd.Index(values, dtype=object)
for val in values:
i = bisect.bisect_right(divisions, val)
div = min(len(divisions) - 2, max(0, i - 1))
Expand Down
15 changes: 15 additions & 0 deletions dask/dataframe/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,18 @@ def test_deterministic_hashing_dataframe():

ddf2 = dask_df.iloc[:, [0, 2]]
assert tokenize(ddf1) != tokenize(ddf2)


@pytest.mark.gpu
def test_gpu_loc():
cudf = pytest.importorskip("cudf")
cupy = pytest.importorskip("cupy")

index = [1, 5, 10, 11, 12, 100, 200, 300]
df = cudf.DataFrame({"a": range(8), "index": index}).set_index("index")
ddf = dd.from_pandas(df, npartitions=3)
cdf_index = cudf.Series([1, 100, 300])
cupy_index = cupy.array([1, 100, 300])

assert_eq(ddf.loc[cdf_index], df.loc[cupy_index])
assert_eq(ddf.loc[cupy_index], df.loc[cupy_index])

0 comments on commit 8be183c

Please sign in to comment.