-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
shape_poly_test.py
2177 lines (1926 loc) · 90.4 KB
/
shape_poly_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the shape-polymorphic jax2tf conversion."""
import unittest
from absl.testing import absltest, parameterized
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import collections
import functools
from functools import partial
import operator
import re
import jax
from jax import core
from jax.experimental import jax2tf
from jax.experimental.jax2tf import shape_poly
from jax import lax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
import numpy as np
from jax.experimental.jax2tf.tests import tf_test_util
import tensorflow as tf # type: ignore[import]
from jax.config import config
config.parse_flags_with_absl()
# Import after parsing flags
from jax.experimental.jax2tf.tests import primitive_harness
from jax.experimental.jax2tf.tests.primitive_harness import Harness, CustomArg, RandArg, StaticArg
from jax.experimental.jax2tf.tests.jax2tf_limitations import Jax2TfLimitation
PS = jax2tf.PolyShape
class DimPolynomialTest(tf_test_util.JaxToTfTestCase):
def test_parse_poly_spec(self):
self.assertEqual((2, 3), shape_poly._parse_spec(None, (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("2, 3", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("2, _", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("2, ...", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec("...", (2, 3)))
self.assertEqual((2, 3), shape_poly._parse_spec(" ( 2 , 3 ) ", (2, 3)))
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", (None, 3)))
tshape = tf.TensorShape([None, 3])
self.assertEqual((a, 3), shape_poly._parse_spec("(a, ...) ", tshape))
a, b = shape_poly._parse_spec("a, b", (2, 3))
@parameterized.named_parameters(
dict(testcase_name=f"_{dim_spec=}",
dim_spec=dim_spec, dim_poly=dim_poly)
for dim_spec, dim_poly in [
("2*a*b", 2 * a * b),
("-2 * a^2 * b + b^2", -2 * a * a * b + b * b),
("-2 * a^2 * b + -1 *b^2*a", -2 * a * a * b - a * b * b),
("3 * a * b * a + -2", 3 * a * b * a - 2),
("a + 1", a + 1),
("a + -1", a - 1),
])
def test_parse_poly_spec_poly(self,
dim_spec="3 * a * b * a + -2",
dim_poly=3 * a * b * a - 2):
# For internal usage only (the polymorphic_shapes of VJP) we need to
# parse polynomials.
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (2,)))
@parameterized.named_parameters(
dict(testcase_name=f"_{dim_spec=}",
dim_spec=dim_spec, dim_poly=dim_poly)
for dim_spec, dim_poly in [
("2*a*b", 2 * a * b),
("-2 * a^2 * b + b^2", -2 * a * a * b + b * b),
("-2 * a^2 * b + -1 *b^2*a", -2 * a * a * b - a * b * b),
("3 * a * b * a + -2", 3 * a * b * a - 2),
("a + 1", a + 1),
("a + -1", a - 1),
])
def test_parse_poly_spec_shapeenv(self,
dim_spec="3 * a * b * a + -2",
dim_poly=3 * a * b * a - 2):
# For internal usage only (the polymorphic_shapes of VJP) we need to
# parse polynomials.
self.assertEqual((dim_poly,), shape_poly._parse_spec(dim_spec, (2,)))
self.assertEqual((dim_poly,), shape_poly._parse_spec(str(dim_poly), (2,)))
def test_dim_vars(self):
a, b, a1 = shape_poly._parse_spec("a, b, a", (2, 3, 2))
self.assertEqual(True, a == a)
self.assertEqual(True, a == a1)
self.assertEqual(False, a != a)
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
"Dimension polynomial comparison 'a' == 'b' is inconclusive"):
a.eq(b)
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
"Dimension polynomial comparison 'a' == 'b' is inconclusive"):
a == b
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
"Dimension polynomial comparison 'a' == 'b' is inconclusive"):
a != b
self.assertLen({a, a}, 1)
self.assertLen({a, b}, 2)
self.assertIn(a, {a, b})
self.assertIn(b, {a, b})
self.assertIn(a, [a, b])
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial comparison .* is inconclusive"):
b in [a, b]
def test_get_vars(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual({"a"}, a.get_vars())
self.assertEqual({"a", "b"}, (a * b * a).get_vars())
def test_evaluate(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual(1, (a * a - b).evaluate(dict(a=2, b=3)))
self.assertEqual(2, (a * a - b + 1).evaluate(dict(a=-2, b=3)))
def test_dim_vars_symbolic_equal(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertTrue(core.symbolic_equal_dim(a, a))
self.assertFalse(core.symbolic_equal_dim(a, 1))
self.assertFalse(core.symbolic_equal_dim(a, b))
self.assertTrue(core.symbolic_equal_one_of_dim(a, [2, a]))
self.assertFalse(core.symbolic_equal_one_of_dim(a, [2, b]))
self.assertFalse(core.symbolic_equal_one_of_dim(a, []))
self.assertTrue(core.symbolic_equal_one_of_dim(2, [a, 3, 2]))
self.assertFalse(core.symbolic_equal_one_of_dim(1, [2, b]))
self.assertFalse(core.symbolic_equal_one_of_dim(3, []))
self.assertTrue(core.symbolic_equal_dim(1, jnp.add(0, 1))) # A DeviceArray
with self.assertRaisesRegex(TypeError,
re.escape("Shapes must be 1D sequences of concrete values of integer type, got (1, 'a').")):
self.assertTrue(core.symbolic_equal_dim(1, "a"))
def test_poly_bounds(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual(a.bounds(), (1, None))
self.assertEqual((2 * a).bounds(), (2, None))
self.assertEqual((2 * a - 3).bounds(), (-1, None))
self.assertEqual((-2 * a - 3).bounds(), (None, -5))
self.assertEqual((3 * a * b * b + 5 * a - 7).bounds(), (1, None))
self.assertEqual((3 * a * b * b - 5 * a - 7).bounds(), (None, None))
self.assertEqual((a + b - a * b + a * b * a).bounds(), (None, None))
self.assertEqual((a + 2 * b - a).bounds(), (2, None))
def test_poly_equal(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
poly3 = a + 3 - a
self.assertTrue(poly3 == 3)
self.assertTrue(poly3 == np.array(3, np.int64))
self.assertTrue(poly3 == np.array(3, np.int64)[()])
self.assertFalse((poly3 + 1) == 3)
self.assertFalse(poly3 == poly3 + 1)
self.assertTrue((2 * a * b * a + 3).eq(1 + b * a * a + a * a * b + 2))
self.assertFalse((2 * a * b * a + 3).eq(a * b * a + 3))
self.assertFalse((a * b * a + 3).eq(a * b * a + 4))
self.assertFalse((2 * a * b * a).eq(a * b * a))
self.assertFalse((2 * a * b * a + 1).eq(a * b * a))
self.assertFalse((3 * a * b * a - 1).eq(a * b * a))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
re.escape("Dimension polynomial comparison '3*a^2*b + -2' == 'a^2*b' is inconclusive")):
(3 * a * b * a - 2).eq(a * b * a)
def test_poly_compare(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
poly = 4 * a + b + 3
self.assertTrue(poly.ge(0))
self.assertTrue(poly.ge(8))
self.assertTrue(poly.ge(poly))
self.assertTrue(poly.ge(poly - 1))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
poly.ge(9)
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
(4 * a - b).ge(0)
def test_poly_compare_overload(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
poly = 4 * a + b + 3
self.assertTrue(poly >= 0)
self.assertTrue(poly >= 8)
self.assertTrue(poly > 7)
self.assertTrue(poly >= poly)
self.assertTrue(poly >= poly - 1)
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
poly >= 9
with self.assertRaisesRegex(core.InconclusiveDimensionOperation, "inconclusive"):
(4 * a - b) >= 0
def test_core_greater_equal(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertTrue(core.greater_equal_dim(a, a))
self.assertTrue(core.greater_equal_dim(a, 0))
self.assertTrue(core.greater_equal_dim(a, 1))
self.assertTrue(core.greater_equal_shape((a, 2), (1, 1)))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial comparison .* is inconclusive"):
core.greater_equal_dim(a, 2)
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Dimension polynomial comparison .* is inconclusive"):
core.greater_equal_dim(a, b)
def test_poly_int_results(self):
a, b = shape_poly._parse_spec("a, b", (2, 3))
self.assertEqual(a + 2 - a, 2)
self.assertIsInstance(a + 2 - a, int)
self.assertEqual(a + (2 - a), 2)
self.assertIsInstance(a + (2 - a), int)
self.assertEqual(a * 2 // a, 2)
self.assertIsInstance(a * 2 // a, int)
@parameterized.named_parameters(
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}_r={remainder}",
dividend=dividend, divisor=divisor, quotient=quotient,
remainder=remainder)
for dividend, divisor, quotient, remainder in [
(a, 1, a, 0),
(3 * a, 3, a, 0),
(3 * a + 3, 3, a + 1, 0),
(3 * a + 2, 3, a, 2),
(3 * a + 5, 3, a + 1, 2),
(3 * a - 2, 3, a - 1, 1),
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b, 0),
(a * a - b * b, a + b, a - b, 0),
(a, b, None, None),
(3 * a, 2, None, None),
(2 * a * b + b * b, a + b, None, None),
(3, a, None, None),
])
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
if quotient is None:
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Cannot divide .* by .*"):
divmod(dividend, divisor)
else:
self.assertEqual((quotient, remainder), divmod(dividend, divisor))
@parameterized.named_parameters(
dict(testcase_name=f"_D={dividend}_d={divisor}_q={quotient}",
dividend=dividend, divisor=divisor, quotient=quotient)
for dividend, divisor, quotient in [
(a, 1, a),
(3 * a, 3, a),
(3 * a + 3, 3, a + 1),
(3 * a + 2, 3, None),
(3 * a + 5, 3, None),
(3 * a - 2, 3, None),
(3 * a * a * b + 2 * b * b * a, a * b, 3 * a + 2 * b),
(a * a - b * b, a + b, a - b),
(a, b, None),
(3 * a, 2, None),
(2 * a * b + b * b, a + b, None),
])
def test_poly_truediv(self, *, dividend, divisor, quotient):
if quotient is None:
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Cannot divide .* by .*"):
dividend / divisor
else:
self.assertEqual(quotient, dividend / divisor)
def test_poly_truediv_error(self):
a, = shape_poly._parse_spec("a,", (2,))
with self.assertRaisesRegex(core.InconclusiveDimensionOperation,
"Cannot divide .* by .*"):
3 / a
def test_dilate_shape(self):
"""0 if d == 0 else 1 + dilation * (d - 1))"""
a, = shape_poly._parse_spec("a,", (2,))
self.assertEqual((4, 7), core.dilate_shape((2, 3), (3, 3)))
self.assertEqual((0, 7), core.dilate_shape((0, 3), (3, 3)))
self.assertEqual((a, 7), core.dilate_shape((a, 3), (1, 3)))
self.assertEqual((2 * a - 1, 7), core.dilate_shape((a, 3), (2, 3)))
def test_stride_shape(self):
"""(s - window_size) // window_stride + 1"""
a, stride = shape_poly._parse_spec("a, s", (2, 3))
self.assertEqual((8, 9), core.stride_shape((10, 20), window_size=(3, 3), window_stride=(1, 2)))
self.assertEqual((a, 9), core.stride_shape((a, 20), (1, 3), (1, 2)))
self.assertEqual((a - 1, 9), core.stride_shape((a, 20), (2, 3), (1, 2)))
self.assertEqual((a + 1, 9), core.stride_shape((a * stride + 2, 20), (2, 3), (stride, 2)))
with self.assertRaisesRegex(
core.InconclusiveDimensionOperation,
re.escape(
"Cannot compute stride for dimension 'a', window_size '1', stride '2'.\nDetails: Cannot divide 'a + -1' by '2'")):
core.stride_shape((a, 20), (1, 3), (2, 2))
class ShapePolyTest(tf_test_util.JaxToTfTestCase):
def test_simple_unary(self):
"""Test shape polymorphism for a simple case, unary function."""
def f_jax(x):
return x + jnp.sin(x)
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([2, 3])],
polymorphic_shapes=None,
expected_output_signature=tf.TensorSpec([2, 3]))
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([2, None])],
polymorphic_shapes=["_, h"],
expected_output_signature=tf.TensorSpec([2, None]))
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, None])],
polymorphic_shapes=["h, h"],
expected_output_signature=tf.TensorSpec([None, None]))
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, None])],
polymorphic_shapes="h, h",
expected_output_signature=tf.TensorSpec([None, None]))
def test_simple_binary(self):
"""Test shape polymorphism for a simple case, binary function."""
def f_jax(x, y):
return x + jnp.sin(y)
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([2, 3]), tf.TensorSpec([2, 3])],
polymorphic_shapes=None,
expected_output_signature=tf.TensorSpec([2, 3]))
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([2, None]), tf.TensorSpec([2, 3])],
polymorphic_shapes="_, h",
expected_output_signature=tf.TensorSpec([2, 3]))
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, None]), tf.TensorSpec([None, None])],
polymorphic_shapes=PS("h", "h"),
expected_output_signature=tf.TensorSpec([None, None]))
def test_arange(self):
def f_jax(x):
return x + jnp.arange(x.shape[0], dtype=np.float32)
x = np.ones((3,), dtype=np.float32)
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="b")(x),
f_jax(x))
def test_argmax(self):
def f_jax(x): # x: f32[b, 4, 5]
return lax.argmax(x, axis=1, index_dtype=np.int32)
x = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
self.assertAllClose(jax2tf.convert(f_jax, polymorphic_shapes="(b, _, _)")(x),
f_jax(x))
def test_static_shape_result(self):
"""The result has static shape."""
def f_jax(x):
return jnp.sum(x + jnp.sin(x), axis=0)
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([2, 3])],
polymorphic_shapes=None,
expected_output_signature=tf.TensorSpec([3]))
self.CheckShapePolymorphism(
f_jax,
input_signature=[tf.TensorSpec([None, 3])],
polymorphic_shapes="b, _",
expected_output_signature=tf.TensorSpec([3]))
def test_forgot_polymorphic_shapes_error(self):
msg_re = "polymorphic shape None in axis .* must contain a dimension variable for unknown dimension in argument shape .*. Perhaps you forgot to add the polymorphic_shapes"
with self.assertRaisesRegex(ValueError, msg_re):
self.CheckShapePolymorphism(
jnp.sin,
input_signature=[tf.TensorSpec([1, None])],
polymorphic_shapes=None)
def test_kwargs(self):
"""Test shape polymorphism for a function with kwargs."""
x = np.ones(3, dtype=np.float32)
y = np.ones(1, dtype=np.float32)
def f_jax(x, *, y):
return x + jnp.sin(y)
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["b, ..."])
f_tf(x, y=y)
def test_arg_avals(self):
"""Test conversion of actual arguments to abstract values."""
def check_avals(*, arg_shapes: Sequence[Sequence[Optional[int]]],
polymorphic_shapes: Sequence[Optional[Union[str, PS]]],
expected_avals: Optional[Sequence[core.ShapedArray]] = None,
expected_shapeenv: Optional[Dict[str, int]] = None,
eager_mode: bool = False):
# Use eager mode only for when all arg_shapes are known, in order to
# check expected_shapeenv.
arg_dtypes = (_f32,) * len(arg_shapes)
def f_tf(*args_tf):
avals = tuple(map(shape_poly.arg_aval, arg_shapes, arg_dtypes, polymorphic_shapes))
dim_vars, get_dim_values_jax = shape_poly.prepare_dim_var_env(avals)
dim_values, _ = jax2tf.jax2tf._interpret_fun_jax(get_dim_values_jax,
args_tf, avals, "")
if expected_avals is not None:
self.assertEqual(expected_avals, avals)
return dict(zip(dim_vars, dim_values))
if eager_mode:
# If we want to check the shape_env then all arg_shapes must be known
assert all(all(d is not None for d in a_s)
for a_s in arg_shapes)
shape_env = f_tf(*[tf.ones(a_s, dtype=_f32) for a_s in arg_shapes])
if expected_shapeenv is not None:
for v, val in expected_shapeenv.items():
self.assertEqual(val, shape_env.get(v))
else:
f_tf = tf.function(autograph=False)(f_tf)
f_tf.get_concrete_function(*[tf.TensorSpec(a_s, _f32)
for a_s in arg_shapes])
assert not expected_shapeenv, "Should use eager_mode=True"
def shaped_array(shape_spec: str, actual_shape: core.Shape):
return core.ShapedArray(
shape_poly._parse_spec(shape_spec, actual_shape), np.float32)
# Known shapes for the arguments
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[None],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["(2, 3)"],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["(_, 3)"],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS("_", 3)],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["..."],
expected_avals=(shaped_array("2, 3", [2, 3]),))
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS(...)],
expected_avals=(shaped_array("2, 3", [2, 3]),))
# Partially known shapes for the arguments
check_avals(
arg_shapes=[(None, 3)],
polymorphic_shapes=[PS("b", ...)],
expected_avals=(shaped_array("(b, 3)", (2, 3)),))
check_avals(
arg_shapes=[(None, None)],
polymorphic_shapes=["h, h"],
expected_avals=(shaped_array("(h, h)", (2, 2)),))
check_avals(
arg_shapes=[(2, None)],
polymorphic_shapes=["h, h"],
expected_avals=(shaped_array("(h, h)", (2, 2)),))
check_avals(
arg_shapes=[(None, 3, 4)],
polymorphic_shapes=["(c, b, a)"],
expected_avals=(shaped_array("(c, b, a)", (2, 3, 4)),),
)
# Check cases when the specifications are polynomials
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS("a + 1", "b + 2")],
eager_mode=True,
expected_shapeenv=dict(a=1, b=1))
check_avals(
arg_shapes=[(7, 5)],
polymorphic_shapes=[PS("2 * a + b", "b + 2")],
eager_mode=True,
expected_shapeenv=dict(a=2, b=3))
check_avals(
arg_shapes=[(7, 11, 4)],
polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + 1")],
eager_mode=True,
expected_shapeenv=dict(a=2, b=3))
check_avals(
arg_shapes=[(7, 11, 19, 7)],
polymorphic_shapes=[PS("2 * a + b", "b * b + 2", "b + c * c", "2 * c + -1")],
eager_mode=True,
expected_shapeenv=dict(a=2, b=3, c=4))
with self.assertRaisesRegex(ValueError,
"Cannot solve for values of dimension variables {'b'}"):
check_avals(
arg_shapes=[(4, 36, 3)],
polymorphic_shapes=[PS("b * b", "b * d * d", "d")])
with self.assertRaisesRegex(ValueError,
"Dimension variable b must have integer value >= 1"):
check_avals(
arg_shapes=[(5, 36)],
polymorphic_shapes=[PS("3 * b", ...)],
eager_mode=True)
with self.assertRaisesRegex(ValueError,
"Dimension variable b must have integer value >= 1"):
check_avals(
arg_shapes=[(10, 3)],
polymorphic_shapes=[PS("3 * b + 10", ...)],
eager_mode=True)
with self.assertRaisesRegex(ValueError,
"Dimension variable b must have integer value >= 1"):
check_avals(
arg_shapes=[(7, 3)],
polymorphic_shapes=[PS("3 * b + 10", ...)],
eager_mode=True)
for invalid_syntax in [")(", "2a", "a@", "a - 2", "'a'", "('a', ...)"]:
with self.assertRaisesRegex(ValueError,
re.escape("has invalid syntax")):
check_avals(
arg_shapes=[(2,)], polymorphic_shapes=[invalid_syntax])
for invalid_syntax in [5.0, ["a list"], ("a tuple",), re.compile(".")]:
with self.assertRaisesRegex(ValueError,
re.escape("Invalid polymorphic shape element")):
check_avals(
arg_shapes=[(2,)], polymorphic_shapes=[PS([invalid_syntax])])
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic shape '..., 3' can contain Ellipsis only at the end.")):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["..., 3"])
with self.assertRaisesRegex(
ValueError,
re.escape(
"polymorphic shape '2, 3, 4, ...' of rank 3 must match the rank 2 of argument shape (2, 3).")
):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["2, 3, 4, ..."])
with self.assertRaisesRegex(
ValueError,
re.escape(
"polymorphic shape (Ellipsis, 3) can contain Ellipsis only at the end.")):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=[PS(..., 3)])
with self.assertRaisesRegex(
ValueError,
re.escape(
"polymorphic shape None in axis 1 must contain a dimension variable for unknown dimension in argument shape (2, None)"
)):
check_avals(
arg_shapes=[(2, None)],
polymorphic_shapes=[None])
with self.assertRaisesRegex(
ValueError,
re.escape("polymorphic shape '()' of rank 0 must match the rank 2 of argument shape (2, 3)")):
check_avals(
arg_shapes=[(2, 3)], polymorphic_shapes=["()"])
with self.assertRaisesRegex(
ValueError,
re.escape(
"polymorphic shape '(_, _)' in axis 1 must contain a dimension variable "
"for unknown dimension in argument shape (2, None)"
)):
check_avals(
arg_shapes=[(2, None)],
polymorphic_shapes=["(_, _)"])
with self.assertRaisesRegex(
ValueError,
re.escape(
"polymorphic shape '(2, 13)' in axis 1 must match the known dimension size 3 "
"for argument shape (2, 3)"
)):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["(2, 13)"])
with self.assertRaisesRegex(
ValueError,
re.escape(
"polymorphic shape '(2, 3)' in axis 1 must contain a dimension variable for "
"unknown dimension in argument shape (2, None)"
)):
check_avals(
arg_shapes=[(2, None)],
polymorphic_shapes=["(2, 3)"])
with self.assertRaisesRegex(
ValueError,
"Found inconsistency when solving.*"):
check_avals(
arg_shapes=[(2, 3)],
polymorphic_shapes=["(a, a)"],
eager_mode=True)
# Same error across multiple arguments
with self.assertRaisesRegex(
ValueError,
"Found inconsistency when solving.*"):
check_avals(
arg_shapes=[(2, 3), (5,)],
polymorphic_shapes=["a, ...", "a"],
eager_mode=True)
def test_pytree(self):
"""Arguments and polymorphic_shapes are pytrees."""
# Arguments are of the form [([x00, x01], [x10]), dict(a=ya, b=yb)]
def add_all_jax(x_pair_of_list, y_dict):
x_list_0, x_list_1 = x_pair_of_list
return functools.reduce(operator.add,
x_list_0 + x_list_1 + [y_dict["a"], y_dict["b"]])
self.CheckShapePolymorphism(
add_all_jax,
input_signature=[([tf.TensorSpec([None]),
tf.TensorSpec([None])], [tf.TensorSpec([None])]),
dict(a=tf.TensorSpec([None]),
b=tf.TensorSpec([None]))],
polymorphic_shapes=[(["v", "v"], [("v")]),
dict(a="v", b="v")],
expected_output_signature=tf.TensorSpec([None]))
# Now partial polymorphic_shapes; the parts of the polymorphic_shapes that
# are not specified must have full input_signatures.
self.CheckShapePolymorphism(
add_all_jax,
input_signature=[([tf.TensorSpec([4]),
tf.TensorSpec([4])], [tf.TensorSpec([4])]),
dict(a=tf.TensorSpec([4]), b=tf.TensorSpec([4]))],
polymorphic_shapes=[(["(4,)", "(_,)"], [("4,")]),
dict(a="(_,)", b="(4,)")],
expected_output_signature=tf.TensorSpec([4]))
def test_with_nested_jit(self):
x = np.ones((3, 4), dtype=np.float32)
# We implement the following computation
_ = x + (np.sin(x) + np.broadcast_to(np.arange(x.shape[1]), x.shape))
def f_jax(x): # x: f32[w, h]
return jnp.sin(x) + jnp.arange(x.shape[1], dtype=x.dtype)
self.CheckShapePolymorphism(
lambda x: x + jax.jit(f_jax)(x),
input_signature=[tf.TensorSpec([None, None])],
polymorphic_shapes=["w, h"])
def test_non_trivial_polynomials(self):
if config.jax_dynamic_shapes:
raise unittest.SkipTest("--jax_dynamic_shapes supports only trivial polynomials")
# We can handle non-trivial polynomials in the input shape,
# as long as all variables also occur in trivial polynoamials
self.CheckShapePolymorphism(
lambda x, y: x + y.reshape((-1,)),
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None, None])],
polymorphic_shapes=["b * b", "b, b"])
def test_unused_args(self):
# Tests with functions that do not use their inputs.
# First arg unused, not polymorphic
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([]), tf.TensorSpec([None])],
polymorphic_shapes=[None, "b"])
# Some args unused, not polymorphic
self.CheckShapePolymorphism(
lambda x_unused, y, z_unused, w: jnp.concatenate([y, w]),
input_signature=[tf.TensorSpec([]), tf.TensorSpec([None]),
tf.TensorSpec([]), tf.TensorSpec([None])],
polymorphic_shapes=[None, "b1", None, "b2"])
# A polymorphic arg is not used, but the dimension var appears
# in a used arg also
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b", "b"])
# A polymorphic arg is not used, and the dimension var does not appear
# elsewhere.
with self.assertRaisesRegex(ValueError,
"The following dimension variables cannot be computed"):
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b2"])
# A polymorphic arg is not used, and the dimension var does appear
# elsewhere but not as a trivial monomial.
with self.assertRaisesRegex(ValueError,
"The following dimension variables cannot be computed"):
self.CheckShapePolymorphism(
lambda x_unused, y: y * 2.0,
input_signature=[tf.TensorSpec([None]), tf.TensorSpec([None])],
polymorphic_shapes=["b1", "b1 * b1"])
def test_with_custom_vjp(self):
"""Shape-polymorphic custom VJP."""
@jax.custom_vjp
def f(x):
# x: [b1, b2, d1, d2] (a batch of matrices)
# res: [b1, b2, d1, d1]
return jnp.matmul(x, jnp.transpose(x, axes=(0, 1, 3, 2)))
# f_fwd: a -> (b, residual)
def f_fwd(x):
# x: [b1, b2, d1, d2]
# b: [b1, b2, d1, d1]
# res: [b1, b2, d1, d1]
# residual: [b1, b2, d1, d2]
return f(x), 3. * x
# f_bwd: (residual, CT b) -> [CT a]
def f_bwd(residual, ct_b):
# residual: [b1, b2, d1, d2]
# ct_b: [b1, b2, d1, d1]
# ct_a: [b1, b2, d1, d2]
return jnp.matmul(ct_b, residual),
f.defvjp(f_fwd, f_bwd)
x = np.ones((2, 3, 4, 5), dtype=np.float32)
res_jax = f(x)
res_jax_grad = jax.grad(lambda x: jnp.sum(f(x)))(x)
f_tf = self.CheckShapePolymorphism(
f,
input_signature=[tf.TensorSpec([None, None, None, None])],
polymorphic_shapes=["(batch1, batch2, d1, d2)"],
expected_output_signature=tf.TensorSpec([None, None, None, None]))
self.assertAllClose(res_jax, f_tf(x))
xv = tf.Variable(x, dtype=np.float32)
def tf_value_and_grad(xv):
with tf.GradientTape() as tape:
tape.watch(xv)
res_tf = f_tf(xv)
res_tf_grad = tape.gradient(res_tf, xv)
return res_tf, res_tf_grad
res_tf, res_tf_grad = tf_value_and_grad(xv)
self.assertAllClose(res_jax, res_tf)
self.assertAllClose(res_jax_grad, res_tf_grad)
# Now use TF tracing for the gradient
tf_grad = tf.function(
tf_value_and_grad, autograph=False).get_concrete_function(
tf.TensorSpec([3, 4, 8, 9]))
self.assertEqual((3, 4, 8, 8), tuple(tf_grad.output_shapes[0]))
self.assertEqual((3, 4, 8, 9), tuple(tf_grad.output_shapes[1]))
def test_gradients_pytree(self):
"""Shape polymorphism with gradients and pytrees for inputs and outputs."""
def f(x):
# x: dict(x=[b, 3, 4])
# res: dict(res=[b, 3, 4])
return dict(res=x["x"] * 2.)
f_tf = self.CheckShapePolymorphism(
f,
input_signature=[dict(x=tf.TensorSpec([None, 3, 4]))],
polymorphic_shapes=[dict(x=("b, 3, 4"))],
expected_output_signature=None)
x = dict(x=np.ones((2, 3, 4), dtype=np.float32))
xv = tf.Variable(x["x"], dtype=np.float32)
def tf_value_and_grad(xv):
# xv: [b, 3, 4]
# res_value: dict(res=[b, 3, 4])
# res_grad: dict(grad=[b, 3, 4])
with tf.GradientTape() as tape:
tape.watch(xv)
res_tf = f_tf(dict(x=xv))
res_tf_grad = tape.gradient(res_tf, xv)
return res_tf, dict(grad=res_tf_grad)
res_tf, res_tf_grad = tf_value_and_grad(xv)
# Now use TF tracing for the gradient
tf_grad = tf.function(
tf_value_and_grad,
autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 4]))
# The shape of the value
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[0]["res"]))
# The shape of the gradient should match the input
self.assertEqual((None, 3, 4), tuple(tf_grad.output_shapes[1]["grad"]))
def test_grad_not_var_output(self):
# Output of the function has poly shapes, non-variable
def f_jax(x): # :[b, 3]
return jnp.reshape(x, (-1,)) # : [3b]
x = np.arange(12, dtype=np.float32).reshape((4, 3))
xv = tf.Variable(x)
f_tf = jax2tf.convert(f_jax, with_gradient=True,
polymorphic_shapes=["b, ..."])
with tf.GradientTape() as tape:
res_tf = f_tf(xv)
grad_tf = tape.gradient(res_tf, xv)
self.assertAllClose(np.ones(x.shape, dtype=np.float32), grad_tf.numpy())
def test_cond(self):
# Test the primitive under conditional
def f(x, y):
# x: f32[B, H], y : f32[H]
return lax.cond(
jnp.sum(x) > 0.,
lambda _: x + y,
lambda _: jnp.zeros_like(x),
operand=None)
x = np.ones((2, 3))
y = np.ones((3,))
res_jax = f(x, y)
self.assertAllClose(
res_jax,
jax2tf.convert(f, polymorphic_shapes=["(b, h)", "h"])(x, y))
def test_while(self):
def f(x):
# x: f32[B], iter: i32
return lax.while_loop(lambda x_iter: x_iter[1] < 5,
lambda x_iter: (x_iter[0] + jnp.arange(x_iter[0].shape[0], dtype=np.float32), x_iter[1] + 1),
(x, 0))
x = np.ones((3,), dtype=np.float32)
res_tf = jax2tf.convert(f, polymorphic_shapes=["(b,)"])(x)
self.assertAllClose(f(x), res_tf)
@jtu.sample_product(with_function=[False, True])
def test_grad_int(self, with_function=True):
# https://github.com/google/jax/issues/7093
# Also issue #6975.
x_shape = (2, 3, 4)
xi = np.arange(np.prod(x_shape), dtype=np.int16).reshape(x_shape)
yf = xi.astype(np.float32)
xi_yf = (xi, yf)
zb = np.array([True, False], dtype=np.bool_)
def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2]
xi, yf = xi_yf
# Return a tuple:
# (1) float constant, with 0 tangent;
# (2) a tuple with:
# (2.1) the integer input;
# (2.2) the boolean input;
# (2.3) a float depending on both inputs.
# TODO: there is a problem if we add a None output
return (jnp.zeros(xi.shape, dtype=jnp.float32),
(xi, zb, xi.astype(np.float32) * 2. * yf))
args = (xi_yf, zb)
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=[("b1, b2, 4", "b1, b2, 4"), "b1"])
if with_function:
f_tf = tf.function(f_tf, autograph=False)
res_tf, g_tf = tf_test_util.ComputeTfValueAndGrad(f_tf, args)
self.assertAllClose(g_tf[0][0], np.zeros_like(xi))
self.assertAllClose(g_tf[0][1], (xi * 2).astype(yf.dtype))
self.assertAllClose(g_tf[1], np.zeros_like(zb))
def test_saved_model(self):
f_jax = jnp.sin
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
restored_f, _ = tf_test_util.SaveAndLoadFunction(
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
self.assertAllClose(f_jax(x), restored_f(x))
# Ensure that restored_f works at other batch size as well
y = np.concatenate([x, x])
self.assertAllClose(f_jax(y), restored_f(y))
def test_saved_model_int_function(self):
def f_jax(x): # x:s32[b, 3, 4]
return jnp.reshape(x, (-1,)) # : s32[b * 12]
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
f_tf = tf.function(f_tf, autograph=False)
x_shape = (2, 3, 4)
x = np.arange(np.prod(x_shape), dtype=np.int32).reshape(x_shape)
# When saving the model with gradients, we trace the gradient function
# and we used to get an error when creating zeros_like_aval for a
# polymorphic shape
restored_f, _ = tf_test_util.SaveAndLoadFunction(
f_tf, input_signature=[tf.TensorSpec((None,) + x.shape[1:], x.dtype)])
f_jax_rt = jax2tf.call_tf(restored_f)
res_jax_rt = f_jax_rt(x)
self.assertAllClose(f_jax(x), res_jax_rt)
def test_saved_model_constant_gradient(self):
def f_jax(x): # A function whose gradient is a constant
return 3.
f_tf = jax2tf.convert(f_jax, polymorphic_shapes=["(b, ...)"])
x = np.array([0.7, 0.8], dtype=np.float32)
restored_f, _ = tf_test_util.SaveAndLoadFunction(
f_tf, input_signature=[tf.TensorSpec([None], x.dtype)])
self.assertAllClose(3., restored_f(x))
self.assertAllClose(np.array([0., 0.], dtype=np.float32), jax.grad(f_jax)(x))
def test_readme_examples(self):
"""Some of the examples from the README."""
jax2tf.convert(lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],)),
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
jax2tf.convert(lambda x: jnp.reshape(x, (np.prod(x.shape),)),
polymorphic_shapes=["(b, 4)"])(np.ones((3, 4)))
jax2tf.convert(lambda x: x + x.shape[0] + jnp.sin(x.shape[0]),
polymorphic_shapes=["b"])(np.ones(3))
jax2tf.convert(lambda x: jnp.sum(x, axis=0) / x.shape[0],
polymorphic_shapes=["(v, _)"])(np.ones((3, 4)))
jax2tf.convert(lambda x: jnp.array(x.shape[0]),