Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle bfloat16 in jnp.load() #8499

Merged
merged 1 commit into from Nov 16, 2021

Conversation

jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Nov 9, 2021

Currently jnp.load is just a direct reference to np.load, but in #8494 a user reported that this does not properly handle bfloat16. This is a simple wrapper to handle this case (and to return jax arrays rather than numpy arrays from jnp.load)

@google-cla google-cla bot added the cla: yes label Nov 9, 2021
@jakevdp jakevdp requested a review from skye November 9, 2021 17:45
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 16, 2021
@copybara-service copybara-service bot merged commit 9491414 into google:main Nov 16, 2021
@jakevdp jakevdp deleted the load-wrapper branch November 16, 2021 19:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants