diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 3477733cf2d8..9dfd8d93a65d 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -637,7 +637,7 @@ polymorphic_shapes = ["a, 2*a, b"] polymorphic_shapes = ["a * a, a"] ``` -Furthermore, in the native lowering the inputs that are not needed in the computation +Furthermore, when native lowering the inputs that are not needed in the computation are ignored, so the dimension variables must be derivable only from used inputs. In the following example, the `x_unused` is not part of the computation so its input shapes cannot be used for deriving the dimension variables, and you will