Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent compilation when no_cpython_wrapper is set and restructure linking IR modules #9566

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

aseyboldt
Copy link
Contributor

@aseyboldt aseyboldt commented May 9, 2024

For now this PR is meant for discussion and experimentation only.

Right now, numba fully compiles functions (ie generates the llvm ir, runs all optimization and runs llvm codegen) whenever a function is typed the first time. This can lead to unnecessary long compile times if a function is never called from python, but only from other jited function, because parts of the compilation is never used.

This was discussed previously here and in the last developer meeting.

As a pathological example, a chain of njit functions where we only call the last one directly, has a compile time that is more or less quadratic in the chain length:

import numba
import numpy as np

options = dict(no_cfunc_wrapper=True, no_cpython_wrapper=True, cache=False)

def make_func(func):
    # A small function that calls a previous function and reshapes the result
    @numba.njit(**options)
    def foo(x):
        return func(2 * x.reshape((9,))).reshape((3, 3))

    return foo

@numba.njit(**options)
def first(x):
    return x

# We build a chain of 50 of those simple function
current = first
for i in range(50):
    current = make_func(current)

# Wrap it in a final function with cpython wrapper
@numba.njit(cache=False)
def outer_logp_wrapper(x):
    return current(x)

# Compile the function
print(outer_logp_wrapper(np.zeros(9)))

This takes ~45s on my machine.

This PR splits the compilation in the CPUCodeLibrary into two phases, "finalization" and "compilation". Previously objects could be in one of two states: non-finalized and finalized. non-finalized objects could still have the llvm-ir changed, and finalized objects were fully optimized and compiled. After the PR there are three states: non-finalized, finalized, and compiled. After finalization the llvm-ir can no longer be changed, and the module can be linked into the modules of other functions. But llvm codegen has not happend yet, and no callable function exists.

(It is I think somewhat of an open question what optimizations should be run at the end of the finalization phase. In this PR all but the cheap optimizations are delayed until compilation, so that unoptimized modules are linked into other modules. This means no destructive optimization passes are executed before the context of the function is known, but it also means that if a function is called from several different functions some optimizations might be running several times now. This might have a nice solution with the new llvm pass-manager by using for instance https://llvm.org/doxygen/classllvm_1_1PassBuilder.html#ad6f258d31ffa2d2e4dfaf990ba596d0d. Previously, all optimizations were run before linking, so that modules that are linked into other modules are optimized a second time. I didn't really investigate what the effects of this change are, but from earlier experiments I think it probably overall increases compile times a bit, but sometimes makes vectorization possible where it previously wasn't.)

If typing and compilation were independent in numba, this change might be enough to prevent unnecessary compilation (?), but because all typing currently requires the creation of a callable, this PR also adds a new jit kwarg no_wrapper (better names welcome...). If this is enabled for a njit function, we never produce a compiled function, and always keep the corresponding CPDCodeLibrary in the finalized state.

If we change the compilation options in the above example to

options = dict(no_cfunc_wrapper=True, no_cpython_wrapper=True, cache=False, no_wrapper=True)

it now compiles only the last function instead of also all intermediary functions, which takes 5s instead of the previous 45s. In some real world functions in pytensor/pymc this speeds up compilation by about a factor of 2.

CC @stuartarchibald

@aseyboldt
Copy link
Contributor Author

Some notes on the implementation:

  • For some reason I get slightly different final llvm modules in the code with and without no_wrapper, and the performance is a bit better with no_wrapper enabled. I'm pretty confused by this, I don't see how there can be any difference between those, so I guess there is probably a bug somewhere...
  • The typing system uses the compiled function as a key. But since there is no compiled function for a no_wrapper function this can't work. In the PR I just use the fndesc as a key instead, but this doesn't sound like a proper solution to me...

@aseyboldt
Copy link
Contributor Author

I realized there doesn't really seem a good reason for a no_wrapper argument, we can just just the existing no_cpython_wrapper...

@dlee992
Copy link
Contributor

dlee992 commented May 13, 2024

re-edit:
After I rebuilt numba dispatcher, the code snippet in the first comment throws an error:

  File "/Users/dali/Code/open-numba/numba/core/types/common.py", line 52, in __init__
    if isinstance(dtype, Buffer):
  File "/opt/homebrew/anaconda3/envs/numbaenv/lib/python3.10/abc.py", line 119, in __instancecheck__
    return _abc_instancecheck(cls, instance)
RecursionError: maximum recursion depth exceeded in comparison

Any clue to solve this?

Updates:

  • if set range(37) will avoid this RecursionError, >=38 will throw it.
  • In my machine, w/ vs w/o this PR is 2.47s vs 10.41s. The compilation speedup is confirmed for this case.

@aseyboldt
Copy link
Contributor Author

I was running it in jupyter, and it seems that might be setting a higher recursion limit? I get the same error if I run it with plain python. (both with the default branch and the new one by the way)
You can increase that limit using

import sys
sys.setrecursionlimit(1500)

@aseyboldt
Copy link
Contributor Author

aseyboldt commented May 15, 2024

Sorry everyone, I had a bit too much fun making a few animations about the linking process :-)

Each square represents a llvm ir module of a function, and each edge indicates that the function uses the other function, so that that module needs to be linked.

numba-linking-current.mp4
numba-linking-proposed.mp4

(always wanted to try manim)

@aseyboldt aseyboldt changed the title Add no_wrapper jit option to disable independent compilation and delay codegen Prevent compilation when no_cpython_wrapper is set and restructure linking IR modules May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants