Skip to content

Commit

Permalink
Update TorchTests::test_horovod_reducescatter_duplicate_name_error fo…
Browse files Browse the repository at this point in the history
…llowing horovod#3300, horovod#3313

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
  • Loading branch information
maxhgerlach committed Dec 17, 2021
1 parent 6b7e976 commit b85e60d
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions test/parallel/test_torch.py
Expand Up @@ -3554,6 +3554,7 @@ def test_horovod_reducescatter_duplicate_name_error(self):
if _is_mac and hvd.gloo_built() and not hvd.mpi_built():
self.skipTest("ReducescatterGloo is not supported on macOS")
hvd.init()
rank = hvd.rank()
size = hvd.size()

if size == 1:
Expand All @@ -3562,13 +3563,22 @@ def test_horovod_reducescatter_duplicate_name_error(self):
dims = [17] * 3
tensor = torch.FloatTensor(*dims)

hvd.reducescatter_async(tensor, name='duplicate_name')
try:
for i in range(10):
hvd.reducescatter_async(tensor, name=f'duplicate_name')
assert False, 'hvd.reducescatter_async did not throw error'
except (torch.FatalError, ValueError):
pass
if rank == 0:
hvd.reducescatter_async(tensor, name='duplicate_name')
try:
hvd.reducescatter_async(tensor, name='duplicate_name')
assert False, 'hvd.reducescatter_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.barrier()
if rank > 0:
hvd.reducescatter_async(tensor, name='duplicate_name')
try:
hvd.reducescatter_async(tensor, name='duplicate_name')
assert False, 'hvd.reducescatter_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.barrier()


def test_horovod_reducescatter_grad(self):
Expand Down

0 comments on commit b85e60d

Please sign in to comment.