Skip to content

Commit

Permalink
Merge pull request #8467 from gmarkall/gjit-to-overload
Browse files Browse the repository at this point in the history
Convert implementations using generated_jit to overload
  • Loading branch information
sklam committed Sep 28, 2022
2 parents 96989df + e101b48 commit 061f21a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 29 deletions.
77 changes: 50 additions & 27 deletions numba/np/arraymath.py
Expand Up @@ -13,7 +13,6 @@
import llvmlite.ir
import numpy as np

from numba import generated_jit
from numba.core import types, cgutils
from numba.core.extending import overload, overload_method, register_jitable
from numba.np.numpy_support import as_dtype, type_can_asarray
Expand Down Expand Up @@ -1316,34 +1315,50 @@ def prepare_ptp_input(a):
return arr


def _compute_current_val_impl_gen(op):
def _compute_current_val_impl(current_val, val):
if isinstance(current_val, types.Complex):
# The sort order for complex numbers is lexicographic. If both the
# real and imaginary parts are non-nan then the order is determined
# by the real parts except when they are equal, in which case the
# order is determined by the imaginary parts.
# https://github.com/numpy/numpy/blob/577a86e/numpy/core/fromnumeric.py#L874-L877 # noqa: E501
def impl(current_val, val):
if op(val.real, current_val.real):
return val
elif (val.real == current_val.real
and op(val.imag, current_val.imag)):
return val
return current_val
else:
def impl(current_val, val):
return val if op(val, current_val) else current_val
return impl
return _compute_current_val_impl
def _compute_current_val_impl_gen(op, current_val, val):
if isinstance(current_val, types.Complex):
# The sort order for complex numbers is lexicographic. If both the
# real and imaginary parts are non-nan then the order is determined
# by the real parts except when they are equal, in which case the
# order is determined by the imaginary parts.
# https://github.com/numpy/numpy/blob/577a86e/numpy/core/fromnumeric.py#L874-L877 # noqa: E501
def impl(current_val, val):
if op(val.real, current_val.real):
return val
elif (val.real == current_val.real
and op(val.imag, current_val.imag)):
return val
return current_val
else:
def impl(current_val, val):
return val if op(val, current_val) else current_val
return impl


def _compute_a_max(current_val, val):
pass


def _compute_a_min(current_val, val):
pass


@overload(_compute_a_max)
def _compute_a_max_impl(current_val, val):
return _compute_current_val_impl_gen(operator.gt, current_val, val)

_compute_a_max = generated_jit(_compute_current_val_impl_gen(greater_than))
_compute_a_min = generated_jit(_compute_current_val_impl_gen(less_than))

@overload(_compute_a_min)
def _compute_a_min_impl(current_val, val):
return _compute_current_val_impl_gen(operator.lt, current_val, val)


@generated_jit
def _early_return(val):
pass


@overload(_early_return)
def _early_return_impl(val):
UNUSED = 0
if isinstance(val, types.Complex):
def impl(val):
Expand Down Expand Up @@ -4527,7 +4542,11 @@ def _cross_preprocessing(x):
out[..., 2] = cp2


@generated_jit
def _cross(a, b):
pass


@overload(_cross)
def _cross_impl(a, b):
dtype = np.promote_types(as_dtype(a.dtype), as_dtype(b.dtype))
if a.ndim == 1 and b.ndim == 1:
Expand Down Expand Up @@ -4559,7 +4578,7 @@ def impl(a, b):
))

if a_.shape[-1] == 3 or b_.shape[-1] == 3:
return _cross_impl(a_, b_)
return _cross(a_, b_)
else:
raise ValueError((
"Dimensions for both inputs is 2.\n"
Expand Down Expand Up @@ -4589,8 +4608,12 @@ def _cross_preprocessing(x):
return np.asarray(cp)


@generated_jit
def cross2d(a, b):
pass


@overload(cross2d)
def cross2d_impl(a, b):
if not type_can_asarray(a) or not type_can_asarray(b):
raise TypingError("Inputs must be array-like.")

Expand Down
8 changes: 6 additions & 2 deletions numba/tests/test_np_functions.py
Expand Up @@ -382,6 +382,10 @@ def np_cross(a, b):
return np.cross(a, b)


def nb_cross2d(a, b):
return cross2d(a, b)


def flip_lr(a):
return np.fliplr(a)

Expand Down Expand Up @@ -4450,7 +4454,7 @@ def test_cross_exceptions(self):

def test_cross2d(self):
pyfunc = np_cross
cfunc = cross2d
cfunc = njit(nb_cross2d)
pairs = [
# 2x2 (n-dims)
(
Expand Down Expand Up @@ -4495,7 +4499,7 @@ def test_cross2d(self):
self.assertPreciseEqual(expected, got)

def test_cross2d_exceptions(self):
cfunc = cross2d
cfunc = njit(nb_cross2d)
self.disable_leak_check()

# test incompatible dimensions for ndim == 1
Expand Down

0 comments on commit 061f21a

Please sign in to comment.