Skip to content

Commit

Permalink
update np.argsort to use @overload
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermeleobas committed Oct 6, 2022
1 parent 586410b commit 6426a01
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 44 deletions.
17 changes: 0 additions & 17 deletions numba/core/typing/arraydecl.py
Expand Up @@ -414,23 +414,6 @@ def sentry_shape_scalar(ty):
retty = ary.copy(ndim=len(args))
return signature(retty, *args)

@bound_function("array.argsort")
def resolve_argsort(self, ary, args, kws):
assert not args
kwargs = dict(kws)
kind = kwargs.pop('kind', types.StringLiteral('quicksort'))
if not isinstance(kind, types.StringLiteral):
raise TypingError('"kind" must be a string literal')
if kwargs:
msg = "Unsupported keywords: {!r}"
raise TypingError(msg.format([k for k in kwargs.keys()]))
if ary.ndim == 1:
def argsort_stub(kind='quicksort'):
pass
pysig = utils.pysignature(argsort_stub)
sig = signature(types.Array(types.intp, 1, 'C'), kind).replace(pysig=pysig)
return sig

@bound_function("array.view")
def resolve_view(self, ary, args, kws):
from .npydecl import parse_dtype
Expand Down
6 changes: 1 addition & 5 deletions numba/core/typing/npydecl.py
Expand Up @@ -389,10 +389,6 @@ def sum_stub(arr, dtype):
def sum_stub(arr, axis, dtype):
pass
pysig = utils.pysignature(sum_stub)
elif self.method_name == 'argsort':
def argsort_stub(arr, kind='quicksort'):
pass
pysig = utils.pysignature(argsort_stub)
else:
fmt = "numba doesn't support kwarg for {}"
raise TypingError(fmt.format(self.method_name))
Expand All @@ -414,7 +410,7 @@ def _numpy_redirect(fname):
infer_global(numpy_function, types.Function(cls))

for func in ['min', 'max', 'sum', 'prod', 'mean', 'var', 'std',
'cumsum', 'cumprod', 'argsort', 'nonzero']:
'cumsum', 'cumprod', 'nonzero']:
_numpy_redirect(func)


Expand Down
23 changes: 12 additions & 11 deletions numba/np/arrayobj.py
Expand Up @@ -6008,21 +6008,22 @@ def np_sort_impl(a):
return np_sort_impl


@lower_builtin("array.argsort", types.Array, types.StringLiteral)
@lower_builtin(np.argsort, types.Array, types.StringLiteral)
def array_argsort(context, builder, sig, args):
arytype, kind = sig.args
@overload(np.argsort)
@overload_method(types.Array, "argsort")
def impl_arr_argsort(arr, kind=None):
if kind is None or isinstance(kind, types.Omitted):
kind = types.StringLiteral(value='quicksort')
if not isinstance(kind, types.StringLiteral):
msg = '"kind" must be a string literal'
raise errors.RequireLiteralValue(msg)

sort_func = get_sort_func(kind=kind.literal_value,
is_float=isinstance(arytype.dtype, types.Float),
is_float=isinstance(arr.dtype, types.Float),
is_argsort=True)

def array_argsort_impl(arr):
def impl(arr, kind=None):
return sort_func(arr)

innersig = sig.replace(args=sig.args[:1])
innerargs = args[:1]
return context.compile_internal(builder, array_argsort_impl,
innersig, innerargs)
return impl


# ------------------------------------------------------------------------------
Expand Down
12 changes: 1 addition & 11 deletions numba/tests/test_sort.py
Expand Up @@ -935,7 +935,7 @@ def check(pyfunc):
check(argsort_usecase)
check(np_argsort_usecase)

def test_argsort_float(self):
def test_argsort_float_stable(self):
def check(pyfunc, is_stable):
cfunc = jit(nopython=True)(pyfunc)
for orig in self.float_arrays():
Expand Down Expand Up @@ -1195,16 +1195,6 @@ def nonliteral_kind(kind):
expect = '"kind" must be a string literal'
self.assertIn(expect, str(raises.exception))

@njit
def unsupported_kwarg():
np.arange(5).argsort(foo='')

with self.assertRaises(errors.TypingError) as raises:
unsupported_kwarg()

expect = "Unsupported keywords: ['foo']"
self.assertIn(expect, str(raises.exception))


if __name__ == '__main__':
unittest.main()

0 comments on commit 6426a01

Please sign in to comment.