Skip to content

Commit

Permalink
Fix conflict with horovod#2963 after rebase, fix edge case exception …
Browse files Browse the repository at this point in the history
…in hvd.init()

Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
  • Loading branch information
maxhgerlach committed Jul 6, 2021
1 parent 4a3fed3 commit 8e89e25
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
10 changes: 9 additions & 1 deletion horovod/common/basics.py
Expand Up @@ -129,7 +129,15 @@ def init(self, comm: Optional[Union[Sequence[int], MPI.Comm]] = None,
raise ValueError(
"Horovod initialization failed. Please check log messages above for a more descriptive error.")

_init_process_sets(process_sets)
try:
_init_process_sets(process_sets)
except ValueError as e:
if (len(e.args) > 0 and isinstance(e.args[0], str) and
"Horovod has not been initialized properly" in e.args[0]):
# Horovod is already shutting down
return
else:
raise e

for ps_idx, ps in enumerate(process_sets):
if ps.process_set_id is None:
Expand Down
2 changes: 1 addition & 1 deletion horovod/common/ops/collective_operations.cc
Expand Up @@ -294,7 +294,7 @@ AlltoallOp::AlltoallOp(HorovodGlobalState* global_state)
: HorovodOp(global_state) {}

// Join
JoinOp::JoinOp(HorovodGlobalState* global_state) : global_state_(global_state) {}
JoinOp::JoinOp(HorovodGlobalState* global_state) : HorovodOp(global_state) {}

Status JoinOp::Execute(std::vector<TensorTableEntry>& entries,
const Response& response, ProcessSet& process_set) {
Expand Down
14 changes: 8 additions & 6 deletions horovod/common/ops/collective_operations.h
Expand Up @@ -272,19 +272,21 @@ class AlltoallOp : public HorovodOp {
}
};

class JoinOp {
// JoinOp does not derive from HorovodOp as its function Execute has a
// different signature and because it does comparatively little.
class JoinOp : public HorovodOp {
public:
explicit JoinOp(HorovodGlobalState* global_state);

virtual ~JoinOp() = default;

Status Execute(std::vector<TensorTableEntry>& entries,
const Response& response) override {
throw std::logic_error(
"Call JoinOp::Execute() overload with extra process_set argument.");
}

// Note the different signature because we need a process_set argument.
virtual Status Execute(std::vector<TensorTableEntry>& entries,
const Response& response, ProcessSet& process_set);

protected:
HorovodGlobalState* global_state_;
};

class ErrorOp : public HorovodOp {
Expand Down

0 comments on commit 8e89e25

Please sign in to comment.