-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
jax2tf.py
3191 lines (2585 loc) · 123 KB
/
jax2tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental module transforms JAX functions to be executed by TensorFlow."""
from functools import partial, reduce
import contextlib
import os
import re
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
from absl import logging
import jax
from jax import lax
from jax import config
from jax import core, custom_derivatives
from jax import linear_util as lu
from jax import random, tree_util
from jax import numpy as jnp
from jax.experimental import maps
from jax.experimental import pjit
from jax._src import sharding
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import dispatch
from jax._src import dtypes
from jax._src import prng
from jax._src import random as random_internal
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import slicing as lax_slicing
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import impl_no_xla
import numpy as np
import tensorflow as tf # type: ignore[import]
# These don't have public equivalents.
# pylint: disable=g-direct-tensorflow-import
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.core.framework import attr_value_pb2 # type: ignore[import]
try:
from tensorflow.python.compiler.xla.experimental import xla_sharding # type: ignore[import]
except ModuleNotFoundError:
# This can be removed when TF 2.10 support is no longer needed.
from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding # type: ignore[import]
from tensorflow.python.framework import ops as tf_ops # type: ignore[import]
from tensorflow.python.eager import context as tf_context # type: ignore[import]
# pylint: enable=g-direct-tensorflow-import
NameStack = source_info_util.NameStack
PolyShape = shape_poly.PolyShape
# A temporary internal flag, to enable the wrapping of jax.jit functions
# with tf.function(jit_compile=True). See #7389. This change has triggered a
# number of failures in TF. We keep this until we are confident that it does
# not create problems.
# TODO(b/207464757): figure out why this change breaks test
_WRAP_JAX_JIT_WITH_TF_FUNCTION = False
# The scope name need to be a valid TensorFlow name. See
# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/core/framework/node_def_util.cc#L731
_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$")
_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\/-]")
map = util.safe_map
zip = util.safe_zip
def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
if not _VALID_SCOPE_REGEX.match(scope_name):
scope_name = f".{scope_name}"
return scope_name
# A value suitable in a TF tracing context: tf.Tensor, tf.Variable,
# or Python scalar or numpy.ndarray. (A tf.EagerTensor is a tf.Tensor.)
TfVal = Any
DType = Any
PrecisionType = int # Enum xla_data.PrecisionConfig.Precision
def _is_tfval(v: TfVal) -> bool:
if isinstance(v, (tf.Tensor, tf.Variable)):
return True
try:
# Include all convertible types, even if not supported on accelerators.
with tf.device("CPU"):
tf.constant(v)
return True
except:
return False
# The implementation rules for primitives. The rule will be called with the
# arguments (TfVal) and must return TfVal (or a sequence thereof,
# if primitive.multiple_results). The exception are primarily the
# control-flow primitives.
tf_impl: Dict[core.Primitive, Callable[..., Any]] = {}
# Some primitive implementation rules need the abstract values of arguments
# and the results. This is the case for the primitives implemented using
# _convert_jax_impl and those that need to adjust the shape of the outputs
# due to missing TF shape inference rules for TFXLA ops. The rules for these
# primitives should be added to `tf_impl_with_avals`.
# The abstract value are passed to the implementation as two special kwargs
# `_in_avals` (a tuple of core.ShapedArray) and `_out_aval` (a
# core.ShapedArray, or a tuple thereof when primitive.multiple_results).
tf_impl_with_avals: Dict[core.Primitive, Callable[..., Any]] = {}
# XLA is not linked in all environments when converting a primitive. If this is
# the case, we first search for implementation rules for primitives in the
# following map. These implementations are workarounds, making use of TF ops
# that do work when XLA is not linked in.
tf_impl_no_xla = impl_no_xla.tf_impl_no_xla
# In order to ensure that JAX picks up the proper user-frame for source
# locations we will register the TensorFlow source path as an internal
# path with source_info_util. The typical stack when a JAX primitive
# conversion happens is:
# jax2tf.process_primitive (top of stack)
# jax tracing machinery ...
# tf.custom_gradient machinery ...
# jax2tf.converted_fun
# tf function machinery ...
# user code invokes the converted function on TF tensors
#
# We need to skip over not only JAX internal frames, but TF internal frames
# also.
# We register the TensorFlow source path lazily
_has_registered_tf_source_path = False
class _ThreadLocalState(threading.local):
def __init__(self):
# XLA is not linked in all environments; when converting a primitive, if this
# variable is disabled, we try harder to use only standard TF ops if they are
# applicable to the concrete use case; if the resulting conversion path ends up
# requiring a TFXLA operation, an exception is thrown instead.
self.enable_xla = True
# Keep track if we are inside a call_tf. In that context we disable the
# safety check that we are not inside JAX transformations.
self.inside_call_tf = False
# Maps dimension variables to TF expressions
self.shape_env: Sequence[Tuple[str, TfVal]] = ()
# Whether to actually include XLA op metadata in the generated TF ops
# TODO(b/189306134): implement support for XLA metadata
self.include_xla_op_metadata = False
# A cache for the tf.convert_to_tensor for constants. We try to preserve
# sharing for constants, to enable tf.Graph to take advantage of it.
# See https://github.com/google/jax/issues/7992.
self.constant_cache = None # None means that we don't use a cache. We
# may be outside a conversion scope.
# A cache for the outside tf name_scope when the converted
# function is running. We will add this as the prefix to the generated tf op
# name. For example, the tf op name will be like
# "{tf_outer_name_scope}/JAX_NAME_STACKS"
self.tf_outer_name_scope = ""
_thread_local_state = _ThreadLocalState()
def _get_current_name_stack() -> Union[NameStack, str]:
return source_info_util.current_name_stack()
@contextlib.contextmanager
def inside_call_tf():
# Set the inside_call_tf flag for a context.
prev = _thread_local_state.inside_call_tf
_thread_local_state.inside_call_tf = True
try:
yield
finally:
_thread_local_state.inside_call_tf = prev
@partial(api_util.api_hook, tag="jax2tf_convert")
def convert(fun_jax: Callable,
*,
polymorphic_shapes=None,
with_gradient=True,
enable_xla=True,
experimental_native_lowering="default") -> Callable:
"""Lowers `fun_jax` into a function that uses only TensorFlow ops.
See
[README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md)
for more details about usage and common problems.
Args:
fun_jax: JAX function to be lowered. Its arguments and return value should be
JAX arrays, or nested standard Python containers (tuple/list/dict) thereof
(pytrees).
polymorphic_shapes: Specifies input shapes to be treated polymorphically
during lowering.
.. warning:: The shape-polymorphic lowering is an experimental feature.
It is meant to be sound, but it is known to reject some JAX programs
that are shape polymorphic. The details of this feature can change.
It should be `None` (all arguments are monomorphic), a single PolyShape
or string (applies to all arguments), or a tuple/list of the same length
as the function arguments. For each argument the shape specification
should be `None` (monomorphic argument), or a Python object with the
same pytree structure as the argument.
See [how optional parameters are matched to
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
A shape specification for an array argument should be an object
`PolyShape(dim0, dim1, ..., dimn)`
where each `dim` is a dimension specification: a positive integer denoting
a monomorphic dimension of the given size, or a string denoting a
dimension variable assumed to range over non-zero dimension sizes, or
the special placeholder string "_" denoting a monomorphic dimension
whose size is given by the actual argument. As a shortcut, an Ellipsis
suffix in the list of dimension specifications stands for a list of "_"
placeholders.
For convenience, a shape specification can also be given as a string
representation, e.g.: "batch, ...", "batch, height, width, _", possibly
with surrounding parentheses: "(batch, ...)".
The lowering fails if it cannot ensure that the it would produce the same
sequence of TF ops for any non-zero values of the dimension variables.
polymorphic_shapes are only supported for positional arguments; shape
polymorphism is not supported for keyword arguments.
See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion)
for more details.
with_gradient: if set (default), add a tf.custom_gradient to the lowered
function, by converting the ``jax.vjp(fun)``. This means that reverse-mode
TensorFlow AD is supported for the output TensorFlow function, and the
value of the gradient will be JAX-accurate.
enable_xla: if set (default), use the simplest conversion
and use XLA TF ops when necessary. These ops are known to create issues
for the TFLite and TFjs converters. For those cases, unset this parameter
so the the lowering tries harder to use non-XLA TF ops to lower the
function and aborts if this is not possible.
experimental_native_lowering: DO NOT USE, for experimental purposes only.
The value "default" defers to --jax2tf_default_experimental_native_lowering.
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
tuple/lists/dicts thereof), and returns TfVals as outputs, and uses
only TensorFlow ops.
"""
if experimental_native_lowering == "default":
experimental_native_lowering = config.jax2tf_default_experimental_native_lowering
if experimental_native_lowering and not enable_xla:
raise ValueError(
"experimental_native_lowering is not supported with enable_xla=False")
api._check_callable(fun_jax)
fun_name = getattr(fun_jax, "__name__", "unknown")
name_stack = util.wrap_name(fun_name, "jax2tf")
def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:
try:
prev_enable_xla = _thread_local_state.enable_xla
prev_include_xla_op_metadata = _thread_local_state.include_xla_op_metadata
prev_tf_outer_name_scope = _thread_local_state.tf_outer_name_scope
_thread_local_state.tf_outer_name_scope = tf.get_current_name_scope()
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean(
) and not _thread_local_state.inside_call_tf:
# It is Ok to nest convert when we are inside a call_tf
raise ValueError(
"convert must be used outside all JAX transformations." +
f"Trace state: {core.thread_local_state.trace_state.trace_stack}")
fun_flat_jax, args_flat_tf, in_tree, out_tree_thunk = flatten_fun_jax(
fun_jax, args_tf, kwargs_tf)
# out_tree_thunk will be ready after we call fun_flat_jax below.
# Expand the polymorphic_shapes to match the args_flat_tf. The polymorphic_shapes
# argument refers to positional arguments only.
if polymorphic_shapes is None or isinstance(polymorphic_shapes,
(PolyShape, str)):
polymorphic_shapes_ = (polymorphic_shapes,) * len(args_tf)
else:
if not (isinstance(polymorphic_shapes, Sequence) and
len(polymorphic_shapes) == len(args_tf)):
msg = (
"polymorphic_shapes must be a sequence with the same length as "
"the positional argument list "
f"({len(args_tf)}). Got polymorphic_shapes={repr(polymorphic_shapes)}."
)
raise TypeError(msg)
polymorphic_shapes_ = tuple(polymorphic_shapes)
polymorphic_shapes_flat = tuple(
api_util.flatten_axes(
"jax2tf.convert polymorphic_shapes", in_tree,
(polymorphic_shapes_, {k: None for k in kwargs_tf.keys()})))
args_and_avals = tuple(
map(preprocess_arg_tf, range(len(args_flat_tf)), args_flat_tf,
polymorphic_shapes_flat))
args_flat_tf, args_avals_flat = util.unzip2(args_and_avals)
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(
args_avals_flat)
dim_values, _ = _interpret_fun_jax(get_dim_values_jax, args_flat_tf,
args_avals_flat, name_stack)
shape_env = zip(dim_vars, dim_values)
assert not _thread_local_state.shape_env, f"Unexpected shape environment {_thread_local_state.shape_env}"
_thread_local_state.enable_xla = enable_xla
# TODO(b/189306134): implement support for XLA metadata
_thread_local_state.include_xla_op_metadata = False
_thread_local_state.shape_env = shape_env
global _has_registered_tf_source_path
if not _has_registered_tf_source_path:
source_info_util.register_exclusion(os.path.dirname(tf.__file__))
_has_registered_tf_source_path = True
if with_gradient:
@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering)
return (tuple(outs_tf),
make_custom_gradient_fn_tf(
fun_flat_jax=fun_flat_jax,
args_flat_tf=args_flat_tf,
args_avals_flat=args_avals_flat,
polymorphic_shapes_flat=polymorphic_shapes_flat,
out_avals=out_avals))
out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
out_flat_tf = [
tf.raw_ops.PreventGradient(input=o, message=message)
for o in outs_tf
]
finally:
_thread_local_state.shape_env = ()
_thread_local_state.enable_xla = prev_enable_xla
_thread_local_state.include_xla_op_metadata = prev_include_xla_op_metadata
_thread_local_state.tf_outer_name_scope = prev_tf_outer_name_scope
out_flat_tf = [tf.identity(x, "jax2tf_out") for x in out_flat_tf]
out_tf = tree_util.tree_unflatten(out_tree_thunk(), out_flat_tf)
return out_tf
return converted_fun_tf
def dtype_of_val(val: TfVal) -> DType:
"""Computes the TensorFlow dtype using JAX's typing rules.
If the value is a tf.Tensor, it starts with its dtype. If the value is a
constant it uses JAX to infer its dtype. The resulting dtype follows the
JAX type inference rules, and depends on the value of the
JAX_ENABLE_X64 flag.
See README.md for how 64-bit values are treated.
"""
tval, _ = _tfval_to_tensor_jax_dtype(val)
return tval.dtype
# Internals
def flatten_fun_jax(fun_jax: Callable, args_tf: Sequence[TfVal],
kwargs_tf: Dict[str, TfVal]
) -> Tuple[Callable, Sequence[TfVal], Any, Callable]:
"""Wraps the function to take a (flat) list of positional args.
jax2tf works better and is simpler when the JAX function takes and returns
just a tuple of values (no pytrees, no kwargs). This is in part because
jax.vjp does not support kwargs and we can only set
tf.custom_gradient on functions with flat arguments and results
Returns:
* the wrapped JAX function taking and returning a flat list of arguments
* the flat list of TF arguments
* the in_tree corresponding to the tuple (args_tf, kwargs_tf)
* a thunk that can be called after the wrapped function has been called
to return the output pytree.
"""
# TODO(necula): technically we should use TF's flattening and unflattening
# because we are working with TF values.
args_flat_tf, in_tree = tree_util.tree_flatten((args_tf, kwargs_tf))
out_tree_ref = None
def fun_flat_jax(*args_flat_jax):
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
tree_res = fun_jax(*tree_args, **tree_kwargs)
res_flat_jax, out_tree = tree_util.tree_flatten(tree_res)
nonlocal out_tree_ref
assert out_tree_ref is None or out_tree_ref == out_tree
out_tree_ref = out_tree
return res_flat_jax
if hasattr(fun_jax, "lower"):
# If the fun_jax is already a jit(f) or pjit(f), we must
# preserve the lowering function. This will be used in the _lower_native_and_run.
# We rely on the fact that the lowering is the same for the function
# taking pytrees, and the one taking flat args.
def fun_flat_jax_lower(*args_flat_jax):
tree_args, tree_kwargs = tree_util.tree_unflatten(in_tree, args_flat_jax)
lowered = fun_jax.lower(*tree_args, **tree_kwargs)
out_tree = lowered.out_tree
nonlocal out_tree_ref
assert out_tree_ref is None or out_tree_ref == out_tree
out_tree_ref = out_tree
return lowered
setattr(fun_flat_jax, "lower", fun_flat_jax_lower)
return fun_flat_jax, args_flat_tf, in_tree, lambda: out_tree_ref
def preprocess_arg_tf(arg_idx: int,
arg_tf: TfVal,
polymorphic_shape: Optional[str]
) -> Tuple[TfVal, core.ShapedArray]:
if not _is_tfval(arg_tf):
msg = (f"Argument {arg_tf} of type {type(arg_tf)} of jax2tf.convert(f) should "
"be NumPy array, scalar, tf.Variable, or tf.Tensor")
raise TypeError(msg)
# May cast the args_flat to JAX types, using JAX's interpretation
# of types of constants.
arg_tf, arg_jax_dtype = _tfval_to_tensor_jax_dtype(arg_tf)
# Name input tensors; do this after we have cast the arguments
arg_tf = tf.identity(arg_tf, f"jax2tf_arg_{arg_idx}")
# Fix the shape for TF1
tf_arg_shape = np.shape(arg_tf)
arg_shape = tuple(d.value if isinstance(d, tf.compat.v1.Dimension) else d for d in tf_arg_shape)
arg_aval = shape_poly.arg_aval(arg_shape, arg_jax_dtype, polymorphic_shape)
return arg_tf, arg_aval
# Prepare the grad_fn for tf.custom_gradient.
def make_custom_gradient_fn_tf(
fun_flat_jax: Callable,
args_flat_tf: Sequence[TfVal],
polymorphic_shapes_flat: Sequence[str],
args_avals_flat: Sequence[core.ShapedArray],
out_avals: Sequence[core.ShapedArray]):
def grad_fn_tf(*out_cts_flat_tf: TfVal,
variables=None):
if variables:
raise ValueError(
"Unexpected variables used in forward pass. "
"This should not happen for first-order differentiation. "
f"{variables=}")
out_cts_flat_polymorphic_shapes = tuple(str(out_aval.shape) # Note: may be polynomials, not just DimVar
for out_aval in out_avals) # type: ignore
vjp_polymorphic_shapes = [
polymorphic_shapes_flat, out_cts_flat_polymorphic_shapes
]
def fun_vjp_jax(args_flat_jax, out_cts_flat_jax):
# One may think that we can get the pullback while we are converting
# the main function in the first place. That is problematic, because the
# pullback may contain captured tracers from the conversion of the
# main function. Those tracers will confuse the conversion of the
# pullback. So, we construct the vjp anew and we convert it separately.
_, pullback_jax = jax.vjp(fun_flat_jax, *args_flat_jax)
def fix_out_ct(out_ct_jax, out_ct_aval: core.ShapedArray):
# If the primal function has outputs of integer or bool types, and if we are
# under a tf.function context, then TF will pass None in _out_cts_flat
# in place of these values. We should change these to float0 or
# else JAX gets unhappy. See issue #6975.
if out_ct_jax is not None:
return out_ct_jax
assert core.primal_dtype_to_tangent_dtype(out_ct_aval.dtype) == dtypes.float0, f"{out_ct_jax=}"
# Note that out_ct_aval.shape contains dimension variable from the
# primal function scope. It is Ok to use them here because we
# use the same shape variables for the VJP function.
return jnp.zeros(out_ct_aval.shape, dtype=_tf_np_dtype_for_float0)
out_cts_fixed_flat = list(map(fix_out_ct, out_cts_flat_jax, out_avals))
in_cts_flat_jax = pullback_jax(out_cts_fixed_flat)
def fix_in_ct(in_ct_jax, arg_aval: core.ShapedArray):
if jnp.issubdtype(arg_aval.dtype, jnp.inexact):
return in_ct_jax
else:
assert in_ct_jax.dtype == dtypes.float0
return jnp.zeros(arg_aval.shape, _tf_np_dtype_for_float0)
in_cts_fixed_flat_jax = tuple(map(fix_in_ct, in_cts_flat_jax, args_avals_flat))
return in_cts_fixed_flat_jax
# TODO: enable higher-order gradients
with tf.name_scope("jax2tf_vjp"):
in_cts_flat = convert(
fun_vjp_jax,
with_gradient=False,
polymorphic_shapes=vjp_polymorphic_shapes)(args_flat_tf, out_cts_flat_tf)
return in_cts_flat
return grad_fn_tf
@contextlib.contextmanager
def _extended_name_stack(extra_name_stack: Optional[str]):
name_ctx = (source_info_util.extend_name_stack(extra_name_stack)
if extra_name_stack
else contextlib.nullcontext())
with name_ctx:
yield
return
def _interpret_fun_jax(
fun_jax: Callable,
args_tf: Sequence[TfVal],
args_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False,
experimental_native_lowering: bool = False
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
if experimental_native_lowering:
del extra_name_stack
return _lower_native_and_run(fun_jax, args_avals, args_tf)
else:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals)
with _extended_name_stack(extra_name_stack):
with core.new_sublevel():
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
_call_wrapped_with_new_constant_cache(subtrace_fun, args_tf,
fresh_constant_cache=fresh_constant_cache)
del main
return util.unzip2(out_vals)
def _lower_native_and_run(fun_jax: Callable,
args_avals: Sequence[core.ShapedArray],
args_tf: Sequence[TfVal],
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
"""Lowers the function using native lowering and then invokes it.
Work-in-progress.
Uses JAX native lowering to MHLO, and then wraps the result in a
XlaCallModule TF op. This op does not have backward-compatibility yet.
Special care must be taken in presence of shape polymorphism.
"""
# Look for shape polymorphism
# We now have two implementations for the native lowering. If --jax_dynamic_shapes
# then we use JAX's in-progress support for native dynamic shapes, and we pass
# abstracted_axes to lowering functions. Otherwise, we just lower using
# abstract values whose shapes may include polynomials (already in args_avals).
if config.jax_dynamic_shapes:
abstracted_axes: Sequence[Dict[int, str]] = []
for arg_idx, aval in enumerate(args_avals):
one_abstract_axes = {}
for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
d_var = d.to_var()
if d_var is None:
raise ValueError(f"Only trivial dimension polynomials on input: {aval.shape}")
one_abstract_axes[axis_idx] = d_var
abstracted_axes.append(one_abstract_axes)
if any(abstracted_axes):
abstracted_axes = tuple(abstracted_axes)
else:
abstracted_axes = None # type: ignore
else:
abstracted_axes = None # type: ignore
arg_specs_jax = [
jax.ShapeDtypeStruct(aval.shape, aval.dtype, named_shape=aval.named_shape)
for aval in args_avals
]
# TODO: specify the backend for experimental_native_lowering
backend = jax.default_backend()
if not hasattr(fun_jax, "lower") or abstracted_axes:
# We support convert(pjit(f_jax, ...)) and convert(jit(f_jax)) but also
# convert(f_jax), in which case a "jit" is implied. We also add a jit when
# we need to pass the abstracted axes.
fun_jax_lower = jax.jit(fun_jax, backend=backend,
abstracted_axes=abstracted_axes).lower
else:
fun_jax_lower = fun_jax.lower
lowered = fun_jax_lower(*arg_specs_jax)._lowering
if config.jax2tf_use_stablehlo:
mhlo_module = lowered.stablehlo()
xla_call_module_version = 2
else:
mhlo_module = lowered.mhlo()
xla_call_module_version = 1
mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module)
# Figure out the result types and shapes
if "global_out_avals" in lowered.compile_args:
# This is currently the case for pjit
out_avals = lowered.compile_args["global_out_avals"]
else:
out_avals = lowered.compile_args["out_avals"]
if lowered.compile_args["host_callbacks"]:
raise NotImplementedError("host_callbacks are not yet implemented for the jax2tf native lowering")
# TODO(necula): handle d being InDBIdx
out_shapes = tuple(
tuple(d if type(d) is int else None
for d in out_aval.shape)
for out_aval in out_avals)
def _out_type(jax_type):
if jax_type == dtypes.float0:
return dtypes.bool_
return jax_type
out_types = tuple(_out_type(out_aval.dtype) for out_aval in out_avals)
module_kept_var_idx = lowered.compile_args["kept_var_idx"]
# We must compute the dim_args_spec: for each dimension variable, encode how
# to compute its value from the shape of the explicit arguments. E.g., "2.1"
# denotes args_tf[2].shape[1]. The order of the dimension variables must match
# the order of the first N arguments of the lowered function.
# If we use --jax_dynamic_shapes, the dimension variables are listed in the
# order in which they are encountered by scanning the arguments and their
# shapes in order. Otherwise, the dimension variables are passed in the
# alphabetical order of their names.
dim_args_spec_dict: Dict[str, str] = {} # map dim var name to dim_args_spec
dim_vars_order: List[str] = []
all_dim_vars: Set[str] = set()
current_kept_arg_idx = -1 # The index among the kept arguments
for arg_idx, aval in enumerate(args_avals):
is_kept = arg_idx in module_kept_var_idx
if is_kept:
current_kept_arg_idx += 1
for axis_idx, d in enumerate(aval.shape):
if not core.is_constant_dim(d):
# We collect dimension variables even from dropped args
all_dim_vars = all_dim_vars.union(d.get_vars())
if not is_kept: continue
d_var = d.to_var()
# We can compute dim vars only from trivial polynomials
if d_var is None: continue
if not d_var in dim_args_spec_dict:
dim_vars_order.append(d_var)
dim_args_spec_dict[d_var] = f"{current_kept_arg_idx}.{axis_idx}"
if all_dim_vars:
dim_args_spec_set = set(dim_vars_order)
if dim_args_spec_set != all_dim_vars:
missing = all_dim_vars.difference(dim_args_spec_set)
args_list = [f" Arg[{arg_idx}] - {'KEPT ' if arg_idx in module_kept_var_idx else 'DROPPED'}: {aval}"
for arg_idx, aval in enumerate(args_avals)]
raise ValueError(
"The following dimension variables cannot be computed from the static "
f"shapes of the kept lowered arguments: {missing}. These are the "
"argument shapes:\n" +
"\n".join(args_list) +
"\n"
"Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details.")
if config.jax_dynamic_shapes:
# In the order we have seen them
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in dim_vars_order]
else:
# In sorted order by name
dim_args_spec = [dim_args_spec_dict[d_var] for d_var in sorted(dim_vars_order)]
else:
dim_args_spec = []
args_avals = [aval for i, aval in enumerate(args_avals) if i in module_kept_var_idx]
args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx]
# Apply the shardings on arguments and results for pjit. This is redundant
# because the mhlo_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
# sharding without digging into the `module` attribute of the `XlaCallModule`
# op, in the same way as it is done for the legacy jax2tf conversion.
if "in_shardings" in lowered.compile_args:
args_tf = tuple(
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))
if logging.vlog_is_on(3):
mhlo_module_text = mlir.module_to_string(mhlo_module)
logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s",
xla_call_module_version, ", ".join(dim_args_spec),
mhlo_module_text)
res = tfxla.call_module(
args_tf,
version=xla_call_module_version,
module=mhlo_serialized_module,
Tout=out_types,
Sout=out_shapes,
dim_args_spec=dim_args_spec)
if "out_shardings" in lowered.compile_args:
res = list(map(_shard_value, res, out_avals, lowered.compile_args["out_shardings"]))
# Convert the results to the needed TF types
def _convert_res(res_val, res_jax_type):
conversion_dtype = _to_tf_dtype(res_jax_type)
if conversion_dtype != res_jax_type:
return tf.cast(res_val, conversion_dtype)
else:
return res_val
res = tuple(
_convert_res(res_val, out_aval.dtype)
for res_val, out_aval in zip(res, out_avals))
return res, out_avals
def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
fresh_constant_cache: bool = False
) -> Sequence[Tuple[TfVal, core.ShapedArray]]:
try:
prev_constant_cache = _thread_local_state.constant_cache
prev_constant_cache_keys = set(prev_constant_cache.keys()) if prev_constant_cache is not None else set()
# Start a new cache, so that we don't share constants across tf.function
# boundaries.
if fresh_constant_cache:
_thread_local_state.constant_cache = {}
out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
fun.call_wrapped(*in_vals)
finally:
if prev_constant_cache is not None and not fresh_constant_cache:
newly_added_keys = set(prev_constant_cache.keys()) - prev_constant_cache_keys
# Delete the newly added keys
for k in newly_added_keys:
del prev_constant_cache[k]
_thread_local_state.constant_cache = prev_constant_cache
return out_vals
def _convert_jax_impl(impl_jax: Callable, *,
multiple_results=True,
with_physical_avals=False,
extra_name_stack: Optional[str] = None) -> Callable:
"""Convert the JAX implementation of a primitive.
Args:
impl_jax: typically the impl-rule for a primitive, with signature
`(*args_jax: JaxVal, **kwargs) -> Sequence[JaxVal]`. This function implements
a primitive in terms of other primitives.
multiple_results: whether `impl_jax` returns a sequence of results.
extra_name_stack: additional element to add to the name stack for the
converted ops.
Returns:
a function with signature `(*args_tf: TfVal, _in_avals, _out_aval, **kwargs)
-> Sequence[TfVal]`.
"""
def wrapped_tf(*args_tf: TfVal, _in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> Sequence[TfVal]:
if with_physical_avals:
_in_avals = map(_jax_physical_aval, _in_avals)
_out_aval = _jax_physical_aval(_out_aval)
# We wrap the impl_jax to always return a tuple of results.
def impl_multiple_results_jax(*args_jax):
results_jax = impl_jax(*args_jax, **kwargs)
return results_jax if multiple_results else [results_jax]
results_tf, _ = _interpret_fun_jax(
impl_multiple_results_jax, args_tf, _in_avals,
extra_name_stack)
return results_tf if multiple_results else results_tf[0]
return wrapped_tf
@lu.transformation
def _interpret_subtrace(main: core.MainTrace,
in_avals: Sequence[core.ShapedArray],
*in_vals: TfVal):
trace = TensorFlowTrace(main, core.cur_sublevel())
in_tracers = tuple(
TensorFlowTracer(trace, val, aval)
for val, aval in zip(in_vals, in_avals))
outs = yield in_tracers, {} # type: Sequence[TfVal]
out_tracers: Iterable[TensorFlowTracer] = (
map(trace.full_raise, outs)) # type: ignore
out_vals_with_avals: Sequence[Tuple[TfVal, core.ShapedArray]] = (
tuple((t.val, t.aval) for t in out_tracers))
yield out_vals_with_avals
def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal,
extra_name_stack: Optional[str]) -> Sequence[TfVal]:
"""Evaluates a Jaxpr with tf.Tensor arguments.
The output is a sequence of TfVal, suitable for use with TF.
"""
outs_tf, _ = _interpret_fun_jax(core.jaxpr_as_fun(jaxpr),
args_tf, jaxpr.in_avals, extra_name_stack)
return outs_tf
def _jax_physical_aval(aval: core.ShapedArray) -> core.ShapedArray:
"""Converts JAX avals from logical to physical, if relevant.
JAX might have avals whose logical vs physical shape/dtype may
differ, and only the physical view is expected to possibly
relate to TF. TF impl rules should operate on the physical form.
A JAX logical aval might even correspond, in principle, to several
physical avals, but we don't support those here. Instead we assert
there is only one and return it.
"""
if core.is_opaque_dtype(aval.dtype):
aval, = aval.dtype._rules.physical_avals(aval)
return aval
return aval
def _jax_physical_dtype(dtype):
# assuming () is a fine stand-in shape
return _jax_physical_aval(core.ShapedArray((), dtype)).dtype
def _aval_to_tf_shape(aval: core.ShapedArray) -> Tuple[Optional[int], ...]:
"""Generate a TF shape, possibly containing None for polymorphic dimensions."""
aval = _jax_physical_aval(aval)
return tuple(map(lambda d: None if shape_poly.is_poly_dim(d) else d,
aval.shape)) # type: ignore[attr-defined]
# In the TF world, we represent float0 as zeros of this type.
_tf_np_dtype_for_float0 = np.int32
def _to_tf_dtype(jax_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
try:
jax_dtype = _jax_physical_dtype(jax_dtype)
except TypeError:
# `jax_dtype` isn't actually a valid jax dtype (e.g. it is
# tf.float32), so there is no physical dtype anyway
pass
if jax_dtype == dtypes.float0:
jax_dtype = _tf_np_dtype_for_float0
return tf.dtypes.as_dtype(jax_dtype)
def _to_jax_dtype(tf_dtype):
# Note that converting _to_tf_dtype and _to_jax_dtype are not inverses,
# due to float0 and 64-bit behavior.
return dtypes.canonicalize_dtype(tf_dtype.as_numpy_dtype)
def _maybe_decode_gda(gda_or_py_object: Any):
"""Convert GlobalDeviceArray into numpy object."""
if isinstance(gda_or_py_object, GlobalDeviceArray):
if jax.process_count() != 1:
raise RuntimeError("GlobalDeviceArray does not support multi-process"
f" currently. Process num = {jax.process_count()}")
return gda_or_py_object._value
return gda_or_py_object
def _tfval_to_tensor_jax_dtype(val: TfVal,
jax_dtype: Optional[DType] = None,
memoize_constants=False) -> Tuple[TfVal, DType]:
"""Converts a scalar, ndarray, or tf.Tensor to a tf.Tensor with proper type.
If `jax_dtype` is missing, uses JAX typing rules.
See README.md for details regarding 64-bit values.
Args:
val: a scalar, ndarray, tf.Tensor, or tf.Variable
jax_dtype: an optional dtype to use. If missing, uses JAX type inference
rules for constants.
memoize_constants: whether to memoize TF constants. We can't do this
everywhere, we may be outside of a conversion scope.
Returns:
a tuple with a tf.Tensor with the type as needed by JAX, and the JAX type.
"""
if isinstance(val, (tf.Tensor, tf.Variable)):
jax_dtype = jax_dtype or _to_jax_dtype(val.dtype) # Give JAX a chance to pick the type
conversion_dtype = _to_tf_dtype(jax_dtype)
if conversion_dtype != val.dtype: # May need to cast for 64-bit values
return tf.cast(val, conversion_dtype), jax_dtype
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or xla.abstractify(val).dtype
# TODO(document): We assume that the value of a constant does not
# change through the scope of the function. But it may be an ndarray, ...
# JAX has the same problem when generating HLO.
const_key = (id(val), jax_dtype)
# Since we use id(val) as a cache key, we have to make sure that we keep
# the previous `val` alive. Otherwise, for an ndarray, it can get garbage
# collected and reused for a different value, which would create correctness
# issues. We keep the `val` alive by storing in the cache the pair
# `(val, tf_val)`.
do_memoize = (memoize_constants and np.shape(val) and _thread_local_state.constant_cache is not None)
if do_memoize:
_, tf_val = _thread_local_state.constant_cache.get(const_key, (None, None))
else:
tf_val = None
if tf_val is None:
conversion_dtype = _to_tf_dtype(jax_dtype)
# The float0 type is not known to TF.
if jax_dtype == dtypes.float0:
val = np.zeros(np.shape(val), conversion_dtype.as_numpy_dtype)
tf_val = tf.convert_to_tensor(
_maybe_decode_gda(val), dtype=conversion_dtype)
if do_memoize:
_thread_local_state.constant_cache[const_key] = (val, tf_val)
return tf_val, jax_dtype
def _eval_shape(shape: Sequence[shape_poly.DimSize], dtype=None) -> Sequence[TfVal]:
assert all(map(lambda x: x is not None, shape)), (
f"Argument shape should be a valid JAX shape but got {shape}")
if dtype is not None:
shape = _jax_physical_aval(core.ShapedArray(shape, dtype)).shape
dim_vars, dim_values = util.unzip2(_thread_local_state.shape_env)
eval_shape_jax, dim_avals = shape_poly.get_shape_evaluator(dim_vars, shape)
shape_values_tf, _ = _interpret_fun_jax(eval_shape_jax,
dim_values, dim_avals, "") # type: ignore
return shape_values_tf
def _assert_matching_abstract_shape(x: TfVal, shape: Sequence[shape_poly.DimSize]):
"""Asserts that shape matches x.shape in the known dimensions and has
dimension polynomials elsewhere."""
# Ensures that the shape does not contain None; it should contain polynomials
assert (len(x.shape) == len(shape) and
all((xd is None and isinstance(sd, shape_poly._DimPolynomial) or
core.is_constant_dim(sd) and xd == sd)
for xd, sd in zip(x.shape, shape))), \
f"Shape {shape} does not match x.shape {x.shape}"
# TODO(b/26854495): pylint doesn't understand slots and inheritance.
# pylint: disable=assigning-non-slot
class TensorFlowTracer(core.Tracer):
"""Tracer class that boxes a TF value and a JAX abstract value.
In addition to the TF value we carry the JAX abstract value because
there are some cases when it cannot be recovered from the value:
when we are converting with polymorphic shapes or when the JAX aval
has a custom element type. In these cases the shape of the value may
have dimensions set to `None`, or it may only correspond to the JAX
"physical" (TF/lowering-compatible) shape, so the JAX abstract value
may contain more precise information.
When the value has a partially-known shape, the dimensions marked as `None`
must correspond to non-constant dimensions in the abstract value.