Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod Ops not XLA compatible #2590

Open
jtchilders opened this issue Jan 12, 2021 · 7 comments
Open

Horovod Ops not XLA compatible #2590

jtchilders opened this issue Jan 12, 2021 · 7 comments
Labels

Comments

@jtchilders
Copy link

Environment:

  1. Framework: TensorFlow
  2. Framework version: 2.4.0
  3. Horovod version: 0.21.1
  4. MPI version: openmpi-4.0.5
  5. CUDA version: 11.0
  6. NCCL version: nccl_2.8.3-1+cuda11.0_x86_64
  7. Python version: 3.8.5
  8. Spark / PySpark version: NA
  9. OS and version: 5.3.0-62-generic To run on 4 machines with 1 GPUs each using Open MPI #56~18.04.1-Ubuntu
  10. GCC version: 7.5.0
  11. CMake version: 3.18.2

I've run into errors when trying to XLA compile my Tensorflow train/test steps. In my custom model, if I use

@tf.function(jit_compile=True)
def train_step(...):

to force compilation of the training operations I can run successfully without Horovod with 1 process. Then when I try to run with Horovod, I receive errors like:

The op is created at:
File "main.py", line 368, in <module>
  main()
File "main.py", line 188, in main
  epoch_loop.one_train_epoch(config,trainds,net,
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 9, in one_train_epoch
  return one_epoch(config,dataset,net,train_step,loss_func,opt,epoch_num,tbwriter,batches_per_epoch,True)
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 74, in one_epoch
  loss_value,logits = step_func(net,loss_func,inputs,labels,weights,opt,first_batch,hvd)
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 244, in train_step
  if hvd and first_batch:
File "/gpfs/mira-home/parton/git/atlas_dgcnn/epoch_loop.py", line 246, in train_step
  hvd.broadcast_variables(opt.variables(), root_rank=root_rank)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/functions.py", line 56, in broadcast_variables
  return broadcast_group(variables, root_rank)
File "/tmp/tmp4h_gfbt8.py", line 53, in broadcast_group
  retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/tmp/tmp4h_gfbt8.py", line 53, in <listcomp>
  retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/mpi_ops.py", line 251, in broadcast
  return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
File "<string>", line 423, in horovod_broadcast
HorovodBroadcast_Adam_dgcnn_conv_bn_layer_12_batch_normalization_14_beta_v_0: unsupported op: No registered 'HorovodBroadcast' OpKernel for XLA_GPU_JIT devices compatible with node {{node HorovodBroadcast_Adam_dgcnn_conv_bn_layer_12_batch_normalization_14_beta_v_0}}

You can see my code here:
https://github.com/jtchilders/atlas_dgcnn

I can reproduce the issue using your example:
examples/tensorflow2_mnist.py
by simply changing @tf.function to @tf.function(jit_compile=True)
And if I run
mpirun -n $RANKS -npernode $PPN python tensorflow2_mnist.py

I see a similar error like this:

The op is created at:
File "tensorflow2_mnist.py", line 84, in <module>
  loss_value = training_step(images, labels, batch == 0)
File "tensorflow2_mnist.py", line 75, in training_step
  if first_batch:
File "tensorflow2_mnist.py", line 77, in training_step
  hvd.broadcast_variables(opt.variables(), root_rank=0)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/functions.py", line 56, in broadcast_variables
  return broadcast_group(variables, root_rank)
File "/tmp/tmp_15ypkpb.py", line 53, in broadcast_group
  retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/tmp/tmp_15ypkpb.py", line 53, in <listcomp>
  retval__2 = ag__.converted_call(ag__.ld(tf).group, tuple([ag__.converted_call(ag__.ld(var).assign, (ag__.converted_call(ag__.ld(broadcast), (ag__.ld(var), ag__.ld(root_rank)), None, fscope_2),), None, fscope_2) for var in ag__.ld(variables)]), None, fscope_2)
File "/home/parton/.local/lib/python3.8/site-packages/horovod/tensorflow/mpi_ops.py", line 251, in broadcast
  return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
File "<string>", line 423, in horovod_broadcast [Op:__inference_training_step_933]
@jtchilders jtchilders added the bug label Jan 12, 2021
@tgaddair
Copy link
Collaborator

Hey @jtchilders, thanks for raising this issue. This is a known incompatibility between XLA and Horovod at the moment, we're actively working with Nvidia to come up with a way to make AsyncOps work with Horovod. We'll use this issue to track.

cc @DEKHTIARJonathan

@tradingjunkie
Copy link

What is the current best practice for distributed training with hovorod on xla compiled code? Is there any workaround?

@trentlo
Copy link
Contributor

trentlo commented Jun 17, 2021

What is the current best practice for distributed training with hovorod on xla compiled code? Is there any workaround?

Currently, @tf.function(jit_compile=True) does not work with Horovod. However, you can run XLA with Horovod by setting TF_XLA_FLAGS="--tf_xla_auto_jit=2". In this auto-clustering mode, Horovod ops remain in Tensorflow while XLA-compilable ops would be auto-clustered and compiled by XLA. This mode can work now.

FYI. I already have finished XLA implementations of some Horovod ops (not yet upstreamed). So, with the new implementations, @tf.function(jit_compile=True) can work with Horovod, so long as the corresponding XLA ops are implemented. My changes depend on an open PR to enable event-based syncs. I will post my PR after the event-based sync PR is merged.

@tradingjunkie
Copy link

I still hit the following error, even with those environment settings

tensorflow.python.framework.errors_impl.InvalidArgumentError: Detected unsupported operations when trying to compile graph __inference_train_step_keras_83604[_XlaMustCompile=true,config_proto="\n\007\n\0...02\001\000",executor_type=""] on XLA_GPU_JIT: HorovodAllreduce (No registered 'HorovodAllreduce' OpKernel for XLA_GPU_JIT devices compatible with node {{node HorovodAllreduce_grads_0}}){{node HorovodAllreduce_grads_0}}

@trentlo
Copy link
Contributor

trentlo commented Jun 17, 2021

I still hit the following error, even with those environment settings

tensorflow.python.framework.errors_impl.InvalidArgumentError: Detected unsupported operations when trying to compile graph __inference_train_step_keras_83604[_XlaMustCompile=true,config_proto="\n\007\n\0...02\001\000",executor_type=""] on XLA_GPU_JIT: HorovodAllreduce (No registered 'HorovodAllreduce' OpKernel for XLA_GPU_JIT devices compatible with node {{node HorovodAllreduce_grads_0}}){{node HorovodAllreduce_grads_0}}

You need to replace all of your @tf.function(jit_compile=True) with only @tf.function() to disable the compile mode and only rely on autoclustering (by setting TF_XLA_FLAGS="--tf_xla_auto_jit=2").

@DEKHTIARJonathan
Copy link
Collaborator

@tgaddair shall we close ?

@trentlo
Copy link
Contributor

trentlo commented Oct 22, 2021

For clarity, this PR completes only HorovodAllreduce for XLA (along with many needed infrastructure changes).

There are other Horovod ops to be added. ^^" I noticed in the description in the issue report that HorovodBroadcast is missing, which is not yet supported. On the other hand, we might not want to wait for the completion of all ops to close this issue.

Currently I don't have a plan to implement all of the Horovod ops for XLA. However, I have two more ops implemented in my local repo, HorovodBroadcast and HorovodGroupedAllreduces (with contributions from @romerojosh ). Perhaps let me upstream them and close this issue afterwards as these are the most commonly used ops. Those codes are ready to upstream but let me find a slack of my time to do that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

No branches or pull requests

5 participants