Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CPU backed failure on ARM due to XLA/LLVM (and a potential fix) #5679

Closed
girgink opened this issue Feb 8, 2021 · 14 comments
Closed

CPU backed failure on ARM due to XLA/LLVM (and a potential fix) #5679

girgink opened this issue Feb 8, 2021 · 14 comments
Labels
bug Something isn't working open Issues intentionally left open, with no schedule for next steps.

Comments

@girgink
Copy link

girgink commented Feb 8, 2021

Hi,

JAX CPU backend fails on ARM architecture (e.g. NVIDIA Jetson AGX, ARMv8.2) with the following errors:

Python 3.6.12 (default, Aug 17 2020, 23:45:20) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jax.config import config;
>>> config.update('jax_platform_name', 'cpu')
>>> import jax.numpy as jnp
>>> x = jnp.ones((1000, 1000))
2021-02-08 17:35:23.661294: W external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:116] Failed to find bogomips or clock in /proc/cpuinfo; cannot determine CPU frequency
>>> x*x
'carmel' is not a recognized processor for this target (ignoring processor)
'+neon' is not a recognized feature for this target (ignoring feature)
'+fp-armv8' is not a recognized feature for this target (ignoring feature)
'+crypto' is not a recognized feature for this target (ignoring feature)
'+crc' is not a recognized feature for this target (ignoring feature)
'carmel' is not a recognized processor for this target (ignoring processor)
'carmel' is not a recognized processor for this target (ignoring processor)
'+neon' is not a recognized feature for this target (ignoring feature)
'+fp-armv8' is not a recognized feature for this target (ignoring feature)
'+crypto' is not a recognized feature for this target (ignoring feature)
'+crc' is not a recognized feature for this target (ignoring feature)
'carmel' is not a recognized processor for this target (ignoring processor)
'carmel' is not a recognized processor for this target (ignoring processor)
'+neon' is not a recognized feature for this target (ignoring feature)
'+fp-armv8' is not a recognized feature for this target (ignoring feature)
'+crypto' is not a recognized feature for this target (ignoring feature)
'+crc' is not a recognized feature for this target (ignoring feature)
'carmel' is not a recognized processor for this target (ignoring processor)
'carmel' is not a recognized processor for this target (ignoring processor)
'+neon' is not a recognized feature for this target (ignoring feature)
'+fp-armv8' is not a recognized feature for this target (ignoring feature)
'+crypto' is not a recognized feature for this target (ignoring feature)
'+crc' is not a recognized feature for this target (ignoring feature)
'carmel' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
Aborted (core dumped)

JAX was reported to work on similar architectures (e.g. NVIDIA Jetson TX2), but I think they only used GPU backend, which works fine.

We have identified the problem as missing AArch64 statements in XLA and LLVM build files. The following seem to fix the problem:

  • Adding following lines to cpu_compiler:deps in org_tensorflow/tensorflow/compiler/xla/service/cpu/BUILD
  "@llvm-project//llvm:ARMCodeGen",
  "@llvm-project//llvm:AArch64CodeGen",
  • Changing llvm_host_triple in org_tensorflow/third_party/llvm/llvm.autogenerated.BUILD (this is required due to default fall-back architecture that is set as X86_64, which is indicated to be fixed)
llvm_host_triple = "aarch64-unknown-linux_gnu"
  • Adding linux_aarch64 target to llvm_all_cmake_vars in org_tensorflow/third_party/llvm/llvm.bzl
	"@org_tensorflow//tensorflow:linux_aarch64": cmake_var_string(
			_dict_add(
					cmake_vars,
					llvm_target_cmake_vars("AArch64", "aarch64-unknown-linux_gnu"),
					posix_cmake_vars,
					linux_cmake_vars,
			),
	),

With these changes the result is as follows:


>>> from jax.config import config;
>>> config.update('jax_platform_name', 'cpu')
>>> import jax.numpy as jnp
>>> x = jnp.ones((1000, 1000))
>>> x*x
DeviceArray([[1., 1., 1., ..., 1., 1., 1.],
             [1., 1., 1., ..., 1., 1., 1.],
             [1., 1., 1., ..., 1., 1., 1.],
             ...,
             [1., 1., 1., ..., 1., 1., 1.],
             [1., 1., 1., ..., 1., 1., 1.],
             [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)

@zhangqiaorjc zhangqiaorjc added the enhancement New feature or request label Feb 9, 2021
@zhangqiaorjc
Copy link
Member

Thanks @girgink ! We will try to fix this internally in tensorflow tree.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Feb 9, 2021
I don't have an environment to test this, but I believe it should be enough to
get the JIT to work.

This came up in
google/jax#5679

PiperOrigin-RevId: 356578451
Change-Id: I99a2aa0e87739b9edce81074fce1ca5c0bd25115
@d0k
Copy link
Member

d0k commented Feb 9, 2021

tensorflow/tensorflow@942f315 links in AArch64 codegen and sets the cmake vars. Let us know if that's sufficient to get JAX to work on your machine.

@girgink
Copy link
Author

girgink commented Feb 10, 2021

Thanks @d0k, I will build again and inform you accordingly

@hawkinsp
Copy link
Member

@girgink Note we haven't bumped the TF version in the JAX WORKSPACE file. You'll either need to do that or point the WORKSPACE file to an up-to-date TF checkout that includes that change.

@girgink
Copy link
Author

girgink commented Feb 10, 2021

Ok!

@girgink
Copy link
Author

girgink commented Feb 11, 2021

@d0k @hawkinsp I confirm that the fix works as expected.

I run also tests with JAX_NUM_GENERATED_CASES=5 pytest -n 1, the summary is: 5 failed, 241 passed, 11 skipped, 16 errors.

All errors are related to call to deprecated create function FileDescriptor().

For the failed tests I suspected that they might be related to available memory. The unit that I use (NVIDIA Jetson AGX Xavier) has unified CPU-GPU memory. I'm not sure if it is the case for JAX, but some frameworks - e.g. TensorFlow - allocate most of GPU memory initially to speed-up processing, but because Xavier has unified memory this results in very low CPU memory. In fact, running the tests in parallel (e.g. -n auto) resulted in quick segmentation fault.

Please find below the test results:

Click to expand
/usr/local/lib/python3.8/dist-packages/pytest_benchmark/logger.py:44: PytestBenchmarkWarning: Benchmarks are automatically disabled because xdist plugin is active.Benchmarks cannot be performed reliably in a parallelized environment.
  warner(PytestBenchmarkWarning(text))
=========================================================================================== test session starts ============================================================================================
platform linux -- Python 3.8.5, pytest-6.2.2, py-1.10.0, pluggy-0.13.1
benchmark: 3.2.3 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /tmp/jax, configfile: pytest.ini
plugins: benchmark-3.2.3, xdist-2.2.1, forked-1.3.0, cov-2.11.1
gw0 [7201] 0 items / 4 errors
....................s.................ss.............................................ss........................................................................................s.................... [  2%]
..........s.......ss..........s................s........Fatal Python error: Segmentation fault

Thread 0x0000007fb41bb160 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 400 in read
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 432 in from_io
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 967 in _thread_receiver
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn

Thread 0x0000007fb4afb430 (most recent call first):
  File "/tmp/jax/jax/interpreters/xla.py", line 848 in _execute_compiled
  File "/tmp/jax/jax/interpreters/xla.py", line 590 in _xla_call_impl
  File "/tmp/jax/jax/core.py", line 631 in process_call
  File "/tmp/jax/jax/core.py", line 1410 in process
  File "/tmp/jax/jax/core.py", line 1398 in call_bind
  File "/tmp/jax/jax/core.py", line 1407 in bind
  File "/tmp/jax/jax/interpreters/batching.py", line 175 in process_call
  File "/tmp/jax/jax/core.py", line 1410 in process
  File "/tmp/jax/jax/core.py", line 1398 in call_bind
  File "/tmp/jax/jax/core.py", line 1407 in bind
  File "/tmp/jax/jax/api.py", line 292 in cache_miss
  File "/tmp/jax/jax/api.py", line 398 in f_jitted
  File "/tmp/jax/jax/_src/traceback_util.py", line 139 in reraise_with_filtered_traceback
  File "/tmp/jax/jax/_src/random.py", line 1082 in gamma
  File "/tmp/jax/tests/api_test.py", line 3636 in f
  File "/tmp/jax/jax/linear_util.py", line 166 in call_wrapped
  File "/tmp/jax/jax/core.py", line 636 in process_custom_jvp_call
  File "/tmp/jax/jax/custom_derivatives.py", line 279 in bind
  File "/tmp/jax/jax/custom_derivatives.py", line 213 in __call__
  File "/tmp/jax/tests/api_test.py", line 3645 in sample
  File "/tmp/jax/jax/linear_util.py", line 166 in call_wrapped
  File "/tmp/jax/jax/api.py", line 1219 in batched_fun
  File "/tmp/jax/jax/_src/traceback_util.py", line 139 in reraise_with_filtered_traceback
  File "/tmp/jax/tests/api_test.py", line 3647 in test_closure_with_vmap
  File "/usr/lib/python3.8/unittest/case.py", line 633 in _callTestMethod
  File "/usr/lib/python3.8/unittest/case.py", line 676 in run
  File "/usr/lib/python3.8/unittest/case.py", line 736 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/unittest.py", line 321 in runtest
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 162 in pytest_runtest_call
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 255 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 311 in from_call
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 254 in call_runtest_hook
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 215 in call_and_report
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 126 in runtestprotocol
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 89 in run_one_test
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 72 in pytest_runtestloop
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 323 in _main
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 269 in wrap_session
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 244 in <module>
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1084 in executetask
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 267 in integrate_as_primary_thread
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1060 in serve
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1554 in serve
  File "<string>", line 8 in <module>
  File "<string>", line 1 in <module>
[gw0] node down: Not properly terminated
F
replacing crashed worker gw0
gw1 [7201]Fatal Python error: Segmentation fault

Thread 0x0000007f85c3a160 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 400 in read
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 432 in from_io
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 967 in _thread_receiver
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn

Current thread 0x0000007f8657a430 (most recent call first):
  File "/tmp/jax/jax/interpreters/xla.py", line 361 in _execute_compiled_primitive
  File "/tmp/jax/jax/interpreters/xla.py", line 243 in apply_primitive
  File "/tmp/jax/jax/core.py", line 628 in process_primitive
  File "/tmp/jax/jax/core.py", line 282 in bind
  File "/tmp/jax/jax/_src/lax/lax.py", line 1493 in iota
  File "/tmp/jax/jax/_src/lax/lax.py", line 4881 in _scatter_jvp
  File "/tmp/jax/jax/interpreters/ad.py", line 285 in process_primitive
  File "/tmp/jax/jax/core.py", line 282 in bind
  File "/tmp/jax/jax/_src/lax/lax.py", line 1082 in scatter
  File "/tmp/jax/tests/lax_autodiff_test.py", line 907 in <lambda>
  File "/tmp/jax/jax/linear_util.py", line 166 in call_wrapped
  File "/tmp/jax/jax/api.py", line 1694 in _jvp
  File "/tmp/jax/jax/api.py", line 1666 in jvp
  File "/tmp/jax/jax/test_util.py", line 239 in check_jvp
  File "/tmp/jax/jax/test_util.py", line 301 in _check_grads
  File "/tmp/jax/jax/test_util.py", line 314 in check_grads
  File "/tmp/jax/tests/lax_autodiff_test.py", line 910 in testScatterGrad
  File "/usr/local/lib/python3.8/dist-packages/absl/testing/parameterized.py", line 282 in bound_param_test
  File "/usr/lib/python3.8/unittest/case.py", line 633 in _callTestMethod
  File "/usr/lib/python3.8/unittest/case.py", line 676 in run
  File "/usr/lib/python3.8/unittest/case.py", line 736 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/unittest.py", line 321 in runtest
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 162 in pytest_runtest_call
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 255 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 311 in from_call
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 254 in call_runtest_hook
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 215 in call_and_report
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 126 in runtestprotocol
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 89 in run_one_test
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 72 in pytest_runtestloop
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 323 in _main
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 269 in wrap_session
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 244 in <module>
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1084 in executetask
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 267 in integrate_as_primary_thread
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1060 in serve
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1554 in serve
  File "<string>", line 8 in <module>
  File "<string>", line 1 in <module>
[gw1] node down: Not properly terminated
F
replacing crashed worker gw1
gw2 [7201]Fatal Python error: Segmentation fault

Thread 0x0000007fab0c5160 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 400 in read
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 432 in from_io
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 967 in _thread_receiver
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn

Thread 0x0000007faba05430 (most recent call first):
  File "/tmp/jax/jax/interpreters/xla.py", line 361 in _execute_compiled_primitive
  File "/tmp/jax/jax/interpreters/xla.py", line 243 in apply_primitive
  File "/tmp/jax/jax/core.py", line 628 in process_primitive
  File "/tmp/jax/jax/core.py", line 282 in bind
  File "/tmp/jax/jax/_src/lax/lax.py", line 309 in integer_pow
  File "/tmp/jax/jax/_src/lax/lax.py", line 1908 in reciprocal
  File "/tmp/jax/jax/test_util.py", line 927 in _CheckAgainstNumpy
  File "/tmp/jax/tests/lax_test.py", line 204 in testOpAgainstNumpy
  File "/usr/local/lib/python3.8/dist-packages/absl/testing/parameterized.py", line 282 in bound_param_test
  File "/usr/lib/python3.8/unittest/case.py", line 633 in _callTestMethod
  File "/usr/lib/python3.8/unittest/case.py", line 676 in run
  File "/usr/lib/python3.8/unittest/case.py", line 736 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/unittest.py", line 321 in runtest
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 162 in pytest_runtest_call
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 255 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 311 in from_call
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 254 in call_runtest_hook
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 215 in call_and_report
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 126 in runtestprotocol
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 89 in run_one_test
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 72 in pytest_runtestloop
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 323 in _main
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 269 in wrap_session
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 244 in <module>
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1084 in executetask
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 267 in integrate_as_primary_thread
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1060 in serve
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1554 in serve
  File "<string>", line 8 in <module>
  File "<string>", line 1 in <module>
[gw2] node down: Not properly terminated
F
replacing crashed worker gw2
gw3 [7201]Fatal Python error: Segmentation fault

Thread 0x0000007fab598160 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 400 in read
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 432 in from_io
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 967 in _thread_receiver
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn

Current thread 0x0000007fabed8430 (most recent call first):
  File "/tmp/jax/jax/interpreters/xla.py", line 361 in _execute_compiled_primitive
  File "/tmp/jax/jax/interpreters/xla.py", line 243 in apply_primitive
  File "/tmp/jax/jax/core.py", line 628 in process_primitive
  File "/tmp/jax/jax/core.py", line 282 in bind
  File "/tmp/jax/jax/_src/lax/lax.py", line 1493 in iota
  File "/tmp/jax/jax/_src/lax/lax.py", line 4881 in _scatter_jvp
  File "/tmp/jax/jax/interpreters/ad.py", line 285 in process_primitive
  File "/tmp/jax/jax/core.py", line 282 in bind
  File "/tmp/jax/jax/_src/lax/lax.py", line 1082 in scatter
  File "/tmp/jax/tests/lax_autodiff_test.py", line 907 in <lambda>
  File "/tmp/jax/jax/linear_util.py", line 166 in call_wrapped
  File "/tmp/jax/jax/api.py", line 1694 in _jvp
  File "/tmp/jax/jax/api.py", line 1666 in jvp
  File "/tmp/jax/jax/test_util.py", line 239 in check_jvp
  File "/tmp/jax/jax/test_util.py", line 301 in _check_grads
  File "/tmp/jax/jax/test_util.py", line 314 in check_grads
  File "/tmp/jax/tests/lax_autodiff_test.py", line 910 in testScatterGrad
  File "/usr/local/lib/python3.8/dist-packages/absl/testing/parameterized.py", line 282 in bound_param_test
  File "/usr/lib/python3.8/unittest/case.py", line 633 in _callTestMethod
  File "/usr/lib/python3.8/unittest/case.py", line 676 in run
  File "/usr/lib/python3.8/unittest/case.py", line 736 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/unittest.py", line 321 in runtest
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 162 in pytest_runtest_call
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 255 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 311 in from_call
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 254 in call_runtest_hook
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 215 in call_and_report
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 126 in runtestprotocol
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 89 in run_one_test
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 72 in pytest_runtestloop
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 323 in _main
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 269 in wrap_session
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 244 in <module>
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1084 in executetask
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 267 in integrate_as_primary_thread
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1060 in serve
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1554 in serve
  File "<string>", line 8 in <module>
  File "<string>", line 1 in <module>
[gw3] node down: Not properly terminated
F
replacing crashed worker gw3
gw4 [7201]Fatal Python error: Segmentation fault

Thread 0x0000007f9342f160 (most recent call first):
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 400 in read
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 432 in from_io
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 967 in _thread_receiver
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn

Thread 0x0000007f93d6f430 (most recent call first):
  File "/tmp/jax/jax/interpreters/xla.py", line 361 in _execute_compiled_primitive
  File "/tmp/jax/jax/interpreters/xla.py", line 243 in apply_primitive
  File "/tmp/jax/jax/core.py", line 628 in process_primitive
  File "/tmp/jax/jax/core.py", line 282 in bind
  File "/tmp/jax/jax/_src/lax/lax.py", line 309 in integer_pow
  File "/tmp/jax/jax/_src/lax/lax.py", line 1908 in reciprocal
  File "/tmp/jax/jax/test_util.py", line 927 in _CheckAgainstNumpy
  File "/tmp/jax/tests/lax_test.py", line 204 in testOpAgainstNumpy
  File "/usr/local/lib/python3.8/dist-packages/absl/testing/parameterized.py", line 282 in bound_param_test
  File "/usr/lib/python3.8/unittest/case.py", line 633 in _callTestMethod
  File "/usr/lib/python3.8/unittest/case.py", line 676 in run
  File "/usr/lib/python3.8/unittest/case.py", line 736 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/unittest.py", line 321 in runtest
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 162 in pytest_runtest_call
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 255 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 311 in from_call
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 254 in call_runtest_hook
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 215 in call_and_report
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 126 in runtestprotocol
  File "/usr/local/lib/python3.8/dist-packages/_pytest/runner.py", line 109 in pytest_runtest_protocol
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 89 in run_one_test
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 72 in pytest_runtestloop
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 323 in _main
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 269 in wrap_session
  File "/usr/local/lib/python3.8/dist-packages/_pytest/main.py", line 316 in pytest_cmdline_main
  File "/usr/local/lib/python3.8/dist-packages/pluggy/callers.py", line 187 in _multicall
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 84 in <lambda>
  File "/usr/local/lib/python3.8/dist-packages/pluggy/manager.py", line 93 in _hookexec
  File "/usr/local/lib/python3.8/dist-packages/pluggy/hooks.py", line 286 in __call__
  File "/usr/local/lib/python3.8/dist-packages/xdist/remote.py", line 244 in <module>
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1084 in executetask
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 220 in run
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 285 in _perform_spawn
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 267 in integrate_as_primary_thread
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1060 in serve
  File "/usr/local/lib/python3.8/dist-packages/execnet/gateway_base.py", line 1554 in serve
  File "<string>", line 8 in <module>
  File "<string>", line 1 in <module>
[gw4] node down: Not properly terminated
F
maximum crashed workers reached: 4

================================================================================================== ERRORS ==================================================================================================
____________________________________________________________________________ ERROR collecting tests/host_callback_to_tf_test.py ____________________________________________________________________________
tests/host_callback_to_tf_test.py:36: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
___________________________________________________________________________________ ERROR collecting tests/image_test.py ___________________________________________________________________________________
tests/image_test.py:34: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
_________________________________________________________________________________ ERROR collecting tests/lax_vmap_test.py __________________________________________________________________________________
ImportError while importing test module '/tmp/jax/tests/lax_vmap_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/usr/lib/python3.8/importlib/__init__.py:127: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/lax_vmap_test.py:33: in <module>
    from tests.lax_test import LAX_OPS
E   ModuleNotFoundError: No module named 'tests.lax_test'
_________________________________________________________________________________ ERROR collecting tests/profiler_test.py __________________________________________________________________________________
tests/profiler_test.py:35: in <module>
    from tensorflow.python.profiler import profiler_client
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
____________________________________________________________________________ ERROR collecting tests/host_callback_to_tf_test.py ____________________________________________________________________________
tests/host_callback_to_tf_test.py:36: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
___________________________________________________________________________________ ERROR collecting tests/image_test.py ___________________________________________________________________________________
tests/image_test.py:34: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
_________________________________________________________________________________ ERROR collecting tests/profiler_test.py __________________________________________________________________________________
tests/profiler_test.py:35: in <module>
    from tensorflow.python.profiler import profiler_client
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
____________________________________________________________________________ ERROR collecting tests/host_callback_to_tf_test.py ____________________________________________________________________________
tests/host_callback_to_tf_test.py:36: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
___________________________________________________________________________________ ERROR collecting tests/image_test.py ___________________________________________________________________________________
tests/image_test.py:34: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
_________________________________________________________________________________ ERROR collecting tests/profiler_test.py __________________________________________________________________________________
tests/profiler_test.py:35: in <module>
    from tensorflow.python.profiler import profiler_client
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
____________________________________________________________________________ ERROR collecting tests/host_callback_to_tf_test.py ____________________________________________________________________________
tests/host_callback_to_tf_test.py:36: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
___________________________________________________________________________________ ERROR collecting tests/image_test.py ___________________________________________________________________________________
tests/image_test.py:34: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
_________________________________________________________________________________ ERROR collecting tests/profiler_test.py __________________________________________________________________________________
tests/profiler_test.py:35: in <module>
    from tensorflow.python.profiler import profiler_client
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
____________________________________________________________________________ ERROR collecting tests/host_callback_to_tf_test.py ____________________________________________________________________________
tests/host_callback_to_tf_test.py:36: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
___________________________________________________________________________________ ERROR collecting tests/image_test.py ___________________________________________________________________________________
tests/image_test.py:34: in <module>
    import tensorflow as tf
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
_________________________________________________________________________________ ERROR collecting tests/profiler_test.py __________________________________________________________________________________
tests/profiler_test.py:35: in <module>
    from tensorflow.python.profiler import profiler_client
/usr/local/lib/python3.8/dist-packages/tensorflow/__init__.py:41: in <module>
    from tensorflow.python.tools import module_util as _module_util
/usr/local/lib/python3.8/dist-packages/tensorflow/python/__init__.py:41: in <module>
    from tensorflow.python.eager import context
/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/context.py:32: in <module>
    from tensorflow.core.framework import function_pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/function_pb2.py:16: in <module>
    from tensorflow.core.framework import attr_value_pb2 as tensorflow_dot_core_dot_framework_dot_attr__value__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/attr_value_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_pb2.py:16: in <module>
    from tensorflow.core.framework import resource_handle_pb2 as tensorflow_dot_core_dot_framework_dot_resource__handle__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/resource_handle_pb2.py:16: in <module>
    from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
/usr/local/lib/python3.8/dist-packages/tensorflow/core/framework/tensor_shape_pb2.py:18: in <module>
    DESCRIPTOR = _descriptor.FileDescriptor(
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:952: in __init__
    _Deprecated('FileDescriptor')
/usr/local/lib/python3.8/dist-packages/google/protobuf/descriptor.py:98: in _Deprecated
    warnings.warn(
E   DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generated code or query the descriptor_pool.
================================================================================================= FAILURES =================================================================================================
____________________________________________________________________________________________ tests/api_test.py _____________________________________________________________________________________________
[gw0] linux -- Python 3.8.5 /usr/bin/python
worker 'gw0' crashed while running 'tests/api_test.py::CustomJVPTest::test_closure_with_vmap'
________________________________________________________________________________________ tests/lax_autodiff_test.py ________________________________________________________________________________________
[gw1] linux -- Python 3.8.5 /usr/bin/python
worker 'gw1' crashed while running 'tests/lax_autodiff_test.py::LaxAutodiffTest::testScatterGrad_shape=float32[10]_idxs=[[0]\n [0]\n [0]]_update=(3, 2)_dnums=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))'
____________________________________________________________________________________________ tests/lax_test.py _____________________________________________________________________________________________
[gw2] linux -- Python 3.8.5 /usr/bin/python
worker 'gw2' crashed while running 'tests/lax_test.py::LaxTest::testOpAgainstNumpyReciprocal_bfloat16[3]'
________________________________________________________________________________________ tests/lax_autodiff_test.py ________________________________________________________________________________________
[gw3] linux -- Python 3.8.5 /usr/bin/python
worker 'gw3' crashed while running 'tests/lax_autodiff_test.py::LaxAutodiffTest::testScatterGrad_shape=float32[10,5]_idxs=[[0]\n [2]\n [1]]_update=(3, 3)_dnums=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))'
____________________________________________________________________________________________ tests/lax_test.py _____________________________________________________________________________________________
[gw4] linux -- Python 3.8.5 /usr/bin/python
worker 'gw4' crashed while running 'tests/lax_test.py::LaxTest::testOpAgainstNumpyReciprocal_bfloat16[3,1]'
================================================================================ xdist: maximum crashed workers reached: 4 =================================================================================
========================================================================================= short test summary info ==========================================================================================
FAILED tests/api_test.py::CustomJVPTest::test_closure_with_vmap
FAILED tests/lax_autodiff_test.py::LaxAutodiffTest::testScatterGrad_shape=float32[10]_idxs=[[0]
 [0]
 [0]]_update=(3, 2)_dnums=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(), scatter_dims_to_operand_dims=(0,))
FAILED tests/lax_test.py::LaxTest::testOpAgainstNumpyReciprocal_bfloat16[3]
FAILED tests/lax_autodiff_test.py::LaxAutodiffTest::testScatterGrad_shape=float32[10,5]_idxs=[[0]
 [2]
 [1]]_update=(3, 3)_dnums=ScatterDimensionNumbers(update_window_dims=(1,), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
FAILED tests/lax_test.py::LaxTest::testOpAgainstNumpyReciprocal_bfloat16[3,1]
ERROR tests/host_callback_to_tf_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors...
ERROR tests/image_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generate...
ERROR tests/lax_vmap_test.py
ERROR tests/profiler_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from gener...
ERROR tests/host_callback_to_tf_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors...
ERROR tests/image_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generate...
ERROR tests/profiler_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from gener...
ERROR tests/host_callback_to_tf_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors...
ERROR tests/image_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generate...
ERROR tests/profiler_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from gener...
ERROR tests/host_callback_to_tf_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors...
ERROR tests/image_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generate...
ERROR tests/profiler_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from gener...
ERROR tests/host_callback_to_tf_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors...
ERROR tests/image_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from generate...
ERROR tests/profiler_test.py - DeprecationWarning: Call to deprecated create function FileDescriptor(). Note: Create unlinked descriptors is going to go away. Please use get/find descriptors from gener...
===================================================================== 5 failed, 241 passed, 11 skipped, 16 errors in 337.44s (0:05:37) =====================================================================
root@4ff274d37bc6:/tmp/jax#

@hawkinsp
Copy link
Member

The warning about FileDescriptor()s actually comes from TensorFlow, not JAX. The tests in question are either using TensorFlow features (e.g., its profiler), testing interoperability with TensorFlow, or testing a JAX implementation returns the same output as a TensorFlow implementation.

The segmentation faults are more concerning. Are you running these tests on CPU or GPU? By default, they are probably GPU tests. You can prevent GPU usage by setting CUDA_VISIBLE_DEVICES= (i.e., empty). If you want to run GPU tests, we could test the hypothesis about unified memory by setting XLA_PYTHON_CLIENT_ALLOCATOR=platform (see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html )?

@jakevdp jakevdp added open Issues intentionally left open, with no schedule for next steps. bug Something isn't working and removed enhancement New feature or request labels Mar 3, 2021
@hawkinsp
Copy link
Member

I tried running the tests that failed with a segfault under ARM emulation and they don't fail at head for me. I'm going to guess that they are fixed at head.

If you can still reproduce these segfaults at head, please reopen the bug!

@shawwn
Copy link
Contributor

shawwn commented Oct 17, 2021

Hiya,

Looks like this bug is back in the latest JAX.

On a fresh MacBook Pro M1 (native, not rosetta), I installed numpy and scipy, then pip3 install -U jax.

I'm getting the crash about 'cyclone' is not a recognized processor for this target (ignoring processor:

spresser@Shawns-MBP ml % python3
Python 3.9.7 (default, Sep  3 2021, 04:31:11)
[Clang 12.0.5 (clang-1205.0.22.9)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as np
/Users/spresser/ml/jax/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
>>> jnp.zeros((1,2))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'jnp' is not defined
>>> import jax.numpy as jnp
>>> jnp.zeros((1,2))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      python3

Since this is the only github issue that comes up related to JAX referencing this cyclone crash, I decided to post this here. What would the steps be for me to debug this?

@tigerneil
Copy link
Contributor

[I 15:10:29.696 NotebookApp] Replaying 3 buffered messages
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!

image

@d0k
Copy link
Member

d0k commented Oct 26, 2021

The cyclone error you're seeing means that LLVM was configured to target X86 instead of arm64.

Sadly I don't have an M1 mac to verify, but it's very likely that you're hitting the BUILD file problem that tensorflow/tensorflow@cd76ed3 is supposed to fix. Does it still occur after that change?

The proper fix for that is blocked by TF supporting older versions of bazel, but the workaround should be enough so it recognizes your machine as arm64.

@shawwn
Copy link
Contributor

shawwn commented Oct 27, 2021

After many hours, I'm reasonably confident that neither tensorflow/tensorflow@cd76ed3 nor tensorflow/tensorflow@e24a3b5 fix the problem.

How can I hack the tensorflow codebase to force arm64? Are you sure that third_party/llvm/macos_build_fix.patch is being incorporated into the python3 build/build.py process?


My build process was:

  1. git clone https://github.com/tensorflow/tensorflow ~/ml/tensorflow
  2. git clone https://github.com/google/jax ~/ml/jax
  3. ln -s $HOME/ml/tensorflow ~/ml/jax/tensorflow
  4. Patch ~/ml/jax/WORKSPACE as follows:
diff --git a/WORKSPACE b/WORKSPACE
index f4d50e5c..83b56c7c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,24 +1,24 @@
 load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
 
-# To update TensorFlow to a new revision,
-# a) update URL and strip_prefix to the new git commit hash
-# b) get the sha256 hash of the commit by running:
-#    curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
-#    and update the sha256 with the result.
-http_archive(
-    name = "org_tensorflow",
-    sha256 = "b2c8b912e7be71306ab6fee063fb4ec1dfbe7158e7e8469d674f8af6583434d4",
-    strip_prefix = "tensorflow-e98b052c08e5d1e7906ac2f6caf95c51a1e04985",
-    urls = [
-        "https://github.com/tensorflow/tensorflow/archive/e98b052c08e5d1e7906ac2f6caf95c51a1e04985.tar.gz",
-    ],
-)
+# # To update TensorFlow to a new revision,
+# # a) update URL and strip_prefix to the new git commit hash
+# # b) get the sha256 hash of the commit by running:
+# #    curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
+# #    and update the sha256 with the result.
+# http_archive(
+#     name = "org_tensorflow",
+#     sha256 = "b2c8b912e7be71306ab6fee063fb4ec1dfbe7158e7e8469d674f8af6583434d4",
+#     strip_prefix = "tensorflow-e98b052c08e5d1e7906ac2f6caf95c51a1e04985",
+#     urls = [
+#         "https://github.com/tensorflow/tensorflow/archive/e98b052c08e5d1e7906ac2f6caf95c51a1e04985.tar.gz",
+#     ],
+# )
 
 # For development, one can use a local TF repository instead.
-# local_repository(
-#    name = "org_tensorflow",
-#    path = "tensorflow",
-# )
+local_repository(
+   name = "org_tensorflow",
+   path = "tensorflow",
+)
 
 load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
 pocketfft()

Lastly:

  • cd ~/ml/jax/tensorflow
  • git checkout cd76ed3114f5d3e5f387dbc04de63891da958861
  • cd ~/ml/jax
  • USE_BAZEL_VERSION=4.2.0rc3 python3 build/build.py
  • pip3 install ~/ml/jax/dist/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl --force-reinstall

Then python3 -c 'import jax.numpy as jnp; jnp.zeros((1,2))' crashes with the same problem:

$ python3 -c 'import jax.numpy as jnp; jnp.zeros((1,2))'
/Users/spresser/ml/jax/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
'cyclone' is not a recognized processor for this target (ignoring processor)
LLVM ERROR: 64-bit code requested on a subtarget that doesn't support it!
zsh: abort      python3 -c 'import jax.numpy as jnp; jnp.zeros((1,2))'

I tried the above steps with commit tensorflow/tensorflow@e24a3b5 as well, with no luck.

I'm pretty sure that tensorflow commit tensorflow/tensorflow@8cc3ffa was working on an M1, because about 4 months ago I built Jax on my old M1 Air laptop for Python 3.8, and the Jax repo at that time was using that tensorflow commit: https://github.com/shawwn/jax/blob/m1/WORKSPACE#L11

So whatever the problem is, it happened somewhere between tensorflow/tensorflow@8cc3ffa and HEAD.

... which is a gargantuan diff: tensorflow/tensorflow@8cc3ffa...master

@shawwn
Copy link
Contributor

shawwn commented Oct 31, 2021

I've confirmed that jaxlib built from tensorflow/tensorflow@8cc3ffa works perfectly on an M1.

If anyone suffering this problem (like @tigerneil) wants a temporary fix -- assuming you trust me -- then you can install my jaxlib wheel for Python 3.9 on M1's:

wget https://battle.shawwn.com/dist/jaxlib/jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl
pip3 install jaxlib-0.1.74-cp39-none-macosx_11_0_arm64.whl --force-reinstall

At that point, you should be able to run python3 -c 'import jax.numpy as jnp; print(jnp.zeros((1,2)))' without any crashes:

$ python3 -c 'import jax.numpy as jnp; print(jnp.zeros((1,2)))'
/Users/spresser/ml/jax/jax/_src/lib/__init__.py:32: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[0. 0.]]

Note that this is actually jaxlib version 0.1.68, even though it advertises itself as 0.1.74. Meaning, you can install it, but I have no idea what might break. Things seem to work well enough for me, which is all I care about.

(I'll post about any problems I run into.)


Once I get some time, I'll try doing a git bisect between tensorflow/tensorflow@8cc3ffa and tensorflow HEAD to track down which commit introduced the problem.

@shawwn
Copy link
Contributor

shawwn commented Feb 15, 2022

Happy to report that the latest jaxlib seems to build and run fine on M1!

I built from jax commit d569440

and tensorflow commit tensorflow/tensorflow@071a34e

My build process was:

git clone https://github.com/tensorflow/tensorflow ~/ml/tensorflow
git clone https://github.com/google/jax ~/ml/jax
ln -s $HOME/ml/tensorflow ~/ml/jax/tensorflow
cd ~/ml/jax/tensorflow
git checkout 071a34e9ff1ed97a9eb1d323645728841fbcb7f1
cd ~/ml/jax
git checkout d5694402bc1d9c494513e51f825591584d66157e

patch WORKSPACE:

diff --git a/WORKSPACE b/WORKSPACE
index 70f310f9..dc6c271c 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,24 +1,24 @@
 load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
 
-# To update TensorFlow to a new revision,
-# a) update URL and strip_prefix to the new git commit hash
-# b) get the sha256 hash of the commit by running:
-#    curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
-#    and update the sha256 with the result.
-http_archive(
-    name = "org_tensorflow",
-    sha256 = "7756b69b4a2a036ad7a4a8478f8bd7d69d0026d9c8c7fe8e8f1ae6205e978719",
-    strip_prefix = "tensorflow-968a1751ef6ccadc30ac6bd0f0be5056ac0e9288",
-    urls = [
-        "https://github.com/tensorflow/tensorflow/archive/968a1751ef6ccadc30ac6bd0f0be5056ac0e9288.tar.gz",
-    ],
-)
+# # To update TensorFlow to a new revision,
+# # a) update URL and strip_prefix to the new git commit hash
+# # b) get the sha256 hash of the commit by running:
+# #    curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
+# #    and update the sha256 with the result.
+# http_archive(
+#     name = "org_tensorflow",
+#     sha256 = "7756b69b4a2a036ad7a4a8478f8bd7d69d0026d9c8c7fe8e8f1ae6205e978719",
+#     strip_prefix = "tensorflow-968a1751ef6ccadc30ac6bd0f0be5056ac0e9288",
+#     urls = [
+#         "https://github.com/tensorflow/tensorflow/archive/968a1751ef6ccadc30ac6bd0f0be5056ac0e9288.tar.gz",
+#     ],
+# )
 
 # For development, one can use a local TF repository instead.
-# local_repository(
-#    name = "org_tensorflow",
-#    path = "tensorflow",
-# )
+local_repository(
+   name = "org_tensorflow",
+   path = "tensorflow",
+)
 
 load("//third_party/pocketfft:workspace.bzl", pocketfft = "repo")
 pocketfft()

Build jaxlib with caching:

cd ~/ml/jax/build
USE_BAZEL_VERSION=5.0.0 bazel run --action_env=PATH --remote_accept_cached=true --spawn_strategy=standalone --remote_local_fallback=false --remote_timeout=600 --verbose_failures=true --config=mkl_open_source_only :build_wheel -- --output_path=../dist --cpu=arm64

Install jaxlib:

pip3 install ~/ml/jax/dist/jaxlib-0.3.1-cp39-none-macosx_11_0_arm64.whl --no-deps --force-reinstall

Use the local jax repo:

cd ~/ml/jax
python3 setup.py develop --no-deps

Install scipy using the workaround here: scipy/scipy#13409

brew install openblas
pip install cython pybind11 pythran numpy
OPENBLAS=$(brew --prefix openblas) CFLAGS="-falign-functions=8 ${CFLAGS}" pip3 install --no-use-pep517 scipy==1.7.0

Test it out:

$ python3
Python 3.9.10 (main, Jan 15 2022, 11:40:36)
[Clang 13.0.0 (clang-1300.0.29.3)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
~/ml/jax/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.
  warnings.warn("JAX on Mac ARM machines is experimental and minimally tested. "
jax.__version__
>>> jax.__version__
'0.3.1'
>>> import jaxlib
>>> jaxlib.__version__
'0.3.1'
>>> import jax.numpy as jnp; print(jnp.zeros((1,2)))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[[0. 0.]]

Thanks for fixing this!

nluehr pushed a commit to NVIDIA/tensorflow that referenced this issue May 20, 2022
I don't have an environment to test this, but I believe it should be enough to
get the JIT to work.

This came up in
google/jax#5679

PiperOrigin-RevId: 356578451
Change-Id: I99a2aa0e87739b9edce81074fce1ca5c0bd25115
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working open Issues intentionally left open, with no schedule for next steps.
Projects
None yet
Development

No branches or pull requests

8 participants
@d0k @shawwn @zhangqiaorjc @hawkinsp @jakevdp @tigerneil @girgink and others