forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rabit-inl.h
229 lines (221 loc) · 6.66 KB
/
rabit-inl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
/*!
* Copyright (c) 2014-2019 by Contributors
* \file rabit-inl.h
* \brief implementation of inline template function for rabit interface
*
* \author Tianqi Chen
*/
#ifndef RABIT_INTERNAL_RABIT_INL_H_
#define RABIT_INTERNAL_RABIT_INL_H_
// use engine for implementation
#include <vector>
#include <string>
#include "rabit/internal/io.h"
#include "rabit/internal/utils.h"
#include "rabit/rabit.h"
namespace rabit {
namespace engine {
namespace mpi {
// template function to translate type to enum indicator
template<typename DType>
inline DataType GetType();
template<>
inline DataType GetType<char>() {
return kChar;
}
template<>
inline DataType GetType<unsigned char>() {
return kUChar;
}
template<>
inline DataType GetType<int>() {
return kInt;
}
template<>
inline DataType GetType<unsigned int>() { // NOLINT(*)
return kUInt;
}
template<>
inline DataType GetType<long>() { // NOLINT(*)
return kLong;
}
template<>
inline DataType GetType<unsigned long>() { // NOLINT(*)
return kULong;
}
template<>
inline DataType GetType<float>() {
return kFloat;
}
template<>
inline DataType GetType<double>() {
return kDouble;
}
template<>
inline DataType GetType<long long>() { // NOLINT(*)
return kLongLong;
}
template<>
inline DataType GetType<unsigned long long>() { // NOLINT(*)
return kULongLong;
}
} // namespace mpi
} // namespace engine
namespace op {
struct Max {
static const engine::mpi::OpType kType = engine::mpi::kMax;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
if (dst < src) dst = src;
}
};
struct Min {
static const engine::mpi::OpType kType = engine::mpi::kMin;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
if (dst > src) dst = src;
}
};
struct Sum {
static const engine::mpi::OpType kType = engine::mpi::kSum;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst += src;
}
};
struct BitOR {
static const engine::mpi::OpType kType = engine::mpi::kBitwiseOR;
template<typename DType>
inline static void Reduce(DType &dst, const DType &src) { // NOLINT(*)
dst |= src;
}
};
template<typename OP, typename DType>
inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &dtype) {
const DType* src = static_cast<const DType*>(src_);
DType* dst = (DType*)dst_; // NOLINT(*)
for (int i = 0; i < len; i++) {
OP::Reduce(dst[i], src[i]);
}
}
} // namespace op
// initialize the rabit engine
inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv);
}
// finalize the rabit engine
inline bool Finalize() {
return engine::Finalize();
}
// get the rank of the previous worker in ring topology
inline int GetRingPrevRank() {
return engine::GetEngine()->GetRingPrevRank();
}
// get the rank of current process
inline int GetRank() {
return engine::GetEngine()->GetRank();
}
// the the size of the world
inline int GetWorldSize() {
return engine::GetEngine()->GetWorldSize();
}
// whether rabit is distributed
inline bool IsDistributed() {
return engine::GetEngine()->IsDistributed();
}
// get the name of current processor
inline std::string GetProcessorName() {
return engine::GetEngine()->GetHost();
}
// broadcast data to all other nodes from root
inline void Broadcast(void *sendrecv_data, size_t size, int root) {
engine::GetEngine()->Broadcast(sendrecv_data, size, root);
}
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root) {
size_t size = sendrecv_data->size();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->size() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root);
}
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
// perform inplace Allreduce
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *arg),
void *prepare_arg) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, prepare_fun, prepare_arg);
}
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
inline void InvokeLambda(void *fun) {
(*static_cast<std::function<void()>*>(fun))();
}
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun) {
engine::Allreduce_(sendrecvbuf, sizeof(DType), count, op::Reducer<OP, DType>,
engine::mpi::GetType<DType>(), OP::kType, InvokeLambda, &prepare_fun);
}
// Performs inplace Allgather
template<typename DType>
inline void Allgather(DType *sendrecvbuf,
size_t totalSize,
size_t beginIndex,
size_t sizeNodeSlice,
size_t sizePrevSlice) {
engine::GetEngine()->Allgather(sendrecvbuf, totalSize * sizeof(DType), beginIndex * sizeof(DType),
(beginIndex + sizeNodeSlice) * sizeof(DType),
sizePrevSlice * sizeof(DType));
}
#endif // C++11
// print message to the tracker
inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}
#ifndef RABIT_STRICT_CXX98_
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
msg.resize(strlen(msg.c_str()));
TrackerPrint(msg);
}
#endif // RABIT_STRICT_CXX98_
// load latest check point
inline int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
}
// checkpoint the model, meaning we finished a stage of execution
inline void CheckPoint(const Serializable *global_model,
const Serializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model);
}
// lazy checkpoint the model, only remember the pointer to global_model
inline void LazyCheckPoint(const Serializable *global_model) {
engine::GetEngine()->LazyCheckPoint(global_model);
}
// return the version number of currently stored model
inline int VersionNumber() {
return engine::GetEngine()->VersionNumber();
}
} // namespace rabit
#endif // RABIT_INTERNAL_RABIT_INL_H_