Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix barrier seg fault and added test to mix it with multiple collectives #3313

Merged
merged 2 commits into from Dec 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 5 additions & 8 deletions horovod/common/controller.cc
Expand Up @@ -893,6 +893,10 @@ void Controller::FuseResponses(std::deque<Response>& responses,
while (!responses.empty()) {

auto& new_response = responses.front();
if (new_response.response_type() == Response::ResponseType::BARRIER ||
new_response.response_type() == Response::ResponseType::JOIN) {
break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be safer to have a continue here, rather than a break?

Copy link
Collaborator Author

@Tixxx Tixxx Dec 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think using break is safer here than continue since using continue will keep this fusion logic going. I think once we see a barrier response, it means we have reached an end of a control block, so we don't want to fuse the responses after the barrier(if there's any). Let me know if this makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, and this applies to join as well.

This loop is specific to allgather. Would it make sense to break the fusion loop in the same way for allreduce and adasum? (Even if it's not necessary there to shield us from accessing invalid pointers)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for determining output size of allreduce operations are fairly straightforward, they are directly using the tensor_sizes field in the response object which is safe. Allgather is a special case since we need to inspect each dimensionality, so it needs a reference to the tensor itself.
We need to re-visit this logic once we support fusion for other ops.

}
assert(new_response.tensor_names().size() == 1);
const auto& new_entry =
tensor_queue_.GetTensorEntry(new_response.tensor_names()[0]);
Expand Down Expand Up @@ -981,14 +985,7 @@ bool Controller::IncrementTensorCount(const Request& msg, int joined_size) {
timeline_.NegotiateStart(name, msg.request_type());
} else {
std::vector<Request>& messages = table_iter->second;
if(msg.request_type() == Request::BARRIER) {
if(tensor_queue_.IsTensorPresentInTable(name)) {
messages.push_back(msg);
}
}
else {
messages.push_back(msg);
}
messages.push_back(msg);
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved
}

timeline_.NegotiateRankReady(name, msg.request_rank());
Expand Down
33 changes: 27 additions & 6 deletions test/parallel/test_torch.py
Expand Up @@ -623,15 +623,15 @@ def test_horovod_allreduce_duplicate_name_error(self):
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
hvd.barrier()
if rank > 0:
hvd.allreduce_async(tensor, name='duplicate_name')
try:
hvd.allreduce_async(tensor, name='duplicate_name')
assert False, 'hvd.allreduce_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")
hvd.barrier()

def test_horovod_allreduce_grad(self):
"""Test the correctness of the allreduce gradient."""
Expand Down Expand Up @@ -1239,15 +1239,15 @@ def test_horovod_allgather_duplicate_name_error(self):
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
hvd.barrier()
if rank > 0:
hvd.allgather_async(tensor, name='duplicate_name')
try:
hvd.allgather_async(tensor, name='duplicate_name')
assert False, 'hvd.allgather_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")
hvd.barrier()

def test_horovod_allgather_grad(self):
"""Test the correctness of the allgather gradient."""
Expand Down Expand Up @@ -1559,15 +1559,15 @@ def test_horovod_broadcast_duplicate_name_error(self):
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([1]), name="synch1")
hvd.barrier()
if rank > 0:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
try:
hvd.broadcast_async(tensor, name='duplicate_name', root_rank=0)
assert False, 'hvd.broadcast_async did not throw error'
except (torch.FatalError, ValueError):
pass
hvd.allreduce(torch.FloatTensor([2]), name="synch2")
hvd.barrier()

def test_horovod_broadcast_grad(self):
"""Test the correctness of the broadcast gradient."""
Expand Down Expand Up @@ -3295,5 +3295,26 @@ def test_global_barrier_op(self):

self.assertTrue(barrier_time >= 5)

def test_barrier_with_multiple_collectives(self):
"""Test barrier mixed with other collectives"""
hvd.init()
rank = hvd.rank()

bcast_tensor = torch.eye(3)
bcast_handle = hvd.broadcast_async(bcast_tensor, root_rank=0)

allgather_tensor_1 = torch.eye(5)
allgather_tensor_2 = torch.zeros([5, 5])
allgather1_handle = hvd.allgather_async(allgather_tensor_1)
allgather2_handle = hvd.allgather_async(allgather_tensor_2)

allreduce_tensor = torch.eye(5)
allreduce_handle = hvd.allreduce_async(allreduce_tensor)

hvd.barrier()

result = hvd.synchronize(allreduce_handle)
self.assertTrue(torch.equal(result, allreduce_tensor))

if __name__ == "__main__":
unittest.main()