Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add to_jax methods to support Jax Array export from DataFrame and Series #16294

Merged
merged 11 commits into from
May 19, 2024

Conversation

alexander-beedie
Copy link
Collaborator

@alexander-beedie alexander-beedie commented May 17, 2024

Continuing the theme of streamlining ML preprocessing tasks using Polars (following the previous PR offering torch integration), this PR adds support for jax1 array export from DataFrame and/or Series.

Features

Supported Jax export modes:

  • df.to_jax(): export the entire frame to a single Array.
  • df.to_jax("dict"): export frame to a dictionary of column Arrays.
  • df.to_jax("dict", label=…, features=…): export frame to a dictionary of label/features Arrays (this mode is also now available for to_torch).

Additional options:

Can create arrays on a specific device2 (eg: CPU, GPU, TPU, METAL3, etc).
Can specify the memory format (eg: "c" or "fortran" order).

  • df.to_jax(device="cpu", order="c")
  • df.to_jax(device=jax.devices("gpu")[1])

Examples

import polars as pl

df = pl.DataFrame(
    {
        "age": [25, 32, 45, 22, 34],
        "income": [50000, 75000, 60000, 58000, 120000],
        "education_level": ["bachelor", "master", "phd", "bachelor", "phd"],
        "purchased": [False, True, True, False, True],
    }
).to_dummies("education_level", separator=":")
# ┌─────┬────────┬────────────────────┬──────────────────┬───────────────┬───────────┐
# │ age ┆ income ┆ education:bachelor ┆ education:master ┆ education:phd ┆ purchased │
# │ --- ┆ ---    ┆ ---                ┆ ---              ┆ ---           ┆ ---       │
# │ i64 ┆ i64    ┆ u8                 ┆ u8               ┆ u8            ┆ bool      │
# ╞═════╪════════╪════════════════════╪══════════════════╪═══════════════╪═══════════╡
# │ 25  ┆ 50000  ┆ 1                  ┆ 0                ┆ 0             ┆ false     │
# │ 32  ┆ 75000  ┆ 0                  ┆ 1                ┆ 0             ┆ true      │
# │ 45  ┆ 60000  ┆ 0                  ┆ 0                ┆ 1             ┆ true      │
# │ 22  ┆ 58000  ┆ 1                  ┆ 0                ┆ 0             ┆ false     │
# │ 34  ┆ 120000 ┆ 0                  ┆ 0                ┆ 1             ┆ true      │
# └─────┴────────┴────────────────────┴──────────────────┴───────────────┴───────────┘
  • As Array:

    df.to_jax()
    # Array([[  25,  50000,    1,    0,    0,    0],
    #        [  32,  75000,    0,    1,    0,    1],
    #        [  45,  60000,    0,    0,    1,    1],
    #        [  22,  58000,    1,    0,    0,    0],
    #        [  34, 120000,    0,    0,    1,    1]], dtype=int32)
  • As dict of column Arrays:

    df.to_jax("dict")
    # {'age': Array([25, 32, 45, 22, 34], dtype=int32),
    #  'income': Array([ 50000,  75000,  60000,  58000, 120000], dtype=int32),
    #  'education:bachelor': Array([1, 0, 0, 1, 0], dtype=uint8),
    #  'education:master': Array([0, 1, 0, 0, 0], dtype=uint8),
    #  'education:phd': Array([0, 0, 1, 0, 1], dtype=uint8),
    #  'purchased': Array([False,  True,  True, False,  True], dtype=bool)}
  • As dict of label/features Arrays:

    (note: if features are not specified they are implied as being "everything except the label")
    df.to_jax("dict", label="purchased")
    # {'label': Array([[False],
    #                  [ True],
    #                  [ True],
    #                  [False],
    #                  [ True]], dtype=bool),
    #  'features': Array([[  25,  50000,    1,    0,    0],
    #                     [  32,  75000,    0,    1,    0],
    #                     [  45,  60000,    0,    0,    1],
    #                     [  22,  58000,    1,    0,    0],
    #                     [  34, 120000,    0,    0,    1]], dtype=int32)}

Notes

As with the torch PR, Jax support is designated "CI-only" for unit tests and requirements, and you'll need to use make requirements-all if you want Polars to install the related libraries in your local development environment. The doctests are similarly gated by the presence of the jax library; they will run if you have it (or are executing on CI), and are omitted otherwise.

Footnotes

  1. Jax: https://jax.readthedocs.io/en/latest/index.html

  2. Placement on devices: https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices

  3. Accelerated JAX training on Mac: https://developer.apple.com/metal/jax/

@github-actions github-actions bot added enhancement New feature or an improvement of an existing feature python Related to Python Polars labels May 17, 2024
@alexander-beedie alexander-beedie added the A-interop Area: interoperability with other libraries label May 17, 2024
Copy link

codecov bot commented May 17, 2024

Codecov Report

Attention: Patch coverage is 96.42857% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 81.34%. Comparing base (fa6b1ca) to head (362b8d9).

Files Patch % Lines
py-polars/polars/dataframe/frame.py 95.12% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main   #16294      +/-   ##
==========================================
- Coverage   81.35%   81.34%   -0.01%     
==========================================
  Files        1403     1403              
  Lines      183463   183515      +52     
  Branches     2929     2946      +17     
==========================================
+ Hits       149253   149288      +35     
- Misses      33707    33723      +16     
- Partials      503      504       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@alexander-beedie alexander-beedie force-pushed the jax-export branch 3 times, most recently from 92bfe04 to 5cf3de6 Compare May 17, 2024 15:16
@alexander-beedie alexander-beedie changed the title feat(python): Add new to_jax methods to support export to Jax arrays from DataFrame and Series feat(python): Add new to_jax methods to support Jax Array export from DataFrame and Series May 17, 2024
@alexander-beedie alexander-beedie changed the title feat(python): Add new to_jax methods to support Jax Array export from DataFrame and Series feat(python): Add to_jax methods to support Jax Array export from DataFrame and Series May 17, 2024
@alexander-beedie alexander-beedie force-pushed the jax-export branch 2 times, most recently from 5c182c3 to 4f14dbe Compare May 18, 2024 12:24
Copy link
Member

@stinodego stinodego left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I am not familiar with jax, but if you're enthousiastic about it that's good enough reason for me to accept an integration.

I left some minor comments, but overall looks great as usual 👍

py-polars/polars/dataframe/frame.py Show resolved Hide resolved
py-polars/polars/dataframe/frame.py Show resolved Hide resolved
py-polars/tests/unit/ml/test_to_jax.py Outdated Show resolved Hide resolved
@pola-rs pola-rs deleted a comment from codspeed-hq bot May 19, 2024
Copy link

codspeed-hq bot commented May 19, 2024

CodSpeed Performance Report

Merging #16294 will degrade performances by 26.57%

Comparing alexander-beedie:jax-export (362b8d9) with main (d544750)

Summary

❌ 11 regressions
✅ 26 untouched benchmarks

⚠️ Please fix the performance issues or acknowledge them on CodSpeed.

Benchmarks breakdown

Benchmark main alexander-beedie:jax-export Change
test_groupby_h2oai_q2 17.2 ms 20.4 ms -15.53%
test_groupby_h2oai_q6 51.5 ms 66.8 ms -22.88%
test_groupby_h2oai_q9 107.8 ms 135.4 ms -20.36%
test_tpch_q11 13.7 ms 17.2 ms -20.22%
test_tpch_q13 37 ms 44.9 ms -17.7%
test_tpch_q16 24.2 ms 32.6 ms -25.64%
test_tpch_q2 15.1 ms 19.9 ms -23.92%
test_tpch_q21 811.6 ms 1,105.3 ms -26.57%
test_tpch_q22 22.9 ms 30.3 ms -24.51%
test_tpch_q7 30.4 ms 38.3 ms -20.6%
test_tpch_q8 22.1 ms 26.1 ms -15.31%

Copy link
Member

@stinodego stinodego left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what CodSpeed is on about, but I don't see any reason why this PR should cause a regression.

@stinodego stinodego merged commit 18c5601 into pola-rs:main May 19, 2024
27 checks passed
@alexander-beedie
Copy link
Collaborator Author

Not sure what CodSpeed is on about, but I don't see any reason why this PR should cause a regression.

Indeed! And it was fine before adding the extra docstring info; must be having some transient issues on their side 🤷‍♂️

@alexander-beedie alexander-beedie deleted the jax-export branch May 20, 2024 05:06
@so-rose
Copy link

so-rose commented May 20, 2024

Just 2c, I happened to need this + find it literally right now. An immense thank you from me!

@alexander-beedie
Copy link
Collaborator Author

alexander-beedie commented May 20, 2024

Just 2c, I happened to need this + find it literally right now. An immense thank you from me!

Nice to hear; feedback welcome :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
A-interop Area: interoperability with other libraries enhancement New feature or an improvement of an existing feature python Related to Python Polars
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants