Skip to content

Commit

Permalink
Raise a runtime error when trying to convert the jax.Array wrapped …
Browse files Browse the repository at this point in the history
…by `jax.core.Token` to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape.

PiperOrigin-RevId: 632682906
  • Loading branch information
yueshengys authored and tensorflower-gardener committed May 11, 2024
1 parent acb4ea8 commit e7532cd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions third_party/xla/xla/python/py_array.cc
Expand Up @@ -1412,6 +1412,13 @@ StatusOr<nb::object> PyHostValue::AsNumPyArray(
if (ifrt_array->IsDeleted()) {
return InvalidArgument("DeviceArray has been deleted.");
}
// The only `jax.Array` with token-shape buffer is the one wrapped by
// `jax.core.Token`. Since it is an internal implementation detail, we
// don't support converting it to a numpy array.
if (ifrt_array->dtype().kind() == ifrt::DType::kToken) {
return InvalidArgument(
"Cannot convert a token-shape buffer to a numpy array.");
}
auto* arr = llvm::dyn_cast_or_null<ifrt::PjRtCompatibleArray>(ifrt_array);
if (arr != nullptr) {
auto* pjrt_buffer = arr->pjrt_buffers().front().get();
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/xla_client.py
Expand Up @@ -49,7 +49,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 263
_version = 264

# Version number for MLIR:Python components.
mlir_api_version = 56
Expand Down

0 comments on commit e7532cd

Please sign in to comment.