Skip to content

Commit

Permalink
fix(integrations): fix type check for jax.Array introduced in jax==0.…
Browse files Browse the repository at this point in the history
…4.1 (#4718)

* fix(integrations): fix type check for jax.Array introduced in jax==0.4.1
  • Loading branch information
dmitryduev committed Jan 4, 2023
1 parent 4f272d3 commit 3c04cf5
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ docker
stable_baselines3
tensorboard
gym
jax[cpu]<0.4; sys_platform == 'darwin' or sys_platform == 'linux'
jax[cpu]; sys_platform == 'darwin' or sys_platform == 'linux'
fastcore; python_version > '3.6'
fastcore==1.3.29; python_version == '3.6'
pyarrow
Expand Down
2 changes: 1 addition & 1 deletion wandb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def is_pytorch_tensor_typename(typename: str) -> bool:


def is_jax_tensor_typename(typename: str) -> bool:
return typename.startswith("jaxlib.") and "DeviceArray" in typename
return typename.startswith("jaxlib.") and "Array" in typename


def get_jax_tensor(obj: Any) -> Optional[Any]:
Expand Down

0 comments on commit 3c04cf5

Please sign in to comment.