Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] better control of compilation of jitclass members #9565

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

andy-bell101
Copy link

@andy-bell101 andy-bell101 commented May 9, 2024

I'm attempting allowing users to control how the methods on their jitclasses are compiled. Currently everything is passed through njit and there's an opportunity to do something different.

Proposed behaviours

There are two ideas for the behaviour that I had that I've listed below. At this stage I prefer option 1 because it's less complex and introduces fewer changes to the public API. But I can't deny that the second option would be more convenient for the use-case that prompted me to open the PR, since I'd just like to pass cache=True to all my member functions.

Currently the code in the PR implements option 1.

Option 1: Options from jitclass or jitted method (but not both)

The first option is that you can "pre-jit" your method using njit and any njit options you give to jitclass are irrelevant for that method.

@jitclass(parallel=True)
class MyClass:
    def my_method_1(self, a):
        # this method is compiled with parallel=True
        ...

    @njit(cache=True)
    def my_method_2(self, b):
        # this method is compiled with *only* the arguments given in 
        # the njit decorator here
        # so just `cache=True` in this case
        ...

Option 2: Options from jitclass and jitted method are combined

The second option is that you can take the union of the two sets of arguments (with the method's keyword arguments taking priority) and compile the members that way. The easiest way I can think of to do this is to introduce a new jitmethod decorator that takes the same keyword arguments as njit.

@jitclass(parallel=True, cache=False)
class MyClass:
    def my_method_1(self, a):
        # this method is compiled with parallel=True and cache=False
        ...

    @jitmethod(cache=True)
    def my_method_2(self, b):
        # this method is compiled with the union of the arguments 
        # given in the jitmethod decorator and the arguments to jitclass, 
        # with the jitmethod arguments taking precedent
        # so in the end this method is compiled with `parallel=True, cache=True`
        ...

Notes

  • I had to bump the flake8 version in the pre-commit config because it was causing false-negatives with the mypy ignore comments

WIP

  • I need to figure out how to actually test that the compilation is working with the new args. I could do mocking but I'd prefer to do something real. Search the repo for testing functions that demonstrate:
    • cache=True
    • parallel=True
    • nogil=True
  • I should probably try to type-hint the files that I touch as part of this PR, or at least the functions I modify.
  • Modify StructRefProxy.__new__ to accept kwargs to control the njit of the constructor, and search for other njit uses in StructRef and StructRefProxy.
  • Documentation

@dlee992
Copy link
Contributor

dlee992 commented May 9, 2024

Ah, cool! Perhaps can you also consider how to do the similar thing for structref? It also has at least one location using njit directly without any options.

Here:

@njit
def ctor(*args):
return cls(*args)
# cache it to attribute to avoid recompilation
cls.__numba_ctor = ctor

@andy-bell101
Copy link
Author

It looks like it would be simple enough to roll a modification to StructRefProxy into this work too via a **kwargs in the __new__ method

@guilhermeleobas
Copy link
Collaborator

Hi @andy-bell101, thanks for your contribution.

What about, rather than using @njit, pass a decorator that would just change the set of keyword arguments on it?

@jitclass(parallel=True)
class MyClass:
    def my_method_1(self, a):
        # this method is compiled with parallel=True
        ...

    @update_jit_args(cache=True)  # or @patch_jit_option(cache=True)
    def my_method_2(self, b):
        # this method is compiled with *only* the arguments given in 
        # the njit decorator here
        # so just `cache=True` in this case
        ...

@andy-bell101
Copy link
Author

andy-bell101 commented May 10, 2024

If I understand your suggestion correctly I think it's the same behaviour as option 2 above, unless I missed something

Edit: actually no, I misunderstood.

I'm not sure what advantage using @update_jit_args is vs. calling njit directly

@sklam sklam self-requested a review May 14, 2024 14:29
@andy-bell101
Copy link
Author

I have been working on this in the background for a while. The current problem I'm stuck against is that the methods defined on the class are not the final methods on the object. There's an additional wrapper compiled around them (so you never hit the cache for the underlying method) and that wrapper itself is dynamically created via exec (meaning it can't be cached because it has no source that inspect can see). Currently looking for a workaround

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants