Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docos] Add new FAQ entry: "Type-check framework-agnostic tensors with jaxtyping! It is glorious!" #357

Open
danielgafni opened this issue Apr 4, 2024 · 6 comments

Comments

@danielgafni
Copy link

danielgafni commented Apr 4, 2024

Passing a normal numpy array to @jaxtyped(typechecker=beartype)-decorated function raises an error.
Passing a jax.numpy array of the same shape and type doesn't.

from beartype import beartype
import numpy as np
imort jax.numpy as jnp

@jaxtyped(typechecker=beartype)
def func(a: Float[Array, "dim"]):
    return a

func(np.zeros(5, dtype=float))

Traceback (most recent call last):
  File ".../.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 418, in wrapped_fn
    param_fn(*args, **kwargs)
  File "<@beartype(__main__.check_params) at 0x7628df1f9f80>", line 29, in check_params
beartype.roar.BeartypeCallHintParamViolation: Function __main__.check_params() parameter a="array([0., 0., 0., 0., 0.])" violates type hint <class 'jaxtyping.Float[Array, 'dim']'>, as this value is not an instance of the underlying array type <class 'jax.Array'>.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File ".../.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-f53de0c809f7>", line 1, in <module>
    func(np.zeros(5, dtype=float))
  File ".../.venv/lib/python3.11/site-packages/jaxtyping/_decorator.py", line 447, in wrapped_fn
    raise TypeCheckError(msg) from e
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of func.
The problem arose whilst typechecking parameter 'a'.
Actual value: array([0., 0., 0., 0., 0.])
Expected type: <class 'Float[Array, 'dim']'>.
----------------------
Called with parameters: {'a': array([0., 0., 0., 0., 0.])}
Parameter annotations: (a: Float[Array, 'dim']).

func(jnp.zeros(5))
Array([0., 0., 0., 0., 0.], dtype=float32)

Versions:

jaxtyping = 0.2.28
beartype = 0.18.2
numpy = 1.26.4

I'm not sure if this is an issue with jaxtyping or beartype

@leycec
Copy link
Member

leycec commented Apr 5, 2024

Thanks a heap overflow for the detailed issue, @danielgafni. Let's ping @patrick-kidger, for he is the king of jaxtyping.

If I had to venture a guess as to the underlying issue, I suspect it's your use of Array in the type hint Float[Array, "dim"]. What is Array, specifically? Nobody knows. I myself can only scratch my bald head. In any case, I suspect you'll find joy by replacing Array with the standard numpy.ndarray type: e.g.,

from beartype import beartype
from jaxtyping import Float, jaxtyped
import numpy as np

@jaxtyped(typechecker=beartype)
def func(a: Float[np.ndarray, "dim"]):
    return a

func(np.zeros(5, dtype=float))  # <-- this succeeds! praise be to dr. kidger 
func("Pretty sure this ain't a NumPy array... but not certain.")  # <-- this fails! that's good. failure is good.

I can personally confirm that raises the expected jaxtyping type-checking violation:

Traceback (most recent call last):
  File "/tmp/jaxtyping/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 418, in wrapped_fn
    param_fn(*args, **kwargs)
  File "<@beartype(__main__.check_params) at 0x7fd8a674e160>", line 29, in check_params
beartype.roar.BeartypeCallHintParamViolation: Function __main__.check_params() parameter a="Pretty sure this ain't a NumPy array... but not certain." violates type hint <class 'jaxtyping.Float[ndarray, 'dim']'>, as this value is not an instance of the underlying array type <class 'numpy.ndarray'>.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/leycec/tmp/mopy.py", line 12, in <module>
    func("Pretty sure this ain't a NumPy array... but not certain.")
  File "/tmp/jaxtyping/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 447, in wrapped_fn
    raise TypeCheckError(msg) from e
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of func.
The problem arose whilst typechecking parameter 'a'.
Actual value: "Pretty sure this ain't a NumPy array... but not certain."
Expected type: <class 'Float[ndarray, 'dim']'>.
----------------------
Called with parameters: {'a': "Pretty sure this ain't a NumPy array... but not certain."}
Parameter annotations: (a: Float[ndarray, 'dim']).

You probably want to check out the official jaxtyping section on NumPy integration. Basically, the first object subscripting (indexing) the jaxtyping.Float[...] is supposed to be the expected tensor type. In the case of NumPy, that's numpy.ndarray. Right? So, jaxtyping is probably behaving as expected here. If all else fails, blame @leycec. 😁

Thankfully, everything is probably behaving as expected here. Let's quietly close this and pretend we actually did something!

@leycec leycec closed this as completed Apr 5, 2024
@danielgafni
Copy link
Author

Hey @leycec , thanks for the explanation!

I had a false assumption I could kinda keep the type annotations framework-agnostic and reuse the same function signature for both numpy and jax. I now understand this is not really possible.

I probably had more joy from reading your answer (as always!) than I'll have from switching to numpy type annotations :)

Thanks!

@patrick-kidger
Copy link
Contributor

If you want to be framework-agnostic then you can use a Union over the array types:

AnyArray: TypeAlias = Union[np.ndarray, jax.Array]
Float[AnyArray, "..."]

@danielgafni
Copy link
Author

Oh amazing! Would TypeVar also work? This way I could check if a function inputs and outputs are either all numpy or all torch, for example.

@danielgafni
Copy link
Author

Seems like it doesn't:

from typing import TypeVar, Union

import numpy as np
import jax.numpy as jnp
from jaxtyping import Float

AnyArray = TypeVar(
    "AnyArray",
    bound=Union[np.ndarray, jnp.ndarray],
)

Float[AnyArray, "dim"]

Traceback (most recent call last):
  File ".../.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-719230540918>", line 1, in <module>
    Float[AnyArray, "dim"]
    ~~~~~^^^^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/jaxtyping/_array_types.py", line 623, in __getitem__
    out = _make_array(array_type, dim_str, cls.dtypes, cls.__name__)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/jaxtyping/_array_types.py", line 567, in _make_array
    out = _make_array_cached(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../.venv/lib/python3.11/site-packages/jaxtyping/_array_types.py", line 530, in _make_array_cached
    if issubclass(array_type, AbstractArray):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: issubclass() arg 1 must be a class

@leycec
Copy link
Member

leycec commented Apr 5, 2024

...heh. Yeah. Runtime TypeVar support is super-hard. To my knowledge, nobody fully supports TypeVar at runtime – not even @beartype. So, this isn't so much the fault of jaxtyping as it is a wicked problem that humanity might never solve within the known lifetime of the Universe.

@patrick-kidger: That's... so hot. I should definitely document that in a new FAQ entry. Framework-agnostic tensor typing is something most of us are probably interested in. Let's reopen this under an appropriate new hypeworthy title. Build the hype!

@leycec leycec reopened this Apr 5, 2024
@leycec leycec changed the title jaxtyping doesn't work with numpy but does with jax.numpy [Docos] Add new FAQ entry documenting framework-agnostic tensor typing with jaxtyping Apr 5, 2024
@leycec leycec changed the title [Docos] Add new FAQ entry documenting framework-agnostic tensor typing with jaxtyping [Docos] Add new FAQ entry: "Type-check framework-agnostic tensors with jaxtyping! It is glorious!" Apr 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants