Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes context.compile_internal where easy #8493

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 0 additions & 37 deletions numba/core/typing/arraydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,29 +414,6 @@ def sentry_shape_scalar(ty):
retty = ary.copy(ndim=len(args))
return signature(retty, *args)

@bound_function("array.sort")
def resolve_sort(self, ary, args, kws):
assert not args
assert not kws
return signature(types.none)

@bound_function("array.argsort")
def resolve_argsort(self, ary, args, kws):
assert not args
kwargs = dict(kws)
kind = kwargs.pop('kind', types.StringLiteral('quicksort'))
if not isinstance(kind, types.StringLiteral):
raise TypingError('"kind" must be a string literal')
if kwargs:
msg = "Unsupported keywords: {!r}"
raise TypingError(msg.format([k for k in kwargs.keys()]))
if ary.ndim == 1:
def argsort_stub(kind='quicksort'):
pass
pysig = utils.pysignature(argsort_stub)
sig = signature(types.Array(types.intp, 1, 'C'), kind).replace(pysig=pysig)
return sig

@bound_function("array.view")
def resolve_view(self, ary, args, kws):
from .npydecl import parse_dtype
Expand Down Expand Up @@ -469,20 +446,6 @@ def resolve_astype(self, ary, args, kws):
retty = ary.copy(dtype=dtype, layout=layout, readonly=False)
return signature(retty, *args)

@bound_function("array.ravel")
def resolve_ravel(self, ary, args, kws):
# Only support no argument version (default order='C')
assert not kws
assert not args
return signature(ary.copy(ndim=1, layout='C'))

@bound_function("array.flatten")
def resolve_flatten(self, ary, args, kws):
# Only support no argument version (default order='C')
assert not kws
assert not args
return signature(ary.copy(ndim=1, layout='C'))

def generic_resolve(self, ary, attr):
# Resolution of other attributes, for record arrays
if isinstance(ary.dtype, types.Record):
Expand Down
11 changes: 0 additions & 11 deletions numba/core/typing/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,17 +654,6 @@ def generic(self, args, kws):
sig = signature(ret, *args)
return sig

# Generic implementation for "not in"

@infer
class GenericNotIn(AbstractTemplate):
key = "not in"

def generic(self, args, kws):
args = args[::-1]
sig = self.context.resolve_function_type(operator.contains, args, kws)
return signature(sig.return_type, *sig.args[::-1])


#-------------------------------------------------------------------------------

Expand Down
6 changes: 1 addition & 5 deletions numba/core/typing/npydecl.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,6 @@ def sum_stub(arr, dtype):
def sum_stub(arr, axis, dtype):
pass
pysig = utils.pysignature(sum_stub)
elif self.method_name == 'argsort':
def argsort_stub(arr, kind='quicksort'):
pass
pysig = utils.pysignature(argsort_stub)
else:
fmt = "numba doesn't support kwarg for {}"
raise TypingError(fmt.format(self.method_name))
Expand All @@ -414,7 +410,7 @@ def _numpy_redirect(fname):
infer_global(numpy_function, types.Function(cls))

for func in ['min', 'max', 'sum', 'prod', 'mean', 'var', 'std',
'cumsum', 'cumprod', 'argsort', 'nonzero', 'ravel']:
'cumsum', 'cumprod', 'nonzero']:
_numpy_redirect(func)


Expand Down
24 changes: 12 additions & 12 deletions numba/cpython/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,16 @@ def round_impl_unary(context, builder, sig, args):
res = builder.fptosi(res, context.get_value_type(sig.return_type))
return impl_ret_untracked(context, builder, sig.return_type, res)

@lower_builtin(round, types.Float, types.Integer)
def round_impl_binary(context, builder, sig, args):
fltty = sig.args[0]

@overload(round)
def round_impl_binary(x, ndigits):
if not (isinstance(x, types.Float) and isinstance(ndigits, types.Integer)):
return

# Allow calling the intrinsic from the Python implementation below.
# This avoids the conversion to an int in Python 3's unary round().
_round = types.ExternalFunction(
_round_intrinsic(fltty), typing.signature(fltty, fltty))
_round_intrinsic(x), typing.signature(x, x))

def round_ndigits(x, ndigits):
if math.isinf(x) or math.isnan(x):
Expand All @@ -281,8 +284,7 @@ def round_ndigits(x, ndigits):
y = x / pow1
return _round(y) * pow1

res = context.compile_internal(builder, round_ndigits, sig, args)
return impl_ret_untracked(context, builder, sig.return_type, res)
return round_ndigits


#-------------------------------------------------------------------------------
Expand Down Expand Up @@ -401,13 +403,11 @@ def next_impl(context, builder, sig, args):

# -----------------------------------------------------------------------------

@lower_builtin("not in", types.Any, types.Any)
def not_in(context, builder, sig, args):
def in_impl(a, b):
@overload("not in")
def impl_not_in(a, b):
def impl(a, b):
return operator.contains(b, a)

res = context.compile_internal(builder, in_impl, sig, args)
return builder.not_(res)
return impl
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved


# -----------------------------------------------------------------------------
Expand Down
88 changes: 43 additions & 45 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,46 +1992,40 @@ def impl(arr, values, axis=None):
return impl


@lower_builtin('array.ravel', types.Array)
def array_ravel(context, builder, sig, args):
@overload_method(types.Array, "ravel")
def ol_array_ravel(arr):
# Only support no argument version (default order='C')
def imp_nocopy(ary):
def imp_nocopy(arr):
"""No copy version"""
return ary.reshape(ary.size)
return arr.reshape(arr.size)

def imp_copy(ary):
def imp_copy(arr):
"""Copy version"""
return ary.flatten()
return arr.flatten()

# If the input array is C layout already, use the nocopy version
if sig.args[0].layout == 'C':
if arr.layout == 'C':
imp = imp_nocopy
# otherwise, use flatten under-the-hood
else:
imp = imp_copy

res = context.compile_internal(builder, imp, sig, args)
res = impl_ret_new_ref(context, builder, sig.return_type, res)
return res
return imp


@lower_builtin(np.ravel, types.Array)
def np_ravel(context, builder, sig, args):
def np_ravel_impl(a):
return a.ravel()

return context.compile_internal(builder, np_ravel_impl, sig, args)
@overload(np.ravel)
def ol_np_ravel(a):
if isinstance(a, types.Array):
def np_ravel_impl(a):
return a.ravel()
return np_ravel_impl


@lower_builtin('array.flatten', types.Array)
def array_flatten(context, builder, sig, args):
@overload_method(types.Array, "flatten")
def ol_array_flatten(arr):
# Only support flattening to C layout currently.
def imp(ary):
return ary.copy().reshape(ary.size)

res = context.compile_internal(builder, imp, sig, args)
res = impl_ret_new_ref(context, builder, sig.return_type, res)
return res
def imp(arr):
return arr.copy().reshape(arr.size)
return imp


@register_jitable
Expand Down Expand Up @@ -5971,9 +5965,10 @@ def get_sort_func(kind, is_float, is_argsort=False):
Get a sort implementation of the given kind.
"""
key = kind, is_float, is_argsort
try:
if key in _sorts:
return _sorts[key]
except KeyError:
else:
_supported_kind_values = ('quicksort', 'mergesort')
if kind == 'quicksort':
sort = quicksort.make_jit_quicksort(
lt=lt_floats if is_float else None,
Expand All @@ -5985,21 +5980,23 @@ def get_sort_func(kind, is_float, is_argsort=False):
lt=lt_floats if is_float else None,
is_argsort=is_argsort)
func = sort.run_mergesort
else:
msg = (f'sort func "{kind}" is not supported. Allowed values '
f'are {_supported_kind_values}')
raise errors.TypingError(msg)
_sorts[key] = func
return func
guilhermeleobas marked this conversation as resolved.
Show resolved Hide resolved


@lower_builtin("array.sort", types.Array)
def array_sort(context, builder, sig, args):
arytype = sig.args[0]
@overload_method(types.Array, "sort")
def impl_array_sort(arr):
sort_func = get_sort_func(kind='quicksort',
is_float=isinstance(arytype.dtype, types.Float))
is_float=isinstance(arr.dtype, types.Float))

def array_sort_impl(arr):
def impl(arr):
# Note we clobber the return value
sort_func(arr)

return context.compile_internal(builder, array_sort_impl, sig, args)
return impl


@overload(np.sort)
Expand All @@ -6015,21 +6012,22 @@ def np_sort_impl(a):
return np_sort_impl


@lower_builtin("array.argsort", types.Array, types.StringLiteral)
@lower_builtin(np.argsort, types.Array, types.StringLiteral)
def array_argsort(context, builder, sig, args):
arytype, kind = sig.args
@overload(np.argsort)
@overload_method(types.Array, "argsort")
def impl_arr_argsort(arr, kind=None):
if is_nonelike(kind):
kind = types.StringLiteral(value='quicksort')
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(kind, types.StringLiteral):
msg = '"kind" must be a string literal'
raise errors.RequireLiteralValue(msg)

sort_func = get_sort_func(kind=kind.literal_value,
is_float=isinstance(arytype.dtype, types.Float),
is_float=isinstance(arr.dtype, types.Float),
is_argsort=True)

def array_argsort_impl(arr):
def impl(arr, kind=None):
return sort_func(arr)

innersig = sig.replace(args=sig.args[:1])
innerargs = args[:1]
return context.compile_internal(builder, array_argsort_impl,
innersig, innerargs)
return impl


# ------------------------------------------------------------------------------
Expand Down
22 changes: 11 additions & 11 deletions numba/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def check(pyfunc):
check(argsort_usecase)
check(np_argsort_usecase)

def test_argsort_float(self):
def test_argsort_float_stable(self):
def check(pyfunc, is_stable):
cfunc = jit(nopython=True)(pyfunc)
for orig in self.float_arrays():
Expand All @@ -953,6 +953,16 @@ def test_bad_array(self):
with self.assertRaisesRegex(errors.TypingError, msg) as raises:
cfunc(None)

def test_argsort_bad_kind(self):
def func(val):
return val.argsort(kind='somesort')

cfunc = jit(nopython=True)(func)
msg = '.*sort func "somesort" is not supported.*'
with self.assertRaisesRegex(errors.TypingError, msg) as raises:
arr = np.arange(10, dtype=float)
cfunc(arr)


class TestPythonSort(TestCase):

Expand Down Expand Up @@ -1195,16 +1205,6 @@ def nonliteral_kind(kind):
expect = '"kind" must be a string literal'
self.assertIn(expect, str(raises.exception))

@njit
def unsupported_kwarg():
np.arange(5).argsort(foo='')

with self.assertRaises(errors.TypingError) as raises:
unsupported_kwarg()

expect = "Unsupported keywords: ['foo']"
self.assertIn(expect, str(raises.exception))


if __name__ == '__main__':
unittest.main()