From 2f9dd041f64e26c60294c42f11971b16e65b2179 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 9 Dec 2022 17:47:56 +0200 Subject: [PATCH] [jax2tf] Fix native lowering for modules that don't use some inputs. We must drop the corresponding actual arguments when invoking the function and we must also be careful to compute the dim_args_specs based only on the kept inputs. As a side benefit, we now allow dimension variables that occur in more complex polynomials, as long as they also occur as trivial monomials somewhere in the input shapes for the kept arguments. --- jax/experimental/jax2tf/README.md | 36 ++++++ jax/experimental/jax2tf/jax2tf.py | 116 ++++++++++++------ jax/experimental/jax2tf/shape_poly.py | 4 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 22 ++++ .../jax2tf/tests/shape_poly_test.py | 70 ++++++++++- 5 files changed, 207 insertions(+), 41 deletions(-) diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index 352d51fe46e1..010e6a780dd8 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -381,6 +381,10 @@ lowered with the batch dimension polymorphic and the remaining dimensions concre It is reasonable to expect that there will be JAX programs for which there is a shape-polymorphic TensorFlow graph, but which will give an error when lowering with jax2tf. +In general, you should expect that shape polymorphism can handle those programs for which +all the intermediate shapes can be expressed as polynomials in the dimension variables +appearing in the input shapes. In particular, this does not include programs whose +intermediate shapes depend on the data. ### Details @@ -613,6 +617,38 @@ jax2tf.convert(lambda x: jnp.reshape(x, (2, -1)), polymorphic_shapes=["(2*b, ...)"])(np.ones((4, 5, 7))) ``` +### Dimension variables must be solvable from the input shapes + +`jax2tf` will generate code to derive the values of the dimension variables +from the input shapes. This works only if dimension polynomials in the input shapes are linear. +For example, the following `polymorphic_shapes` will result in errors: + +```python +polymorphic_shapes = ["a * a"] # Not a linear polynomial +polymorphic_shapes = ["a + b"] # Too few equations to derive both `a` and `b` +``` + +If you are using native lowering, the restrictions are stronger: every dimension +variable must occur as the value of some dimension of some input, e.g., +the following will work: + +```python +polymorphic_shapes = ["a, 2*a, b"] +polymorphic_shapes = ["a * a, a"] +``` + +Furthermore, when using the native lowering the inputs that are not needed in the computation +are ignored, so the dimension variables must be derivable only from used inputs. +In the following example, the `x_unused` is not part of the computation so its +input shapes cannot be used for deriving the dimension variables, and you will +get an error that `a` cannot be derived: + +```python +jax2tf.convert(lambda x_unused, y: y * 2., + polymorphic_shapes=["b, a", "b, 2 * a"])(x, y) +``` + + ## Known issues `jax2tf` has been in use since 2020 and the vast majority of users encounter diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index b5441da301dc..9cc48382e1d3 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -595,46 +595,28 @@ def _lower_native_and_run(fun_jax: Callable, Special care must be taken in presence of shape polymorphism. """ # Look for shape polymorphism - - # For each dimension variable, encode how to compute its value from the - # shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1]. - # The order of the dimension variables must match the order of the first N - # arguments of the lowered function. - # We now have two implementations for the native lowering. If --jax_dynamic_shapes - # then we use JAX's in-progress support for native dynamic shapes. In that - # case we assume that the dimension variables are listed in the order in which - # they are encountered by scanning the arguments and their shapes in order. - # If we don't use --jax_dynamic_shapes then the dimension variables are passed - # in the alphabetical order of their names. - abstracted_axes: Sequence[Dict[int, str]] = [] - dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec - dim_vars_seen: List[str] = [] # the dim var names in order - for arg_idx, aval in enumerate(args_avals): - one_abstract_axes = {} - for axis_idx, d in enumerate(aval.shape): - if not core.is_constant_dim(d): - d_var = d.to_var() - if d_var is None: - raise ValueError(f"Only simple dimension variables supported: {aval.shape}") - if not d_var in dim_vars_seen: - dim_args_spec_dict[d_var] = f"{arg_idx}.{axis_idx}" - dim_vars_seen.append(d_var) - one_abstract_axes[axis_idx] = d_var - abstracted_axes.append(one_abstract_axes) - - if any(abstracted_axes): - if config.jax_dynamic_shapes: + # then we use JAX's in-progress support for native dynamic shapes, and we pass + # abstracted_axes to lowering functions. Otherwise, we just lower using + # abstract values whose shapes may include polynomials (already in args_avals). + if config.jax_dynamic_shapes: + abstracted_axes: Sequence[Dict[int, str]] = [] + for arg_idx, aval in enumerate(args_avals): + one_abstract_axes = {} + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + d_var = d.to_var() + if d_var is None: + raise ValueError(f"Only trivial dimension polynomials on input: {aval.shape}") + one_abstract_axes[axis_idx] = d_var + abstracted_axes.append(one_abstract_axes) + + if any(abstracted_axes): abstracted_axes = tuple(abstracted_axes) - # In the order we have seen them - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_seen] else: abstracted_axes = None # type: ignore - # In sorted order by name - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_seen)] else: abstracted_axes = None # type: ignore - dim_args_spec = [] arg_specs_jax = [ jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape) @@ -647,7 +629,6 @@ def _lower_native_and_run(fun_jax: Callable, # convert(f_jax), in which case a "jit" is implied. We also add a jit when # we need to pass the abstracted axes. fun_jax_lower = jax.jit(fun_jax, backend=backend, - keep_unused=True, # TODO: allow dropping unused abstracted_axes=abstracted_axes).lower else: fun_jax_lower = fun_jax.lower @@ -658,10 +639,6 @@ def _lower_native_and_run(fun_jax: Callable, else: mhlo_module = lowered.mhlo() xla_call_module_version = 1 - if logging.vlog_is_on(3): - mhlo_module_text = mlir.module_to_string(mhlo_module) - logging.vlog(3, "XlaCallModule (version=%d)\n%s", xla_call_module_version, - mhlo_module_text) mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module) # Figure out the result types and shapes @@ -685,6 +662,62 @@ def _out_type(jax_type): return jax_type out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals) + module_kept_var_idx = lowered.compile_args["kept_var_idx"] + # We must compute the dim_args_spec: for each dimension variable, encode how + # to compute its value from the shape of the explicit arguments. E.g., "2.1" + # denotes args_tf[2].shape[1]. The order of the dimension variables must match + # the order of the first N arguments of the lowered function. + # If we use --jax_dynamic_shapes, the dimension variables are listed in the + # order in which they are encountered by scanning the arguments and their + # shapes in order. Otherwise, the dimension variables are passed in the + # alphabetical order of their names. + dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec + dim_vars_order: List[str] = [] + all_dim_vars: Set[str] = set() + current_kept_arg_idx = -1 # The index among the kept arguments + for arg_idx, aval in enumerate(args_avals): + is_kept = arg_idx in module_kept_var_idx + if is_kept: + current_kept_arg_idx += 1 + + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + # We collect dimension variables even from dropped args + all_dim_vars = all_dim_vars.union(d.get_vars()) + if not is_kept: continue + d_var = d.to_var() + # We can compute dim vars only from trivial polynomials + if d_var is None: continue + if not d_var in dim_args_spec_dict: + dim_vars_order.append(d_var) + dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}" + + if all_dim_vars: + dim_args_spec_set = set(dim_vars_order) + if dim_args_spec_set != all_dim_vars: + missing = all_dim_vars.difference(dim_args_spec_set) + args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}" + for arg_idx, aval in enumerate(args_avals)] + raise ValueError( + "The following dimension variables cannot be computed from the static " + f"shapes of the kept lowered arguments: {missing}. These are the " + "argument shapes:\n" + + "\n".join(args_list) + + "\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") + + if config.jax_dynamic_shapes: + # In the order we have seen them + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order] + else: + # In sorted order by name + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)] + else: + dim_args_spec = [] + + args_avals = [aval for i, aval in enumerate(args_avals) if i in module_kept_var_idx] + args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx] + # Apply the shardings on arguments and results for pjit. This is redundant # because the mhlo_module_text will already contain the shardings, but it # makes it easier for tools like the TPU inference converter to see the @@ -694,6 +727,11 @@ def _out_type(jax_type): args_tf = tuple( map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"])) + if logging.vlog_is_on(3): + mhlo_module_text = mlir.module_to_string(mhlo_module) + logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s", + xla_call_module_version, ", ".join(dim_args_spec), + mhlo_module_text) res = tfxla.call_module( args_tf, version=xla_call_module_version, diff --git a/jax/experimental/jax2tf/shape_poly.py b/jax/experimental/jax2tf/shape_poly.py index ca33b1690ebb..032b1fec735e 100644 --- a/jax/experimental/jax2tf/shape_poly.py +++ b/jax/experimental/jax2tf/shape_poly.py @@ -886,5 +886,7 @@ def process_one_eqn(eqn: DimEquation) -> bool: err_msg = ( f"Cannot solve for values of dimension variables {unsolved_vars} from " f"the remaining dimension polynomials\n {eqns_str}.{_shapeenv_to_str()} " - "Dimension variables can be solved only from linear polynomials.") + "Dimension variables can be solved only from linear polynomials.\n" + "\n" + "Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") raise ValueError(err_msg) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 39d6eb8f4a36..b5f1525b5d37 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -818,6 +818,28 @@ def inner1(y): jax2tf.convert(func)(2.) # No error + def test_jit_unused(self): + def f_jax(x, y_unused): + return x * np.float32(2.) + x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32) + res_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))(x, y_unused) + self.assertAllClose(f_jax(x, None), res_tf) + + def test_jit_unused_grad(self): + def f_jax(x, y_unused): + return x * np.float32(2.) + + x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32) + f_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False)) + xv, y_unused_v = tf.Variable(x), tf.Variable(y_unused) + with tf.GradientTape() as tape: + res_tf = f_tf(xv, y_unused_v) + grad_tf_x, grad_tf_y = tape.gradient(res_tf, (xv, y_unused_v)) + + self.assertAllClose(f_jax(x, None), res_tf) + self.assertAllClose(np.float32(2.), grad_tf_x) + self.assertIsNone(grad_tf_y) + def test_nested_convert_error(self): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index e4e1f8a19193..e34a47db389a 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -20,7 +20,6 @@ import collections import functools from functools import partial -import logging import operator import re @@ -722,6 +721,70 @@ def f_jax(x): # x: f32[w, h] input_signature=[tf.TensorSpec([None, None])], polymorphic_shapes=["w, h"]) + def test_non_trivial_polynomials(self): + if config.jax_dynamic_shapes: + raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials") + # We can handle non-trivial polynomials in the input shape, + # as long as all variables also occur in trivial polynoamials + self.CheckShapePolymorphism( + lambda x, y: x + y.reshape((-1,)), + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])], + polymorphic_shapes=["b * b", "b, b"]) + + def test_unused_args(self): + # Tests with functions that do not use their inputs. + + # First arg unused, not polymorphic + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b"]) + + # Some args unused, not polymorphic + self.CheckShapePolymorphism( + lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]), + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]), + tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b1", None, "b2"]) + + # A polymorphic arg is not used, but the dimension var appears + # in a used arg also + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b", "b"]) + + # A polymorphic arg is not used, and the dimension var does not appear + # elsewhere. + if config.jax2tf_default_experimental_native_lowering: + with self.assertRaisesRegex(ValueError, + "The following dimension variables cannot be computed"): + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b2"]) + else: + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b2"]) + + # A polymorphic arg is not used, and the dimension var does appear + # elsewhere but not as a trivial monomial. + if config.jax2tf_default_experimental_native_lowering: + with self.assertRaisesRegex(ValueError, + "The following dimension variables cannot be computed"): + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b1 * b1"]) + else: + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b1 * b1"]) + + def test_with_custom_vjp(self): """Shape-polymorphic custom VJP.""" @@ -1065,6 +1128,11 @@ def f_jax(x): jax2tf.convert(f_jax, polymorphic_shapes=["b, b"])).get_concrete_function(tf.TensorSpec([None, None], dtype=np.float32)) self.assertEqual(1, f_tf(x45)) + x = np.ones((5,), dtype=np.float32) + with self.assertRaisesRegex(ValueError, + "Cannot solve for values of dimension variables"): + jax2tf.convert(lambda x: x, polymorphic_shapes=["a + b"])(x) + class DimAsValueTest(tf_test_util.JaxToTfTestCase): """Dimension polynomials used as values in the JAX computation."""