Skip to content

Commit

Permalink
update test_search.py
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Jun 21, 2020
1 parent 969c6a8 commit 8349818
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/cupy_tests/sorting_tests/test_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest import mock

import numpy
import pytest
Expand Down Expand Up @@ -177,6 +178,18 @@ def test_cub_argmin(self, xp, dtype):
a = xp.ascontiguousarray(a)
else:
a = xp.asfortranarray(a)

if xp is numpy:
return a.argmin()

# xp is cupy, first ensure we really use CUB
full_scan = 'cupy.core._routines_statistics.cub.device_reduce'
full_raise = NotImplementedError('gotcha_full')
with mock.patch(full_scan, side_effect=full_raise), \
pytest.raises(NotImplementedError) as e:
a.argmin()
assert str(e.value) == 'gotcha_full'
# ...then perform the actual computation
return a.argmin()

@testing.for_dtypes('bhilBHILefdFD')
Expand All @@ -188,6 +201,18 @@ def test_cub_argmax(self, xp, dtype):
a = xp.ascontiguousarray(a)
else:
a = xp.asfortranarray(a)

if xp is numpy:
return a.argmax()

# xp is cupy, first ensure we really use CUB
full_scan = 'cupy.core._routines_statistics.cub.device_reduce'
full_raise = NotImplementedError('gotcha_full')
with mock.patch(full_scan, side_effect=full_raise), \
pytest.raises(NotImplementedError) as e:
a.argmax()
assert str(e.value) == 'gotcha_full'
# ...then perform the actual computation
return a.argmax()


Expand Down

0 comments on commit 8349818

Please sign in to comment.