Skip to content

Commit

Permalink
Fix barrier seg fault and added test to mix it with multiple collecti…
Browse files Browse the repository at this point in the history
…ves (horovod#3313)

Signed-off-by: TJ <tix@uber.com>
  • Loading branch information
TJ Xu authored and tkhanna1996 committed Dec 16, 2021
1 parent 6f6b33a commit 70c7259
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
13 changes: 5 additions & 8 deletions horovod/common/controller.cc
Expand Up @@ -893,6 +893,10 @@ void Controller::FuseResponses(std::deque<Response>& responses,
while (!responses.empty()) {

auto& new_response = responses.front();
if (new_response.response_type() == Response::ResponseType::BARRIER ||
new_response.response_type() == Response::ResponseType::JOIN) {
break;
}
assert(new_response.tensor_names().size() == 1);
const auto& new_entry =
tensor_queue_.GetTensorEntry(new_response.tensor_names()[0]);
Expand Down Expand Up @@ -981,14 +985,7 @@ bool Controller::IncrementTensorCount(const Request& msg, int joined_size) {
timeline_.NegotiateStart(name, msg.request_type());
} else {
std::vector<Request>& messages = table_iter->second;
if(msg.request_type() == Request::BARRIER) {
if(tensor_queue_.IsTensorPresentInTable(name)) {
messages.push_back(msg);
}
}
else {
messages.push_back(msg);
}
messages.push_back(msg);
}

timeline_.NegotiateRankReady(name, msg.request_rank());
Expand Down
33 changes: 27 additions & 6 deletions test/parallel/test_torch.py
Expand Up @@ -623,15 +623,15 @@ def test_horovod_allreduce_duplicate_name_error(self):
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
hvd.barrier()
if rank > 0:
hvd.allreduce_async(tensor, name='duplicate_name')
try:
hvd.allreduce_async(tensor, name='duplicate_name')
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")
hvd.barrier()

def test_horovod_allreduce_grad(self):
"""Test the correctness of the allreduce gradient."""
Expand Down Expand Up @@ -1239,15 +1239,15 @@ def test_horovod_allgather_duplicate_name_error(self):
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
hvd.barrier()
if rank > 0:
hvd.allgather_async(tensor, name='duplicate_name')
try:
hvd.allgather_async(tensor, name='duplicate_name')
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")
hvd.barrier()

def test_horovod_allgather_grad(self):
"""Test the correctness of the allgather gradient."""
Expand Down Expand Up @@ -1559,15 +1559,15 @@ def test_horovod_broadcast_duplicate_name_error(self):
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
hvd.barrier()
if rank > 0:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
try:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")
hvd.barrier()

def test_horovod_broadcast_grad(self):
"""Test the correctness of the broadcast gradient."""
Expand Down Expand Up @@ -3295,5 +3295,26 @@ def test_global_barrier_op(self):

self.assertTrue(barrier_time >= 5)

def test_barrier_with_multiple_collectives(self):
"""Test barrier mixed with other collectives"""
hvd.init()
rank = hvd.rank()

bcast_tensor = torch.eye(3)
bcast_handle = hvd.broadcast_async(bcast_tensor, root_rank=0)

allgather_tensor_1 = torch.eye(5)
allgather_tensor_2 = torch.zeros([5, 5])
allgather1_handle = hvd.allgather_async(allgather_tensor_1)
allgather2_handle = hvd.allgather_async(allgather_tensor_2)

allreduce_tensor = torch.eye(5)
allreduce_handle = hvd.allreduce_async(allreduce_tensor)

hvd.barrier()

result = hvd.synchronize(allreduce_handle)
self.assertTrue(torch.equal(result, allreduce_tensor))

if __name__ == "__main__":
unittest.main()

0 comments on commit 70c7259

Please sign in to comment.