Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for providing signatures for jitted functions that accept functions as arguments #9557

Open
ShaiAvr opened this issue May 5, 2024 · 0 comments

Comments

@ShaiAvr
Copy link

ShaiAvr commented May 5, 2024

Feature request

I use numba a lot to optimize my numerical simulations and most of the time I use an explicit signature to compile my jitted functions eagerly (as explained here). Sometimes, my jitted functions accept a function as a parameter (for example, a numerical integration scheme which accepts the integrated function), and I couldn't find how to provide a signature for such a function. For example:

from collections.abc import Callable

import numpy as np
import numba as nb

jit_opts = dict(
    nopython=True, nogil=True, cache=False, error_model="numpy", fastmath=True
)


@nb.jit([nb.float64(nb.float64)], **jit_opts)
def quarter_circle(x: float) -> float:
    return np.sqrt(1 - x**2)


# How to annotate an argument that is a function?
@nb.jit([nb.float64(???, nb.float64, nb.float64)], **jit_opts)
def integrate(f: Callable[[float], float], a: float, b: float) -> float:
    # Numerical integration algorithm


print("pi =", 4 * integrate(quarter_circle, 0, 1))

I couldn't find any example in the documentation on how to annotate such a function to enable eager compilation. I could use nb.typeof(quarter_circle), but this approach will only work for the specific function quarter_circle as we can observe with this code (code taken from https://stackoverflow.com/questions/64776569/numba-signature-for-jitted-function-as-argument):

import numba as nb


@nb.jit(
    nb.int32(nb.int32),
    nopython=True,
    nogil=True,
)
def bar(a):
    return 2 * a


@numba.jit(
    nb.int32(nb.int32),
    nopython=True,
    nogil=True,
)
def baz(a):
    return 3 * a


@numba.jit(
    nb.int32(nb.typeof(bar), nb.int32),
    nopython=True,
    nogil=True,
)
def foo(fn, a):
    return fn(a)


print(foo(bar, 2))
print(foo(baz, 2))

This code will print 4 for the first call, and crash with a TypeError on the second call. So, it seems I have to add a signature for every individual function I'd want to use with foo which is tedious and very limiting.

I think there should be an API for annotating functions that accept functions as arguments and some examples in the documentation. If such an API exists and I failed to find it, then I think the documentation should be updated to include examples for this use case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants