New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Docos] Add new FAQ entry: "Type-check framework-agnostic tensors with jaxtyping
! It is glorious!"
#357
Comments
Thanks a heap overflow for the detailed issue, @danielgafni. Let's ping @patrick-kidger, for he is the king of If I had to venture a guess as to the underlying issue, I suspect it's your use of from beartype import beartype
from jaxtyping import Float, jaxtyped
import numpy as np
@jaxtyped(typechecker=beartype)
def func(a: Float[np.ndarray, "dim"]):
return a
func(np.zeros(5, dtype=float)) # <-- this succeeds! praise be to dr. kidger
func("Pretty sure this ain't a NumPy array... but not certain.") # <-- this fails! that's good. failure is good. I can personally confirm that raises the expected Traceback (most recent call last):
File "/tmp/jaxtyping/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 418, in wrapped_fn
param_fn(*args, **kwargs)
File "<@beartype(__main__.check_params) at 0x7fd8a674e160>", line 29, in check_params
beartype.roar.BeartypeCallHintParamViolation: Function __main__.check_params() parameter a="Pretty sure this ain't a NumPy array... but not certain." violates type hint <class 'jaxtyping.Float[ndarray, 'dim']'>, as this value is not an instance of the underlying array type <class 'numpy.ndarray'>.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/leycec/tmp/mopy.py", line 12, in <module>
func("Pretty sure this ain't a NumPy array... but not certain.")
File "/tmp/jaxtyping/lib/python3.12/site-packages/jaxtyping/_decorator.py", line 447, in wrapped_fn
raise TypeCheckError(msg) from e
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of func.
The problem arose whilst typechecking parameter 'a'.
Actual value: "Pretty sure this ain't a NumPy array... but not certain."
Expected type: <class 'Float[ndarray, 'dim']'>.
----------------------
Called with parameters: {'a': "Pretty sure this ain't a NumPy array... but not certain."}
Parameter annotations: (a: Float[ndarray, 'dim']). You probably want to check out the official Thankfully, everything is probably behaving as expected here. Let's quietly close this and pretend we actually did something! |
Hey @leycec , thanks for the explanation! I had a false assumption I could kinda keep the type annotations framework-agnostic and reuse the same function signature for both I probably had more joy from reading your answer (as always!) than I'll have from switching to Thanks! |
If you want to be framework-agnostic then you can use a AnyArray: TypeAlias = Union[np.ndarray, jax.Array]
Float[AnyArray, "..."] |
Oh amazing! Would |
Seems like it doesn't: from typing import TypeVar, Union
import numpy as np
import jax.numpy as jnp
from jaxtyping import Float
AnyArray = TypeVar(
"AnyArray",
bound=Union[np.ndarray, jnp.ndarray],
)
Float[AnyArray, "dim"]
Traceback (most recent call last):
File ".../.venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-8-719230540918>", line 1, in <module>
Float[AnyArray, "dim"]
~~~~~^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.11/site-packages/jaxtyping/_array_types.py", line 623, in __getitem__
out = _make_array(array_type, dim_str, cls.dtypes, cls.__name__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.11/site-packages/jaxtyping/_array_types.py", line 567, in _make_array
out = _make_array_cached(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.11/site-packages/jaxtyping/_array_types.py", line 530, in _make_array_cached
if issubclass(array_type, AbstractArray):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: issubclass() arg 1 must be a class |
...heh. Yeah. Runtime @patrick-kidger: That's... so hot. I should definitely document that in a new FAQ entry. Framework-agnostic tensor typing is something most of us are probably interested in. Let's reopen this under an appropriate new hypeworthy title. Build the hype! |
jaxtyping
jaxtyping
jaxtyping
! It is glorious!"
Passing a normal
numpy
array to@jaxtyped(typechecker=beartype)
-decorated function raises an error.Passing a
jax.numpy
array of the same shape and type doesn't.Versions:
I'm not sure if this is an issue with jaxtyping or beartype
The text was updated successfully, but these errors were encountered: