Skip to content

Commit

Permalink
Merge pull request #8259 from guilhermeleobas/guilhermeleobas/broadca…
Browse files Browse the repository at this point in the history
…st_to_0darray

Add `np.broadcast_to(scalar_array, ())`
  • Loading branch information
sklam committed Oct 11, 2022
2 parents deeeaa2 + 87e51de commit c666425
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 19 deletions.
8 changes: 4 additions & 4 deletions numba/np/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from numba.core import types, cgutils
from numba.core.extending import overload, overload_method, register_jitable
from numba.np.numpy_support import as_dtype, type_can_asarray
from numba.np.numpy_support import numpy_version
from numba.np.numpy_support import is_nonelike, check_is_integer
from numba.np.numpy_support import (as_dtype, type_can_asarray, type_is_scalar,
numpy_version, is_nonelike,
check_is_integer)
from numba.core.imputils import (lower_builtin, impl_ret_borrowed,
impl_ret_new_ref, impl_ret_untracked)
from numba.np.arrayobj import make_array, load_item, store_item, _empty_nd_impl
Expand Down Expand Up @@ -1042,7 +1042,7 @@ def impl(x):

@overload(np.isscalar)
def np_isscalar(num):
res = isinstance(num, (types.Number, types.UnicodeType, types.Boolean))
res = type_is_scalar(num)

def impl(num):
return res
Expand Down
67 changes: 55 additions & 12 deletions numba/np/arrayobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from numba import pndindex, literal_unroll
from numba.core import types, utils, typing, errors, cgutils, extending
from numba.np.numpy_support import (as_dtype, carray, farray, is_contiguous,
is_fortran, check_is_integer)
is_fortran, check_is_integer,
type_is_scalar)
from numba.np.numpy_support import type_can_asarray, is_nonelike, numpy_version
from numba.core.imputils import (lower_builtin, lower_getattr,
lower_getattr_generic,
Expand Down Expand Up @@ -1342,6 +1343,24 @@ def codegen(context, builder, sig, args):
return sig, codegen


@intrinsic
def get_readonly_array(typingctx, arr):
# returns a copy of arr which is readonly
ret = arr.copy(readonly=True)
sig = ret(arr)

def codegen(context, builder, sig, args):
[src] = args
srcty = sig.args[0]

dest = make_array(srcty)(context, builder, src)
# Hack to return a read-only array
dest.parent = cgutils.get_null_value(dest.parent.type)
res = dest._getvalue()
return impl_ret_borrowed(context, builder, sig.return_type, res)
return sig, codegen


@register_jitable
def _can_broadcast(array, dest_shape):
src_shape = array.shape
Expand Down Expand Up @@ -1374,29 +1393,53 @@ def _can_broadcast(array, dest_shape):
'with remapped shapes')


def _default_broadcast_to_impl(array, shape):
array = np.asarray(array)
_can_broadcast(array, shape)
return _numpy_broadcast_to(array, shape)


@overload(np.broadcast_to)
def numpy_broadcast_to(array, shape):
if not type_can_asarray(array):
raise errors.TypingError('The first argument "array" must '
'be array-like')

if isinstance(shape, types.UniTuple):
if not isinstance(shape.dtype, types.Integer):
raise errors.TypingError('The second argument "shape" must '
'be a tuple of integers')

def impl(array, shape):
array = np.asarray(array)
_can_broadcast(array, shape)
return _numpy_broadcast_to(array, shape)
elif isinstance(shape, types.Integer):
if isinstance(shape, types.Integer):
def impl(array, shape):
return np.broadcast_to(array, (shape,))
return impl

elif isinstance(shape, types.UniTuple):
if not isinstance(shape.dtype, types.Integer):
msg = 'The second argument "shape" must be a tuple of integers'
raise errors.TypingError(msg)
return _default_broadcast_to_impl

elif isinstance(shape, types.Tuple) and shape.count > 0:
# check if all types are integers
if not all([isinstance(typ, types.IntegerLiteral) for typ in shape]):
msg = f'"{shape}" object cannot be interpreted as an integer'
raise errors.TypingError(msg)
return _default_broadcast_to_impl
elif isinstance(shape, types.Tuple) and shape.count == 0:
is_scalar_array = isinstance(array, types.Array) and array.ndim == 0
if type_is_scalar(array) or is_scalar_array:

def impl(array, shape): # broadcast_to(array, ())
# Array type must be supported by "type_can_asarray"
# Quick note that unicode types are not supported!
array = np.asarray(array)
return get_readonly_array(array)
return impl

else:
msg = 'Cannot broadcast a non-scalar to a scalar array'
raise errors.TypingError(msg)
else:
msg = ('The argument "shape" must be a tuple or an integer. '
'Got %s' % shape)
raise errors.TypingError(msg)
return impl


@register_jitable
Expand Down
11 changes: 11 additions & 0 deletions numba/np/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,17 @@ def type_can_asarray(arr):
return isinstance(arr, ok)


def type_is_scalar(typ):
""" Returns True if the type of 'typ' is a scalar type, according to
NumPy rules. False otherwise.
https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types
"""

ok = (types.Boolean, types.Number, types.UnicodeType, types.StringLiteral,
types.NPTimedelta, types.NPDatetime)
return isinstance(typ, ok)


def check_is_integer(v, name):
"""Raises TypingError if the value is not an integer."""
if not isinstance(v, (int, types.Integer)):
Expand Down
43 changes: 40 additions & 3 deletions numba/tests/test_array_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,28 @@ def test_broadcast_to(self):
got = cfunc(input_array, shape)
self.assertPreciseEqual(got, expected)

def test_broadcast_to_0d_array(self):
pyfunc = numpy_broadcast_to
cfunc = jit(nopython=True)(pyfunc)

inputs = [
np.array(123),
123,
True,
# can't do np.asarray() on the types below
# 'hello',
# np.timedelta64(10, 'Y'),
# np.datetime64(10, 'Y'),
]

shape = ()
for arr in inputs:
expected = pyfunc(arr, shape)
got = cfunc(arr, shape)
self.assertPreciseEqual(expected, got)
# ensure that np.broadcast_to returned a read-only array
self.assertFalse(got.flags['WRITEABLE'])

def test_broadcast_to_raises(self):
pyfunc = numpy_broadcast_to
cfunc = jit(nopython=True)(pyfunc)
Expand All @@ -828,11 +850,15 @@ def test_broadcast_to_raises(self):
# https://github.com/numpy/numpy/blob/75f852edf94a7293e7982ad516bee314d7187c2d/numpy/lib/tests/test_stride_tricks.py#L260-L276 # noqa: E501
data = [
[np.zeros((0,)), (), TypingError,
'The argument "shape" must be a tuple or an integer.'],
'Cannot broadcast a non-scalar to a scalar array'],
[np.zeros((1,)), (), TypingError,
'The argument "shape" must be a tuple or an integer.'],
'Cannot broadcast a non-scalar to a scalar array'],
[np.zeros((3,)), (), TypingError,
'The argument "shape" must be a tuple or an integer.'],
'Cannot broadcast a non-scalar to a scalar array'],
[(), (), TypingError,
'Cannot broadcast a non-scalar to a scalar array'],
[(123,), (), TypingError,
'Cannot broadcast a non-scalar to a scalar array'],
[np.zeros((3,)), (1,), ValueError,
'operands could not be broadcast together with remapped shapes'],
[np.zeros((3,)), (2,), ValueError,
Expand All @@ -855,13 +881,24 @@ def test_broadcast_to_raises(self):
'The second argument "shape" must be a tuple of integers'],
['hello', (3,), TypingError,
'The first argument "array" must be array-like'],
[3, (2, 'a'), TypingError,
'object cannot be interpreted as an integer'],
]
self.disable_leak_check()
for arr, target_shape, err, msg in data:
with self.assertRaises(err) as raises:
cfunc(arr, target_shape)
self.assertIn(msg, str(raises.exception))

def test_broadcast_to_corner_cases(self):
@njit
def _broadcast_to_1():
return np.broadcast_to('a', (2, 3))

expected = _broadcast_to_1.py_func()
got = _broadcast_to_1()
self.assertPreciseEqual(expected, got)

def test_broadcast_to_change_view(self):
pyfunc = numpy_broadcast_to
cfunc = jit(nopython=True)(pyfunc)
Expand Down
3 changes: 3 additions & 0 deletions numba/tests/test_np_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,9 @@ def values():
yield 4.234
yield True
yield None
yield np.timedelta64(10, 'Y')
yield np.datetime64('nat')
yield np.datetime64(1, 'Y')

pyfunc = isscalar
cfunc = jit(nopython=True)(pyfunc)
Expand Down

0 comments on commit c666425

Please sign in to comment.