From 327d8a1328222761650b9f0d5dc536ead82a5875 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 27 Sep 2022 10:47:17 +0100 Subject: [PATCH 1/2] Convert implementations using generated_jit to overload Mixing `generated_jit()` and `overload()` causes fallbacks to object mode - this should not be happening in Numba internal implementations. This commit converts all uses of `generated_jit()` into `overload()` for function implementations in Numba. --- numba/np/arraymath.py | 77 ++++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/numba/np/arraymath.py b/numba/np/arraymath.py index b0d2a68bcab..e4a70374165 100644 --- a/numba/np/arraymath.py +++ b/numba/np/arraymath.py @@ -13,7 +13,6 @@ import llvmlite.ir import numpy as np -from numba import generated_jit from numba.core import types, cgutils from numba.core.extending import overload, overload_method, register_jitable from numba.np.numpy_support import as_dtype, type_can_asarray @@ -1317,34 +1316,50 @@ def prepare_ptp_input(a): return arr -def _compute_current_val_impl_gen(op): - def _compute_current_val_impl(current_val, val): - if isinstance(current_val, types.Complex): - # The sort order for complex numbers is lexicographic. If both the - # real and imaginary parts are non-nan then the order is determined - # by the real parts except when they are equal, in which case the - # order is determined by the imaginary parts. - # https://github.com/numpy/numpy/blob/577a86e/numpy/core/fromnumeric.py#L874-L877 # noqa: E501 - def impl(current_val, val): - if op(val.real, current_val.real): - return val - elif (val.real == current_val.real - and op(val.imag, current_val.imag)): - return val - return current_val - else: - def impl(current_val, val): - return val if op(val, current_val) else current_val - return impl - return _compute_current_val_impl +def _compute_current_val_impl_gen(op, current_val, val): + if isinstance(current_val, types.Complex): + # The sort order for complex numbers is lexicographic. If both the + # real and imaginary parts are non-nan then the order is determined + # by the real parts except when they are equal, in which case the + # order is determined by the imaginary parts. + # https://github.com/numpy/numpy/blob/577a86e/numpy/core/fromnumeric.py#L874-L877 # noqa: E501 + def impl(current_val, val): + if op(val.real, current_val.real): + return val + elif (val.real == current_val.real + and op(val.imag, current_val.imag)): + return val + return current_val + else: + def impl(current_val, val): + return val if op(val, current_val) else current_val + return impl + + +def _compute_a_max(current_val, val): + pass + + +def _compute_a_min(current_val, val): + pass + +@overload(_compute_a_max) +def _compute_a_max_impl(current_val, val): + return _compute_current_val_impl_gen(operator.gt, current_val, val) -_compute_a_max = generated_jit(_compute_current_val_impl_gen(greater_than)) -_compute_a_min = generated_jit(_compute_current_val_impl_gen(less_than)) + +@overload(_compute_a_min) +def _compute_a_min_impl(current_val, val): + return _compute_current_val_impl_gen(operator.lt, current_val, val) -@generated_jit def _early_return(val): + pass + + +@overload(_early_return) +def _early_return_impl(val): UNUSED = 0 if isinstance(val, types.Complex): def impl(val): @@ -4535,7 +4550,11 @@ def _cross_preprocessing(x): out[..., 2] = cp2 -@generated_jit +def _cross(a, b): + pass + + +@overload(_cross) def _cross_impl(a, b): dtype = np.promote_types(as_dtype(a.dtype), as_dtype(b.dtype)) if a.ndim == 1 and b.ndim == 1: @@ -4567,7 +4586,7 @@ def impl(a, b): )) if a_.shape[-1] == 3 or b_.shape[-1] == 3: - return _cross_impl(a_, b_) + return _cross(a_, b_) else: raise ValueError(( "Dimensions for both inputs is 2.\n" @@ -4597,8 +4616,12 @@ def _cross_preprocessing(x): return np.asarray(cp) -@generated_jit def cross2d(a, b): + pass + + +@overload +def cross2d_impl(a, b): if not type_can_asarray(a) or not type_can_asarray(b): raise TypingError("Inputs must be array-like.") From e101b488216b83ffa4fe885e57baf434f418def2 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Tue, 27 Sep 2022 12:41:01 +0100 Subject: [PATCH 2/2] Fix cross2d There were two issues: - The overload decorator didn't overload `cross2d` - The tests called the overload directly, rather than calling `cross2d` in a jitted function. --- numba/np/arraymath.py | 2 +- numba/tests/test_np_functions.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/numba/np/arraymath.py b/numba/np/arraymath.py index e4a70374165..eb657807788 100644 --- a/numba/np/arraymath.py +++ b/numba/np/arraymath.py @@ -4620,7 +4620,7 @@ def cross2d(a, b): pass -@overload +@overload(cross2d) def cross2d_impl(a, b): if not type_can_asarray(a) or not type_can_asarray(b): raise TypingError("Inputs must be array-like.") diff --git a/numba/tests/test_np_functions.py b/numba/tests/test_np_functions.py index 4cdaf548b1b..b7e5822932e 100644 --- a/numba/tests/test_np_functions.py +++ b/numba/tests/test_np_functions.py @@ -382,6 +382,10 @@ def np_cross(a, b): return np.cross(a, b) +def nb_cross2d(a, b): + return cross2d(a, b) + + def flip_lr(a): return np.fliplr(a) @@ -4432,7 +4436,7 @@ def test_cross_exceptions(self): def test_cross2d(self): pyfunc = np_cross - cfunc = cross2d + cfunc = njit(nb_cross2d) pairs = [ # 2x2 (n-dims) ( @@ -4477,7 +4481,7 @@ def test_cross2d(self): self.assertPreciseEqual(expected, got) def test_cross2d_exceptions(self): - cfunc = cross2d + cfunc = njit(nb_cross2d) self.disable_leak_check() # test incompatible dimensions for ndim == 1