-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
socket.h
547 lines (483 loc) · 15.1 KB
/
socket.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
/*!
* Copyright (c) 2022 by XGBoost Contributors
*/
#pragma once
#if !defined(NOMINMAX) && defined(_WIN32)
#define NOMINMAX
#endif // !defined(NOMINMAX)
#include <cerrno> // errno, EINTR, EBADF
#include <climits> // HOST_NAME_MAX
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t, std::uint16_t
#include <cstring> // memset
#include <limits> // std::numeric_limits
#include <string> // std::string
#include <system_error> // std::error_code, std::system_category
#include <utility> // std::swap
#if !defined(xgboost_IS_MINGW)
#define xgboost_IS_MINGW() defined(__MINGW32__)
#endif // xgboost_IS_MINGW
#if defined(_WIN32)
#include <winsock2.h>
#include <ws2tcpip.h>
using in_port_t = std::uint16_t;
#ifdef _MSC_VER
#pragma comment(lib, "Ws2_32.lib")
#endif // _MSC_VER
#if !xgboost_IS_MINGW()
using ssize_t = int;
#endif // !xgboost_IS_MINGW()
#else // UNIX
#include <arpa/inet.h> // inet_ntop
#include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
#include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
#include <netinet/in.h> // IPPROTO_TCP
#include <netinet/tcp.h> // TCP_NODELAY
#include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
#include <unistd.h> // close
#if defined(__sun) || defined(sun)
#include <sys/sockio.h>
#endif // defined(__sun) || defined(sun)
#endif // defined(_WIN32)
#include "xgboost/base.h" // XGBOOST_EXPECT
#include "xgboost/logging.h" // LOG
#include "xgboost/string_view.h" // StringView
#if !defined(HOST_NAME_MAX)
#define HOST_NAME_MAX 256 // macos
#endif
namespace xgboost {
#if xgboost_IS_MINGW()
// see the dummy implementation of `poll` in rabit for more info.
inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
#endif // xgboost_IS_MINGW()
namespace system {
inline std::int32_t LastError() {
#if defined(_WIN32)
return WSAGetLastError();
#else
int errsv = errno;
return errsv;
#endif
}
#if defined(__GLIBC__)
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
std::int32_t line = __builtin_LINE(),
char const *file = __builtin_FILE()) {
auto err = std::error_code{errsv, std::system_category()};
LOG(FATAL) << "\n"
<< file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
<< std::endl;
}
#else
inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
auto err = std::error_code{errsv, std::system_category()};
LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
}
#endif // defined(__GLIBC__)
#if defined(_WIN32)
using SocketT = SOCKET;
#else
using SocketT = int;
#endif // defined(_WIN32)
#if !defined(xgboost_CHECK_SYS_CALL)
#define xgboost_CHECK_SYS_CALL(exp, expected) \
do { \
if (XGBOOST_EXPECT((exp) != (expected), false)) { \
::xgboost::system::ThrowAtError(#exp); \
} \
} while (false)
#endif // !defined(xgboost_CHECK_SYS_CALL)
inline std::int32_t CloseSocket(SocketT fd) {
#if defined(_WIN32)
return closesocket(fd);
#else
return close(fd);
#endif
}
inline bool LastErrorWouldBlock() {
int errsv = LastError();
#ifdef _WIN32
return errsv == WSAEWOULDBLOCK;
#else
return errsv == EAGAIN || errsv == EWOULDBLOCK;
#endif // _WIN32
}
inline void SocketStartup() {
#if defined(_WIN32)
WSADATA wsa_data;
if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
ThrowAtError("WSAStartup");
}
if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
WSACleanup();
LOG(FATAL) << "Could not find a usable version of Winsock.dll";
}
#endif // defined(_WIN32)
}
inline void SocketFinalize() {
#if defined(_WIN32)
WSACleanup();
#endif // defined(_WIN32)
}
#if defined(_WIN32) && xgboost_IS_MINGW()
// dummy definition for old mysys32.
inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
MingWError();
return nullptr;
}
#else
using ::inet_ntop;
#endif
} // namespace system
namespace collective {
class SockAddress;
enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
/**
* \brief Parse host address and return a SockAddress instance. Supports IPv4 and IPv6
* host.
*/
SockAddress MakeSockAddress(StringView host, in_port_t port);
class SockAddrV6 {
sockaddr_in6 addr_;
public:
explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
static SockAddrV6 Loopback();
static SockAddrV6 InaddrAny();
in_port_t Port() const { return ntohs(addr_.sin6_port); }
std::string Addr() const {
char buf[INET6_ADDRSTRLEN];
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
buf, INET6_ADDRSTRLEN);
if (s == nullptr) {
system::ThrowAtError("inet_ntop");
}
return {buf};
}
sockaddr_in6 const &Handle() const { return addr_; }
};
class SockAddrV4 {
private:
sockaddr_in addr_;
public:
explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
static SockAddrV4 Loopback();
static SockAddrV4 InaddrAny();
in_port_t Port() const { return ntohs(addr_.sin_port); }
std::string Addr() const {
char buf[INET_ADDRSTRLEN];
auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
buf, INET_ADDRSTRLEN);
if (s == nullptr) {
system::ThrowAtError("inet_ntop");
}
return {buf};
}
sockaddr_in const &Handle() const { return addr_; }
};
/**
* \brief Address for TCP socket, can be either IPv4 or IPv6.
*/
class SockAddress {
private:
SockAddrV6 v6_;
SockAddrV4 v4_;
SockDomain domain_{SockDomain::kV4};
public:
SockAddress() = default;
explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
auto Domain() const { return domain_; }
bool IsV4() const { return Domain() == SockDomain::kV4; }
bool IsV6() const { return !IsV4(); }
auto const &V4() const { return v4_; }
auto const &V6() const { return v6_; }
};
/**
* \brief TCP socket for simple communication.
*/
class TCPSocket {
public:
using HandleT = system::SocketT;
private:
HandleT handle_{InvalidSocket()};
// There's reliable no way to extract domain from a socket without first binding that
// socket on macos.
#if defined(__APPLE__)
SockDomain domain_{SockDomain::kV4};
#endif
constexpr static HandleT InvalidSocket() { return -1; }
explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
public:
TCPSocket() = default;
/**
* \brief Return the socket domain.
*/
auto Domain() const -> SockDomain {
auto ret_iafamily = [](std::int32_t domain) {
switch (domain) {
case AF_INET:
return SockDomain::kV4;
case AF_INET6:
return SockDomain::kV6;
default: {
LOG(FATAL) << "Unknown IA family.";
}
}
return SockDomain::kV4;
};
#if defined(_WIN32)
WSAPROTOCOL_INFOA info;
socklen_t len = sizeof(info);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
0);
return ret_iafamily(info.iAddressFamily);
#elif defined(__APPLE__)
return domain_;
#elif defined(__unix__)
#ifndef __PASE__
std::int32_t domain;
socklen_t len = sizeof(domain);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
return ret_iafamily(domain);
#else
struct sockaddr sa;
socklen_t sizeofsa = sizeof(sa);
xgboost_CHECK_SYS_CALL(
getsockname(handle_,&sa,&sizeofsa), 0);
if(sizeofsa < sizeof(uchar_t)*2){
return ret_iafamily(AF_INET);
}
return ret_iafamily(sa.sa_family);
#endif //__PASE__
#else
LOG(FATAL) << "Unknown platform.";
return ret_iafamily(AF_INET);
#endif // platforms
}
bool IsClosed() const { return handle_ == InvalidSocket(); }
/** \brief get last error code if any */
std::int32_t GetSockError() const {
std::int32_t error = 0;
socklen_t len = sizeof(error);
xgboost_CHECK_SYS_CALL(
getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
return error;
}
/** \brief check if anything bad happens */
bool BadSocket() const {
if (IsClosed()) return true;
std::int32_t err = GetSockError();
if (err == EBADF || err == EINTR) return true;
return false;
}
void SetNonBlock() {
bool non_block{true};
#if defined(_WIN32)
u_long mode = non_block ? 1 : 0;
xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
#else
std::int32_t flag = fcntl(handle_, F_GETFL, 0);
if (flag == -1) {
system::ThrowAtError("fcntl");
}
if (non_block) {
flag |= O_NONBLOCK;
} else {
flag &= ~O_NONBLOCK;
}
if (fcntl(handle_, F_SETFL, flag) == -1) {
system::ThrowAtError("fcntl");
}
#endif // _WIN32
}
void SetKeepAlive() {
std::int32_t keepalive = 1;
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
0);
}
void SetNoDelay() {
std::int32_t tcp_no_delay = 1;
xgboost_CHECK_SYS_CALL(
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay)),
0);
}
/**
* \brief Accept new connection, returns a new TCP socket for the new connection.
*/
TCPSocket Accept() {
HandleT newfd = accept(handle_, nullptr, nullptr);
if (newfd == InvalidSocket()) {
system::ThrowAtError("accept");
}
TCPSocket newsock{newfd};
return newsock;
}
~TCPSocket() {
if (!IsClosed()) {
Close();
}
}
TCPSocket(TCPSocket const &that) = delete;
TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
TCPSocket &operator=(TCPSocket const &that) = delete;
TCPSocket &operator=(TCPSocket &&that) {
std::swap(this->handle_, that.handle_);
return *this;
}
/**
* \brief Return the native socket file descriptor.
*/
HandleT const &Handle() const { return handle_; }
/**
* \brief Listen to incoming requests. Should be called after bind.
*/
void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
/**
* \brief Bind socket to INADDR_ANY, return the port selected by the OS.
*/
in_port_t BindHost() {
if (Domain() == SockDomain::kV6) {
auto addr = SockAddrV6::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
sockaddr_in6 res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin6_port);
} else {
auto addr = SockAddrV4::InaddrAny();
auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
xgboost_CHECK_SYS_CALL(
bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
sockaddr_in res_addr;
socklen_t addrlen = sizeof(res_addr);
xgboost_CHECK_SYS_CALL(
getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
return ntohs(res_addr.sin_port);
}
}
/**
* \brief Send data, without error then all data should be sent.
*/
auto SendAll(void const *buf, std::size_t len) {
char const *_buf = reinterpret_cast<const char *>(buf);
std::size_t ndone = 0;
while (ndone < len) {
ssize_t ret = send(handle_, _buf, len - ndone, 0);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
}
system::ThrowAtError("send");
}
_buf += ret;
ndone += ret;
}
return ndone;
}
/**
* \brief Receive data, without error then all data should be received.
*/
auto RecvAll(void *buf, std::size_t len) {
char *_buf = reinterpret_cast<char *>(buf);
std::size_t ndone = 0;
while (ndone < len) {
ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
if (ret == -1) {
if (system::LastErrorWouldBlock()) {
return ndone;
}
system::ThrowAtError("recv");
}
if (ret == 0) {
return ndone;
}
_buf += ret;
ndone += ret;
}
return ndone;
}
/**
* \brief Send data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually sent return -1 if error occurs
*/
auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
const char *buf = reinterpret_cast<const char *>(buf_);
return send(handle_, buf, len, flags);
}
/**
* \brief receive data using the socket
* \param buf the pointer to the buffer
* \param len the size of the buffer
* \param flags extra flags
* \return size of data actually received return -1 if error occurs
*/
auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
char *_buf = reinterpret_cast<char *>(buf);
return recv(handle_, _buf, len, flags);
}
/**
* \brief Send string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Send(StringView str);
/**
* \brief Receive string, format is matched with the Python socket wrapper in RABIT.
*/
std::size_t Recv(std::string *p_str);
/**
* \brief Close the socket, called automatically in destructor if the socket is not closed.
*/
void Close() {
if (InvalidSocket() != handle_) {
xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0);
handle_ = InvalidSocket();
}
}
/**
* \brief Create a TCP socket on specified domain.
*/
static TCPSocket Create(SockDomain domain) {
#if xgboost_IS_MINGW()
MingWError();
return {};
#else
auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
if (fd == InvalidSocket()) {
system::ThrowAtError("socket");
}
TCPSocket socket{fd};
#if defined(__APPLE__)
socket.domain_ = domain;
#endif // defined(__APPLE__)
return socket;
#endif // xgboost_IS_MINGW()
}
};
/**
* \brief Connect to remote address, returns the error code if failed (no exception is
* raised so that we can retry).
*/
std::error_code Connect(SockAddress const &addr, TCPSocket *out);
/**
* \brief Get the local host name.
*/
inline std::string GetHostName() {
char buf[HOST_NAME_MAX];
xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
return buf;
}
} // namespace collective
} // namespace xgboost
#undef xgboost_CHECK_SYS_CALL
#undef xgboost_IS_MINGW