-
Hello, I have a large codebase that runs in float32, but I want to run one operation in float64. It looks like jax.experimental.enable_x64() does exactly what I want, however it fails as soon as I try to backward mode autodiff it (but it works in forward mode). I am not sure if I am using it wrong or if it is a bug. Here is a minimal reproducing example:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
jax.experimental.enable_x64
will not work in all contexts, and you've hit one context where it is known not to work. There is no workaround; the mechanism it uses is fundamentally flawed, which is one reason it's remained injax.experimental
.