Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

toms748_scan doesn't work with JAX backend #2466

Open
1 task done
TC01 opened this issue Apr 13, 2024 · 6 comments · May be fixed by #2467
Open
1 task done

toms748_scan doesn't work with JAX backend #2466

TC01 opened this issue Apr 13, 2024 · 6 comments · May be fixed by #2467
Assignees
Labels
bug Something isn't working

Comments

@TC01
Copy link

TC01 commented Apr 13, 2024

Summary

Hello; perhaps this is known but I thought I'd file a bug report just in case. I was testing the upper_limits API and discovered that the example given in the documentation doesn't seem to work with the JAX backend. It fails with a complaint about an unhashable array type (see the traceback). If I switch to the numpy backend, as shown in the documentation, it runs fine.

I see this on both EL7 in an ATLAS environment (StatAnalysis,0.3,latest) and on my own desktop (Fedora 38); in both cases I have the same pyhf version (0.7.6) and I manually installed jax[CPU] == 0.4.26 on top of that.

I should add that things work fine with JAX if I use the version of upper_limits where I pass in a range of mu values to scan-- so I guess maybe some extra type conversion is needed to go from the JAX array type to a list or something hashable?

OS / Environment

# Linux
$ cat /etc/os-release
NAME="Fedora Linux"
VERSION="38 (Thirty Eight)"
ID=fedora
VERSION_ID=38
VERSION_CODENAME=""
PLATFORM_ID="platform:f38"
PRETTY_NAME="Fedora Linux 38 (Thirty Eight)"
ANSI_COLOR="0;38;2;60;110;180"
LOGO=fedora-logo-icon
CPE_NAME="cpe:/o:fedoraproject:fedora:38"
DEFAULT_HOSTNAME="fedora"
HOME_URL="https://fedoraproject.org/"
DOCUMENTATION_URL="https://docs.fedoraproject.org/en-US/fedora/f38/system-administrators-guide/"
SUPPORT_URL="https://ask.fedoraproject.org/"
BUG_REPORT_URL="https://bugzilla.redhat.com/"
REDHAT_BUGZILLA_PRODUCT="Fedora"
REDHAT_BUGZILLA_PRODUCT_VERSION=38
REDHAT_SUPPORT_PRODUCT="Fedora"
REDHAT_SUPPORT_PRODUCT_VERSION=38
SUPPORT_END=2024-05-14

Steps to Reproduce

Install pyhf and JAX through pip; then try to run the example in the documentation, but with the JAX backend instead of numpy:

import numpy as np
import pyhf
pyhf.set_backend("JAX")
model = pyhf.simplemodels.uncorrelated_background(
    signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
)
observations = [51, 48]
data = pyhf.tensorlib.astensor(observations + model.config.auxdata)
obs_limit, exp_limits = pyhf.infer.intervals.upper_limits.toms748_scan(
    data, model, 0., 5., rtol=0.01
)

File Upload (optional)

No response

Expected Results

Ideally the example would run without crashing (as it does with the numpy backend).

Actual Results

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 130, in toms748_scan
    toms748(f, bounds_low, bounds_up, args=(level, 0), k=2, xtol=atol, rtol=rtol)
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1374, in toms748
    result = solver.solve(f, a, b, args=args, k=k, xtol=xtol, rtol=rtol,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1229, in solve
    fc = self._callf(c)
         ^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1083, in _callf
    fx = self.f(x, *self.args)
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 95, in f
    f_cached(poi)[0] - level
    ^^^^^^^^^^^^^
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 80, in f_cached
    if poi not in cache:
       ^^^^^^^^^^^^^^^^
TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'

pyhf Version

$ pyhf --version
pyhf, version 0.7.6

Code of Conduct

  • I agree to follow the Code of Conduct
@TC01 TC01 added bug Something isn't working needs-triage Needs a maintainer to categorize and assign labels Apr 13, 2024
@TC01
Copy link
Author

TC01 commented Apr 13, 2024

It seems that this can be fixed by explicitly converting poi to a float before checking if it's in the hypotest cache here. I printed out the types and at some point in the test run poi switched from being a numpy float64 to being a jaxlib type; I can trace further to see exactly why that happens and then maybe submit a PR.

@matthewfeickert
Copy link
Member

First thing we'll need to do is understand why there aren't test failures

pyhf/tests/test_infer.py

Lines 26 to 57 in 64ab264

def test_toms748_scan(tmp_path, hypotest_args):
"""
Test the upper limit toms748 scan returns the correct structure and values
"""
_, data, model = hypotest_args
results = pyhf.infer.intervals.upper_limits.toms748_scan(
data, model, 0, 5, rtol=1e-8
)
assert len(results) == 2
observed_limit, expected_limits = results
observed_cls = pyhf.infer.hypotest(
observed_limit,
data,
model,
model.config.suggested_init(),
model.config.suggested_bounds(),
)
expected_cls = np.array(
[
pyhf.infer.hypotest(
expected_limits[i],
data,
model,
model.config.suggested_init(),
model.config.suggested_bounds(),
return_expected_set=True,
)[1][i]
for i in range(5)
]
)
assert observed_cls == pytest.approx(0.05)
assert expected_cls == pytest.approx(0.05)

which will probably mean revisiting PR #1274. So writing a failing test would be a good start, so that a PR can make it pass.

@TC01
Copy link
Author

TC01 commented Apr 15, 2024

Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.

@kratsg
Copy link
Contributor

kratsg commented Apr 16, 2024

Ah, I didn't realize there was a test for this! Does that get run with all the backends? When I get a chance I can try running that locally too.

Nope, which is likely which explains why it wasn't caught. (Adding the backend fixture in the test will have it run on all the backends).

@matthewfeickert matthewfeickert removed the needs-triage Needs a maintainer to categorize and assign label Apr 17, 2024
@matthewfeickert matthewfeickert self-assigned this Apr 17, 2024
@matthewfeickert matthewfeickert linked a pull request Apr 18, 2024 that will close this issue
4 tasks
@kratsg
Copy link
Contributor

kratsg commented Apr 18, 2024

@matthewfeickert i saw the PR, and I think we need to swap the way we're approaching this. Here's my suggestion instead of type-casting - we need to add in shims across each lib and move some functions into our tensorlib instead to make them backend-dependent (or use a shim to swap them out as needed, like we do for scipy.optimize)

See this example:

from functools import lru_cache
import time
import timeit

import jax.numpy as jnp
import jax
import tensorflow as tf



def slow(n):
    time.sleep(1)
    return n**2

fast = lru_cache(maxsize=None)(slow)

fast_jax = jax.jit(slow)
fast_tflow = tf.function(jit_compile=True)(slow)

value = 5
print('slow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast')
print(timeit.timeit(lambda: [fast(value), fast(value), fast(value), fast(value), fast(value)], number=1))


value = jnp.array(5)
print('slow, jax')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, jax')
print(timeit.timeit(lambda: [fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value), fast_jax(value)], number=1))

value = tf.convert_to_tensor(5)
print('slow, tensorflow')
print(timeit.timeit(lambda: [slow(value), slow(value), slow(value), slow(value), slow(value)], number=1))
print('fast, tensorflow')
print(timeit.timeit(lambda: [fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value), fast_tflow(value)], number=1))

which outputs

$ python cache.py
slow
5.012567336
fast
1.0029977690000003
slow, jax
5.043927394000001
fast, jax
1.0195144690000006
slow, tensorflow
5.017408181999997
fast, tensorflow
1.0631543910000012

so we can definitely cache those values by JIT-ing for the toms748 scan here and that's probably what we want to do. My suggestion might be that we support pyhf.tensor.jit with something similiar to the signature of jax.jit across all backends (yes even numpy, but numpy would be an lru_cache).

@matthewfeickert
Copy link
Member

we need to add in shims across each lib and move some functions into our tensorlib instead to make them backend-dependent

Okay, sounds good. Let's start up a seperate series of PRs to do this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants