Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration with Python performance tools like JAX #3432

Open
ThomasMGeo opened this issue Mar 11, 2024 · 10 comments
Open

Integration with Python performance tools like JAX #3432

ThomasMGeo opened this issue Mar 11, 2024 · 10 comments
Labels
Type: Feature New functionality

Comments

@ThomasMGeo
Copy link

ThomasMGeo commented Mar 11, 2024

What should we add?

There are a few ways to speed up raw numpy calculations. I looked at two options:

  1. JAX, See documentation here
  2. Numba, documentation here

This was tested on a M2 MacBook Pro (so no NVIDIA GPU) but I didn't have any trouble installing either package. Overall for two basic calcualtions, I saw speed ups on the order of 3-10x. These results are just from a few hours of hacking, and not intended to be strict benchmarks.

My Take

JAX was much easier to use, and faster. It felt like much more 'drop in' as a numpy replacement than re-writing JIT'd functions that didn't support all of numpy's functionality. If I needed to write faster numpy code, it was straightforward doing so with JAX, with or without a GPU.

Future packages or workflows to consider:

CuPy, downside is this requires a CUDA enabled GPU
Cython might be another option
Multiprocessing?

Notebook

Simple test notebook is here.

Reference

No response

@ThomasMGeo ThomasMGeo added the Type: Feature New functionality label Mar 11, 2024
@winash12
Copy link

How about adding a xarray + dask example ? How does a Jupyter notebook that includes parallel processing takes place ?

@jthielen
Copy link
Collaborator

Just to add this for the sake of reference: about a year and a half ago I did some experiments with a Numba-based re-implementation of MetPy's CAPE calculations, as shown here: https://github.com/jthielen/cumulonumba/blob/main/examples/cumulonumba_v_metpy_rough_test.ipynb. Key takeaways were that the speed up with Numba was substantial (by two or three orders of magnitude), but that the JIT compilation costs were not insignificant (and perhaps a deal breaker for some use cases). This was yet another factor favoring Cython over Numba for MetPy's purposes.

@ThomasMGeo
Copy link
Author

Thanks for the add @jthielen ! Have you had the chance to mess around with JAX? I know your quite busy :)

But overall I agree that numba is not the solution

@ThomasMGeo
Copy link
Author

@winash12 , do you have a specific problem in mind to solve with xarray/dask?

@jthielen
Copy link
Collaborator

@ThomasMGeo Only a little bit, and not in this context unfortunately! That being said, for some of the underlying array operations (intersection finding, fixed point iteration), my hunch is that a JAX-type approach (given its more functional way of doing things) requiring more refactoring than Numba would. I could be mistaken on that though too!

@dopplershift
Copy link
Member

@winash12 Interoperability with Dask is one of the major technical areas we are focusing on at the moment.

@winash12
Copy link

winash12 commented Mar 16, 2024

Regarding the cython usage how do you propose to take it forward ? Will the cython code be in python and converted to C code by the compiler or can we add C or C++ functions ? The second would need C makefiles for the build to go through plus modifications to the LDPATH etc. From an implementation perspective my question is are you planning to permit usage of cdef functions or purely def functions ? Looking at the implementation of scipy they have many classes that do use cdef functions,

If we are planning to use cdef functions then worthwhile to look at xtensor - https://xtensor.readthedocs.io/en/latest/

@ThomasMGeo As an example let us assume I want to calculate potential vorticity (PV) for 4 different times and the input data is present in a single netCDF file. Now can I do the calculation of the PV of the four different time instances in parallel ? Most definitely I can as they are mutually independent data snapshots. For that I need to use dask arrays if I am not mistaken. Last time I attended the con call I recall everyone agreeing that there isn't a notebook yet to do this.

@winash12
Copy link

winash12 commented Mar 16, 2024

Actually looking at it again if all we want is a faster version of numpy then I question the need for cython. xtensor has a python wrapper which we can use - https://github.com/xtensor-stack/xtensor-python

https://xtensor.readthedocs.io/en/latest/numpy.html

@dopplershift dopplershift changed the title Potential ways to speed up MetPy Integration with Python performance tools like JAX Mar 21, 2024
@dopplershift
Copy link
Member

@winash12 I really need to update our roadmap with this stuff (#1655) but the plan is to only update particular places in the code that are bottlenecks to doing calculations at scale--and to see what's slow through benchmarks. The top offender that comes to mind is CAPE/CIN, mostly due to moist_lapse(). CAPE/CIN is especially problematic because the nature of the calculation resists vectorization, so direct looping is the only option. Hence, we look at compiled solutions.

That does not imply we're looking at general solutions for a faster numpy. It is really important for ease of maintenance and contributions from the community that we stick to Python. The nature of @ThomasMGeo's investigation was really to look at how well people using tools like JAX or CuPy can pass data from those libraries (which are numpy-like) into MetPy and have things "just work". We have no plans to depend on them, however. The same can be said about our plans for supporting Dask--we want to make sure we facilitate workflows using Dask (like the one you described for multiple levels of PV analysis), but we will not be using Dask directly within MetPy.

Currently on the table are:

  • Cython
  • Numba
  • Rust
  • C++ with pybind11

The leader is Cython due to how commonplace it is within the scientific Python ecosystem. Also, I am heavily interested in the ability to run Python (with MetPy) within web browsers, so any solution chosen needs to be amenable to WASM (Web Assembly), so that likely rules out Numba. Rust/C++ are included for completeness (Rust mainly because there's a lot of momentum there, but I'm unclear on the numpy integration story), but I'm 95% sure we're going down the Cython route.

@leaver2000
Copy link

leaver2000 commented Apr 18, 2024

@Z-Richard mentioned I should drop my code into a public repo

The code has a single runtime dependency on numpy and requires Cython to compile the moist_lapse ODE. There is an notebook that pulls from the weatherbench2 zarr storage. The code needs to be compiled in a certain way to achieve code coverage on the compiled binary which slows things down quite a bit.

This is dcape

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Type: Feature New functionality
Projects
None yet
Development

No branches or pull requests

5 participants