From 8eb3da2be9bbf562f4cb5fcb4c172d928fe2f080 Mon Sep 17 00:00:00 2001 From: George Necula Date: Fri, 9 Dec 2022 17:47:56 +0200 Subject: [PATCH] [jax2tf] Fix native lowering for modules that don't use some inputs. We must drop the corresponding actual arguments when invoking the function and we must also be careful to compute the dim_args_specs based only on the kept inputs. As a side benefit, we now allow dimension variables that occur in more complex polynomials, as long as they also occur as trivial monomials somewhere in the input shapes for the kept arguments. --- jax/experimental/jax2tf/jax2tf.py | 115 ++++++++++++------ jax/experimental/jax2tf/tests/jax2tf_test.py | 22 ++++ .../jax2tf/tests/shape_poly_test.py | 55 ++++++++- 3 files changed, 150 insertions(+), 42 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index dbf227ba1a94..0a4d37460753 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -595,46 +595,28 @@ def _lower_native_and_run(fun_jax: Callable, Special care must be taken in presence of shape polymorphism. """ # Look for shape polymorphism - - # For each dimension variable, encode how to compute its value from the - # shape of the explicit arguments. E.g., "2.1" denotes args_tf[2].shape[1]. - # The order of the dimension variables must match the order of the first N - # arguments of the lowered function. - # We now have two implementations for the native lowering. If --jax_dynamic_shapes - # then we use JAX's in-progress support for native dynamic shapes. In that - # case we assume that the dimension variables are listed in the order in which - # they are encountered by scanning the arguments and their shapes in order. - # If we don't use --jax_dynamic_shapes then the dimension variables are passed - # in the alphabetical order of their names. - abstracted_axes: Sequence[Dict[int, str]] = [] - dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec - dim_vars_seen: List[str] = [] # the dim var names in order - for arg_idx, aval in enumerate(args_avals): - one_abstract_axes = {} - for axis_idx, d in enumerate(aval.shape): - if not core.is_constant_dim(d): - d_var = d.to_var() - if d_var is None: - raise ValueError(f"Only simple dimension variables supported: {aval.shape}") - if not d_var in dim_vars_seen: - dim_args_spec_dict[d_var] = f"{arg_idx}.{axis_idx}" - dim_vars_seen.append(d_var) - one_abstract_axes[axis_idx] = d_var - abstracted_axes.append(one_abstract_axes) - - if any(abstracted_axes): - if config.jax_dynamic_shapes: + # then we use JAX's in-progress support for native dynamic shapes, and we pass + # abstracted_axes to lowering functions. Otherwise, we just lower using + # abstract values whose shapes may include polynomials (already in args_avals). + if config.jax_dynamic_shapes: + abstracted_axes: Sequence[Dict[int, str]] = [] + for arg_idx, aval in enumerate(args_avals): + one_abstract_axes = {} + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + d_var = d.to_var() + if d_var is None: + raise ValueError(f"Only trivial dimension polynomials on input: {aval.shape}") + one_abstract_axes[axis_idx] = d_var + abstracted_axes.append(one_abstract_axes) + + if any(abstracted_axes): abstracted_axes = tuple(abstracted_axes) - # In the order we have seen them - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_seen] else: abstracted_axes = None # type: ignore - # In sorted order by name - dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_seen)] else: abstracted_axes = None # type: ignore - dim_args_spec = [] arg_specs_jax = [ jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape) @@ -647,7 +629,6 @@ def _lower_native_and_run(fun_jax: Callable, # convert(f_jax), in which case a "jit" is implied. We also add a jit when # we need to pass the abstracted axes. fun_jax_lower = jax.jit(fun_jax, backend=backend, - keep_unused=True, # TODO: allow dropping unused abstracted_axes=abstracted_axes).lower else: fun_jax_lower = fun_jax.lower @@ -658,10 +639,6 @@ def _lower_native_and_run(fun_jax: Callable, else: mhlo_module = lowered.mhlo() xla_call_module_version = 1 - if logging.vlog_is_on(3): - mhlo_module_text = mlir.module_to_string(mhlo_module) - logging.vlog(3, "XlaCallModule (version=%d)\n%s", xla_call_module_version, - mhlo_module_text) mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module) # Figure out the result types and shapes @@ -685,6 +662,60 @@ def _out_type(jax_type): return jax_type out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals) + module_kept_var_idx = lowered.compile_args["kept_var_idx"] + # We must compute the dim_args_spec: for each dimension variable, encode how + # to compute its value from the shape of the explicit arguments. E.g., "2.1" + # denotes args_tf[2].shape[1]. The order of the dimension variables must match + # the order of the first N arguments of the lowered function. + # If we use jax_dynamic_shapes, the dimension variables are listed in the + # order in which they are encountered by scanning the arguments and their + # shapes in order. Otherwise, the dimension variables are passed in the + # alphabetical order of their names. + dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec + dim_vars_order: List[str] = [] + all_dim_vars: Set[str] = set() + current_kept_arg_idx = -1 # The index among the kept arguments + for arg_idx, aval in enumerate(args_avals): + is_kept = arg_idx in module_kept_var_idx + if is_kept: + current_kept_arg_idx += 1 + + for axis_idx, d in enumerate(aval.shape): + if not core.is_constant_dim(d): + # We collect dimension variables even from dropped args + all_dim_vars = all_dim_vars.union(d.get_vars()) + if not is_kept: continue + d_var = d.to_var() + # We don't yet know how to compute dim vars from non-trivial polynomials + if d_var is None: continue + if not d_var in dim_args_spec_dict: + dim_vars_order.append(d_var) + dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}" + + if all_dim_vars: + dim_args_spec_set = set(dim_vars_order) + if dim_args_spec_set != all_dim_vars: + missing = all_dim_vars.difference(dim_args_spec_set) + args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}" + for arg_idx, aval in enumerate(args_avals)] + raise ValueError( + "The following dimension variables cannot be computed from the static " + f"shapes of the kept lowered arguments: {missing}. These are the " + "argument shapes:\n" + + "\n".join(args_list)) + + if config.jax_dynamic_shapes: + # In the order we have seen them + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order] + else: + # In sorted order by name + dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)] + else: + dim_args_spec = [] + + args_avals = [aval for i, aval in enumerate(args_avals) if i in module_kept_var_idx] + args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx] + # Apply the shardings on arguments and results for pjit. This is redundant # because the mhlo_module_text will already contain the shardings, but it # makes it easier for tools like the TPU inference converter to see the @@ -693,6 +724,12 @@ def _out_type(jax_type): if "in_shardings" in lowered.compile_args: args_tf = tuple( map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"])) + + if logging.vlog_is_on(3): + mhlo_module_text = mlir.module_to_string(mhlo_module) + logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s", + xla_call_module_version, ", ".join(dim_args_spec), + mhlo_module_text) res = tfxla.call_module( args_tf, version=xla_call_module_version, diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 39d6eb8f4a36..89a29dbdbae6 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -818,6 +818,28 @@ def inner1(y): jax2tf.convert(func)(2.) # No error + def test_jit_unused(self): + def f_jax(x, y_unused): + return x * 2. + x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32) + res_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False))(x, y_unused) + self.assertAllClose(f_jax(x, None), res_tf) + + def test_jit_unused_grad(self): + def f_jax(x, y_unused): + return x * 2. + + x, y_unused = np.float32(5.), np.arange(7, dtype=np.int32) + f_tf = jax2tf.convert(jax.jit(f_jax, keep_unused=False)) + xv, y_unused_v = tf.Variable(x), tf.Variable(y_unused) + with tf.GradientTape() as tape: + res_tf = f_tf(xv, y_unused_v) + grad_tf_x, grad_tf_y = tape.gradient(res_tf, (xv, y_unused_v)) + + self.assertAllClose(f_jax(x, None), res_tf) + self.assertAllClose(2., grad_tf_x) + self.assertIsNone(grad_tf_y) + def test_nested_convert_error(self): def outer(y): return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 3455f3711bc2..52a788ef5354 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -20,7 +20,6 @@ import collections import functools from functools import partial -import logging import operator import re @@ -29,12 +28,10 @@ from jax.experimental import jax2tf from jax.experimental.jax2tf import shape_poly from jax import lax -from jax import linear_util as lu import jax.numpy as jnp from jax._src import test_util as jtu from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow -from jax._src import util import numpy as np from jax.experimental.jax2tf.tests import tf_test_util @@ -723,6 +720,58 @@ def f_jax(x): # x: f32[w, h] input_signature=[tf.TensorSpec([None, None])], polymorphic_shapes=["w, h"]) + def test_non_trivial_polynomials(self): + if config.jax_dynamic_shapes: + raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials") + # We can handle non-trivial polynomials in the input shape, + # as long as all variables also occur in trivial polynoamials + self.CheckShapePolymorphism( + lambda x, y: x + y.reshape((-1,)), + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])], + polymorphic_shapes=["b * b", "b, b"]) + + def test_unused_args(self): + # Tests with functions that do not use their inputs. + + # First arg unused, not polymorphic + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b"]) + + # Some args unused, not polymorphic + self.CheckShapePolymorphism( + lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]), + input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]), + tf.TensorSpec([]), tf.TensorSpec([None])], + polymorphic_shapes=[None, "b1", None, "b2"]) + + # A polymorphic arg is not used, but the dimension var appears + # in a used arg also + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b", "b"]) + + # A polymorphic arg is not used, and the dimension var does not appear + # elsewhere. + with self.assertRaisesRegex(ValueError, + "The following dimension variables cannot be computed"): + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b2"]) + + # A polymorphic arg is not used, and the dimension var does appear + # elsewhere but not as a trivial monomial. + with self.assertRaisesRegex(ValueError, + "The following dimension variables cannot be computed"): + self.CheckShapePolymorphism( + lambda x_unused, y: y * 2.0, + input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])], + polymorphic_shapes=["b1", "b1 * b1"]) + + def test_with_custom_vjp(self): """Shape-polymorphic custom VJP."""