Skip to content

Commit

Permalink
Remove stop process. (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 5, 2020
1 parent e6cd74e commit 4acdd7c
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 54 deletions.
23 changes: 5 additions & 18 deletions include/rabit/internal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ namespace utils {
/*! \brief error message buffer length */
const int kPrintBuffer = 1 << 12;

/*! \brief we may want to keep the process alive when there are multiple workers
* co-locate in the same process */
extern bool STOP_PROCESS_ON_ERROR;

/* \brief Case-insensitive string comparison */
inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
#ifdef _MSC_VER
Expand All @@ -89,26 +85,17 @@ inline bool StringToBool(const char* s) {
* \param msg error message
*/
inline void HandleAssertError(const char *msg) {
if (STOP_PROCESS_ON_ERROR) {
fprintf(stderr, "AssertError:%s, shutting down process\n", msg);
exit(-1);
} else {
fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
fprintf(stderr,
"AssertError:%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
/*!
* \brief handling of Check error, caused by inappropriate input
* \param msg error message
*/
inline void HandleCheckError(const char *msg) {
if (STOP_PROCESS_ON_ERROR) {
fprintf(stderr, "%s, shutting down process\n", msg);
exit(-1);
} else {
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
inline void HandlePrint(const char *msg) {
printf("%s", msg);
Expand Down
15 changes: 0 additions & 15 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
#include <map>

namespace rabit {

namespace utils {
bool STOP_PROCESS_ON_ERROR = true;
}

namespace engine {
// constructor
AllreduceBase::AllreduceBase(void) {
Expand Down Expand Up @@ -48,7 +43,6 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR");
}

// initialization function
Expand Down Expand Up @@ -200,15 +194,6 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) {
if (!strcmp(val, "true")) {
rabit::utils::STOP_PROCESS_ON_ERROR = true;
} else if (!strcmp(val, "false")) {
rabit::utils::STOP_PROCESS_ON_ERROR = false;
} else {
throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR");
}
}
if (!strcmp(name, "rabit_bootstrap_cache")) {
rabit_bootstrap_cache = utils::StringToBool(val);
}
Expand Down
13 changes: 7 additions & 6 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
* \param _caller caller function name used to generate unique cache key
*/
void AllreduceRobust::Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
Expand Down Expand Up @@ -518,8 +518,8 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
}
// execute checkpoint, note: when checkpoint existing, load will not happen
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
ActionSummary::kSpecialOp, cur_cache_seq),
"check point must return true");
ActionSummary::kSpecialOp, cur_cache_seq),
"check point must return true");
// this is the critical region where we will change all the stored models
// increase version number
version_number += 1;
Expand Down Expand Up @@ -550,8 +550,9 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
delta = utils::GetTime() - start;
// log checkpoint ack latency
if (rabit_debug) {
utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n",
rank, version_number, delta);
utils::HandleLogInfo(
"[%d] checkpoint ack finished version %d, take %f seconds\n", rank,
version_number, delta);
}
}
/*!
Expand Down
5 changes: 0 additions & 5 deletions src/engine_empty.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
#include "rabit/internal/engine.h"

namespace rabit {

namespace utils {
bool STOP_PROCESS_ON_ERROR = true;
}

namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
Expand Down
5 changes: 0 additions & 5 deletions src/engine_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
#include "rabit/internal/utils.h"

namespace rabit {

namespace utils {
bool STOP_PROCESS_ON_ERROR = true;
}

namespace engine {
/*! \brief implementation of engine using MPI */
class MPIEngine : public IEngine {
Expand Down
1 change: 1 addition & 0 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ find_package(GTest REQUIRED)
add_executable(
unit_tests
test_io.cc
test_utils.cc
allreduce_robust_test.cc
allreduce_base_test.cc
allreduce_mock_test.cc
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/allreduce_mock_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TEST(allreduce_mock, mock_allreduce)
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error);
}

TEST(allreduce_mock, mock_broadcast)
Expand All @@ -32,5 +32,5 @@ TEST(allreduce_mock, mock_broadcast)
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error);
}
7 changes: 4 additions & 3 deletions test/cpp/allreduce_mock_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <string>
#include <iostream>
#include <dmlc/logging.h>
#include "../../src/allreduce_mock.h"

TEST(allreduce_mock, mock_allreduce)
Expand All @@ -17,7 +18,7 @@ TEST(allreduce_mock, mock_allreduce)
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
EXPECT_THROW({m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr);}, dmlc::Error);
}

TEST(allreduce_mock, mock_broadcast)
Expand All @@ -32,7 +33,7 @@ TEST(allreduce_mock, mock_broadcast)
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
EXPECT_THROW({m.Broadcast(nullptr,0,0);}, dmlc::Error);
}

TEST(allreduce_mock, mock_gather)
Expand All @@ -47,5 +48,5 @@ TEST(allreduce_mock, mock_gather)
m.rank = 3;
m.version_number=13;
m.seq_counter=22;
EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error);
}
6 changes: 6 additions & 0 deletions test/cpp/test_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <gtest/gtest.h>
#include <rabit/internal/utils.h>

TEST(Utils, Assert) {
EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error);
}

0 comments on commit 4acdd7c

Please sign in to comment.