diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index ded96bcbab4c..434428d3ea42 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -90,6 +90,9 @@ #include "../src/common/threading_utils.cc" #include "../src/common/version.cc" +// collective +#include "../src/collective/socket.cc" + // c_api #include "../src/c_api/c_api.cc" #include "../src/c_api/c_api_error.cc" diff --git a/doc/conf.py b/doc/conf.py index c362709d5bcb..7e1126331d7f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -204,7 +204,7 @@ ] intersphinx_mapping = { - "python": ("https://docs.python.org/3.6", None), + "python": ("https://docs.python.org/3.8", None), "numpy": ("https://docs.scipy.org/doc/numpy/", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 2eebeda11534..fa10c04e7528 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -474,7 +474,6 @@ interface, including callback functions, custom evaluation metric and objective: callbacks=[early_stop], ) - .. _tracker-ip: *************** @@ -504,6 +503,35 @@ dask config is used: reg = dxgb.DaskXGBRegressor() + +************ +IPv6 Support +************ + +.. versionadded:: 2.0.0 + +XGBoost has initial IPv6 support for the dask interface on Linux. Due to most of the +cluster support for IPv6 is partial (dual stack instead of IPv6 only), we require +additional user configuration similar to :ref:`tracker-ip` to help XGBoost obtain the +correct address information: + +.. code-block:: python + + import dask + from distributed import Client + from xgboost import dask as dxgb + # let xgboost know the scheduler address, use the same bracket format as dask. + with dask.config.set({"xgboost.scheduler_address": "[fd20:b6f:f759:9800::]"}): + with Client("[fd20:b6f:f759:9800::]") as client: + reg = dxgb.DaskXGBRegressor(tree_method="hist") + + +When GPU is used, XGBoost employs `NCCL `_ as the +underlying communication framework, which may require some additional configuration via +environment variable depending on the setting of the cluster. Please note that IPv6 +support is Unix only. + + ***************************************************************************** Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors ***************************************************************************** diff --git a/include/xgboost/collective/socket.h b/include/xgboost/collective/socket.h new file mode 100644 index 000000000000..e058e58a843f --- /dev/null +++ b/include/xgboost/collective/socket.h @@ -0,0 +1,536 @@ +/*! + * Copyright (c) 2022 by XGBoost Contributors + */ +#pragma once + +#if !defined(NOMINMAX) && defined(_WIN32) +#define NOMINMAX +#endif // !defined(NOMINMAX) + +#include // errno, EINTR, EBADF +#include // HOST_NAME_MAX +#include // std::size_t +#include // std::int32_t, std::uint16_t +#include // memset +#include // std::numeric_limits +#include // std::string +#include // std::error_code, std::system_category +#include // std::swap + +#if !defined(xgboost_IS_MINGW) +#define xgboost_IS_MINGW() defined(__MINGW32__) +#endif // xgboost_IS_MINGW + +#if defined(_WIN32) + +#include +#include + +using in_port_t = std::uint16_t; + +#ifdef _MSC_VER +#pragma comment(lib, "Ws2_32.lib") +#endif // _MSC_VER + +#if !xgboost_IS_MINGW() +using ssize_t = int; +#endif // !xgboost_IS_MINGW() + +#else // UNIX + +#include // inet_ntop +#include // fcntl, F_GETFL, O_NONBLOCK +#include // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN +#include // IPPROTO_TCP +#include // TCP_NODELAY +#include // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET +#include // close + +#if defined(__sun) || defined(sun) +#include +#endif // defined(__sun) || defined(sun) + +#endif // defined(_WIN32) + +#include "xgboost/base.h" // XGBOOST_EXPECT +#include "xgboost/logging.h" // LOG +#include "xgboost/string_view.h" // StringView + +#if !defined(HOST_NAME_MAX) +#define HOST_NAME_MAX 256 // macos +#endif + +namespace xgboost { + +#if xgboost_IS_MINGW() +// see the dummy implementation of `poll` in rabit for more info. +inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; } +#endif // xgboost_IS_MINGW() + +namespace system { +inline std::int32_t LastError() { +#if defined(_WIN32) + return WSAGetLastError(); +#else + int errsv = errno; + return errsv; +#endif +} + +#if defined(__GLIBC__) +inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(), + std::int32_t line = __builtin_LINE(), + char const *file = __builtin_FILE()) { + auto err = std::error_code{errsv, std::system_category()}; + LOG(FATAL) << "\n" + << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message() + << std::endl; +} +#else +inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) { + auto err = std::error_code{errsv, std::system_category()}; + LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl; +} +#endif // defined(__GLIBC__) + +#if defined(_WIN32) +using SocketT = SOCKET; +#else +using SocketT = int; +#endif // defined(_WIN32) + +#if !defined(xgboost_CHECK_SYS_CALL) +#define xgboost_CHECK_SYS_CALL(exp, expected) \ + do { \ + if (XGBOOST_EXPECT((exp) != (expected), false)) { \ + ::xgboost::system::ThrowAtError(#exp); \ + } \ + } while (false) +#endif // !defined(xgboost_CHECK_SYS_CALL) + +inline std::int32_t CloseSocket(SocketT fd) { +#if defined(_WIN32) + return closesocket(fd); +#else + return close(fd); +#endif +} + +inline bool LastErrorWouldBlock() { + int errsv = LastError(); +#ifdef _WIN32 + return errsv == WSAEWOULDBLOCK; +#else + return errsv == EAGAIN || errsv == EWOULDBLOCK; +#endif // _WIN32 +} + +inline void SocketStartup() { +#if defined(_WIN32) + WSADATA wsa_data; + if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { + ThrowAtError("WSAStartup"); + } + if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { + WSACleanup(); + LOG(FATAL) << "Could not find a usable version of Winsock.dll"; + } +#endif // defined(_WIN32) +} + +inline void SocketFinalize() { +#if defined(_WIN32) + WSACleanup(); +#endif // defined(_WIN32) +} + +#if defined(_WIN32) && xgboost_IS_MINGW() +// dummy definition for old mysys32. +inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT + MingWError(); + return nullptr; +} +#else +using ::inet_ntop; +#endif + +} // namespace system + +namespace collective { +class SockAddress; + +enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 }; + +/** + * \brief Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 + * host. + */ +SockAddress MakeSockAddress(StringView host, in_port_t port); + +class SockAddrV6 { + sockaddr_in6 addr_; + + public: + explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {} + SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); } + + static SockAddrV6 Loopback(); + static SockAddrV6 InaddrAny(); + + in_port_t Port() const { return ntohs(addr_.sin6_port); } + + std::string Addr() const { + char buf[INET6_ADDRSTRLEN]; + auto const *s = system::inet_ntop(static_cast(SockDomain::kV6), &addr_.sin6_addr, + buf, INET6_ADDRSTRLEN); + if (s == nullptr) { + system::ThrowAtError("inet_ntop"); + } + return {buf}; + } + sockaddr_in6 const &Handle() const { return addr_; } +}; + +class SockAddrV4 { + private: + sockaddr_in addr_; + + public: + explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {} + SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); } + + static SockAddrV4 Loopback(); + static SockAddrV4 InaddrAny(); + + in_port_t Port() const { return ntohs(addr_.sin_port); } + + std::string Addr() const { + char buf[INET_ADDRSTRLEN]; + auto const *s = system::inet_ntop(static_cast(SockDomain::kV4), &addr_.sin_addr, + buf, INET_ADDRSTRLEN); + if (s == nullptr) { + system::ThrowAtError("inet_ntop"); + } + return {buf}; + } + sockaddr_in const &Handle() const { return addr_; } +}; + +/** + * \brief Address for TCP socket, can be either IPv4 or IPv6. + */ +class SockAddress { + private: + SockAddrV6 v6_; + SockAddrV4 v4_; + SockDomain domain_{SockDomain::kV4}; + + public: + SockAddress() = default; + explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {} + explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {} + + auto Domain() const { return domain_; } + + bool IsV4() const { return Domain() == SockDomain::kV4; } + bool IsV6() const { return !IsV4(); } + + auto const &V4() const { return v4_; } + auto const &V6() const { return v6_; } +}; + +/** + * \brief TCP socket for simple communication. + */ +class TCPSocket { + public: + using HandleT = system::SocketT; + + private: + HandleT handle_{InvalidSocket()}; + // There's reliable no way to extract domain from a socket without first binding that + // socket on macos. +#if defined(__APPLE__) + SockDomain domain_{SockDomain::kV4}; +#endif + + constexpr static HandleT InvalidSocket() { return -1; } + + explicit TCPSocket(HandleT newfd) : handle_{newfd} {} + + public: + TCPSocket() = default; + /** + * \brief Return the socket domain. + */ + auto Domain() const -> SockDomain { + auto ret_iafamily = [](std::int32_t domain) { + switch (domain) { + case AF_INET: + return SockDomain::kV4; + case AF_INET6: + return SockDomain::kV6; + default: { + LOG(FATAL) << "Unknown IA family."; + } + } + return SockDomain::kV4; + }; + +#if defined(_WIN32) + WSAPROTOCOL_INFOA info; + socklen_t len = sizeof(info); + xgboost_CHECK_SYS_CALL( + getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast(&info), &len), + 0); + return ret_iafamily(info.iAddressFamily); +#elif defined(__APPLE__) + return domain_; +#elif defined(__unix__) + std::int32_t domain; + socklen_t len = sizeof(domain); + xgboost_CHECK_SYS_CALL( + getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast(&domain), &len), 0); + return ret_iafamily(domain); +#else + LOG(FATAL) << "Unknown platform."; + return ret_iafamily(AF_INET); +#endif // platforms + } + + bool IsClosed() const { return handle_ == InvalidSocket(); } + + /** \brief get last error code if any */ + std::int32_t GetSockError() const { + std::int32_t error = 0; + socklen_t len = sizeof(error); + xgboost_CHECK_SYS_CALL( + getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len), 0); + return error; + } + /** \brief check if anything bad happens */ + bool BadSocket() const { + if (IsClosed()) return true; + std::int32_t err = GetSockError(); + if (err == EBADF || err == EINTR) return true; + return false; + } + + void SetNonBlock() { + bool non_block{true}; +#if defined(_WIN32) + u_long mode = non_block ? 1 : 0; + xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR); +#else + std::int32_t flag = fcntl(handle_, F_GETFL, 0); + if (flag == -1) { + system::ThrowAtError("fcntl"); + } + if (non_block) { + flag |= O_NONBLOCK; + } else { + flag &= ~O_NONBLOCK; + } + if (fcntl(handle_, F_SETFL, flag) == -1) { + system::ThrowAtError("fcntl"); + } +#endif // _WIN32 + } + + void SetKeepAlive() { + std::int32_t keepalive = 1; + xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, + reinterpret_cast(&keepalive), sizeof(keepalive)), + 0); + } + + void SetNoDelay() { + std::int32_t tcp_no_delay = 1; + xgboost_CHECK_SYS_CALL( + setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&tcp_no_delay), + sizeof(tcp_no_delay)), + 0); + } + + /** + * \brief Accept new connection, returns a new TCP socket for the new connection. + */ + TCPSocket Accept() { + HandleT newfd = accept(handle_, nullptr, nullptr); + if (newfd == InvalidSocket()) { + system::ThrowAtError("accept"); + } + TCPSocket newsock{newfd}; + return newsock; + } + + ~TCPSocket() { + if (!IsClosed()) { + Close(); + } + } + + TCPSocket(TCPSocket const &that) = delete; + TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); } + TCPSocket &operator=(TCPSocket const &that) = delete; + TCPSocket &operator=(TCPSocket &&that) { + std::swap(this->handle_, that.handle_); + return *this; + } + /** + * \brief Return the native socket file descriptor. + */ + HandleT const &Handle() const { return handle_; } + /** + * \brief Listen to incoming requests. Should be called after bind. + */ + void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); } + /** + * \brief Bind socket to INADDR_ANY, return the port selected by the OS. + */ + in_port_t BindHost() { + if (Domain() == SockDomain::kV6) { + auto addr = SockAddrV6::InaddrAny(); + auto handle = reinterpret_cast(&addr.Handle()); + xgboost_CHECK_SYS_CALL( + bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + + sockaddr_in6 res_addr; + socklen_t addrlen = sizeof(res_addr); + xgboost_CHECK_SYS_CALL( + getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); + return ntohs(res_addr.sin6_port); + } else { + auto addr = SockAddrV4::InaddrAny(); + auto handle = reinterpret_cast(&addr.Handle()); + xgboost_CHECK_SYS_CALL( + bind(handle_, handle, sizeof(std::remove_reference_t)), 0); + + sockaddr_in res_addr; + socklen_t addrlen = sizeof(res_addr); + xgboost_CHECK_SYS_CALL( + getsockname(handle_, reinterpret_cast(&res_addr), &addrlen), 0); + return ntohs(res_addr.sin_port); + } + } + /** + * \brief Send data, without error then all data should be sent. + */ + auto SendAll(void const *buf, std::size_t len) { + char const *_buf = reinterpret_cast(buf); + std::size_t ndone = 0; + while (ndone < len) { + ssize_t ret = send(handle_, _buf, len - ndone, 0); + if (ret == -1) { + if (system::LastErrorWouldBlock()) { + return ndone; + } + system::ThrowAtError("send"); + } + _buf += ret; + ndone += ret; + } + return ndone; + } + /** + * \brief Receive data, without error then all data should be received. + */ + auto RecvAll(void *buf, std::size_t len) { + char *_buf = reinterpret_cast(buf); + std::size_t ndone = 0; + while (ndone < len) { + ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL); + if (ret == -1) { + if (system::LastErrorWouldBlock()) { + return ndone; + } + system::ThrowAtError("recv"); + } + if (ret == 0) { + return ndone; + } + _buf += ret; + ndone += ret; + } + return ndone; + } + /** + * \brief Send data using the socket + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually sent return -1 if error occurs + */ + auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) { + const char *buf = reinterpret_cast(buf_); + return send(handle_, buf, len, flags); + } + /** + * \brief receive data using the socket + * \param buf the pointer to the buffer + * \param len the size of the buffer + * \param flags extra flags + * \return size of data actually received return -1 if error occurs + */ + auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) { + char *_buf = reinterpret_cast(buf); + return recv(handle_, _buf, len, flags); + } + /** + * \brief Send string, format is matched with the Python socket wrapper in RABIT. + */ + std::size_t Send(StringView str); + /** + * \brief Receive string, format is matched with the Python socket wrapper in RABIT. + */ + std::size_t Recv(std::string *p_str); + /** + * \brief Close the socket, called automatically in destructor if the socket is not closed. + */ + void Close() { + if (InvalidSocket() != handle_) { + xgboost_CHECK_SYS_CALL(system::CloseSocket(handle_), 0); + handle_ = InvalidSocket(); + } + } + /** + * \brief Create a TCP socket on specified domain. + */ + static TCPSocket Create(SockDomain domain) { +#if xgboost_IS_MINGW() + MingWError(); + return {}; +#else + auto fd = socket(static_cast(domain), SOCK_STREAM, 0); + if (fd == InvalidSocket()) { + system::ThrowAtError("socket"); + } + + TCPSocket socket{fd}; +#if defined(__APPLE__) + socket.domain_ = domain; +#endif // defined(__APPLE__) + return socket; +#endif // xgboost_IS_MINGW() + } +}; + +/** + * \brief Connect to remote address, returns the error code if failed (no exception is + * raised so that we can retry). + */ +std::error_code Connect(SockAddress const &addr, TCPSocket *out); + +/** + * \brief Get the local host name. + */ +inline std::string GetHostName() { + char buf[HOST_NAME_MAX]; + xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0); + return buf; +} +} // namespace collective +} // namespace xgboost + +#undef xgboost_CHECK_SYS_CALL +#undef xgboost_IS_MINGW diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 9a74d0143681..3daf2b44b3c6 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -52,6 +52,7 @@ Sequence, Set, Tuple, + TypedDict, TypeVar, Union, ) @@ -102,19 +103,13 @@ _DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"] _DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor - -try: - from mypy_extensions import TypedDict - - TrainReturnT = TypedDict( - "TrainReturnT", - { - "booster": Booster, - "history": Dict, - }, - ) -except ImportError: - TrainReturnT = Dict[str, Any] # type:ignore +TrainReturnT = TypedDict( + "TrainReturnT", + { + "booster": Booster, + "history": Dict, + }, +) __all__ = [ "RabitContext", @@ -832,11 +827,15 @@ async def _get_rabit_args( if k not in valid_config: raise ValueError(f"Unknown configuration: {k}") host_ip = dconfig.get("scheduler_address", None) + if host_ip is not None and host_ip.startswith("[") and host_ip.endswith("]"): + # convert dask bracket format to proper IPv6 address. + host_ip = host_ip[1:-1] if host_ip is not None: try: host_ip, port = distributed.comm.get_address_host_port(host_ip) except ValueError: pass + if host_ip is not None: user_addr = (host_ip, port) else: diff --git a/python-package/xgboost/testing.py b/python-package/xgboost/testing.py new file mode 100644 index 000000000000..9e1b54276037 --- /dev/null +++ b/python-package/xgboost/testing.py @@ -0,0 +1,41 @@ +"""Utilities for defining Python tests.""" + +import socket +from platform import system +from typing import TypedDict + +PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str}) + + +def has_ipv6() -> bool: + """Check whether IPv6 is enabled on this host.""" + # connection error in macos, still need some fixes. + if system() not in ("Linux", "Windows"): + return False + + if socket.has_ipv6: + try: + with socket.socket( + socket.AF_INET6, socket.SOCK_STREAM + ) as server, socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as client: + server.bind(("::1", 0)) + port = server.getsockname()[1] + server.listen() + + client.connect(("::1", port)) + conn, _ = server.accept() + + client.sendall("abc".encode()) + msg = conn.recv(3).decode() + # if the code can be executed to this point, the message should be + # correct. + assert msg == "abc" + return True + except OSError: + pass + return False + + +def skip_ipv6() -> PytestSkip: + """PyTest skip mark for IPv6.""" + return {"condition": not has_ipv6(), "reason": "IPv6 is required to be enabled."} diff --git a/python-package/xgboost/tracker.py b/python-package/xgboost/tracker.py index 6dc6167d9517..169f303ccfe9 100644 --- a/python-package/xgboost/tracker.py +++ b/python-package/xgboost/tracker.py @@ -112,7 +112,7 @@ def assign_rank( """Assign the rank for current entry.""" self.rank = rank nnset = set(tree_map[rank]) - rprev, rnext = ring_map[rank] + rprev, next_rank = ring_map[rank] self.sock.sendint(rank) # send parent rank self.sock.sendint(parent_map[rank]) @@ -129,9 +129,9 @@ def assign_rank( else: self.sock.sendint(-1) # send next link - if rnext not in (-1, rank): - nnset.add(rnext) - self.sock.sendint(rnext) + if next_rank not in (-1, rank): + nnset.add(next_rank) + self.sock.sendint(next_rank) else: self.sock.sendint(-1) @@ -157,6 +157,7 @@ def _get_remote( self.sock.sendstr(wait_conn[r].host) port = wait_conn[r].port assert port is not None + # send port of this node to other workers so that they can call connect self.sock.sendint(port) self.sock.sendint(r) nerr = self.sock.recvint() diff --git a/rabit/include/rabit/internal/socket.h b/rabit/include/rabit/internal/socket.h index d42df39d73be..cb7d4a0784c1 100644 --- a/rabit/include/rabit/internal/socket.h +++ b/rabit/include/rabit/internal/socket.h @@ -1,65 +1,48 @@ /*! - * Copyright (c) 2014-2019 by Contributors + * Copyright (c) 2014-2022 by XGBoost Contributors * \file socket.h - * \brief this file aims to provide a wrapper of sockets * \author Tianqi Chen */ #ifndef RABIT_INTERNAL_SOCKET_H_ #define RABIT_INTERNAL_SOCKET_H_ +#include "xgboost/collective/socket.h" + #if defined(_WIN32) #include #include -#ifdef _MSC_VER -#pragma comment(lib, "Ws2_32.lib") -#endif // _MSC_VER - #else +#include #include #include -#include -#include -#include #include -#include #include +#include +#include -#if defined(__sun) || defined(sun) -#include -#endif // defined(__sun) || defined(sun) +#include #endif // defined(_WIN32) -#include -#include -#include #include +#include +#include #include -#include "utils.h" - -#if defined(_WIN32) && !defined(__MINGW32__) -typedef int ssize_t; -#endif // defined(_WIN32) || defined(__MINGW32__) +#include -#if defined(_WIN32) -using sock_size_t = int; +#include "utils.h" -#else +#if !defined(_WIN32) #include + using SOCKET = int; using sock_size_t = size_t; // NOLINT -#endif // defined(_WIN32) +#endif // !defined(_WIN32) #define IS_MINGW() defined(__MINGW32__) -#if IS_MINGW() -inline void MingWError() { - throw dmlc::Error("Distributed training on mingw is not supported."); -} -#endif // IS_MINGW() - #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND) /* * On later mingw versions poll should be supported (with bugs). See: @@ -88,23 +71,17 @@ typedef struct pollfd { // POLLWRNORM #define POLLOUT 0x0010 -inline const char *inet_ntop(int, const void *, char *, size_t) { - MingWError(); - return nullptr; -} #endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND) namespace rabit { namespace utils { -static constexpr int kInvalidSocket = -1; - template int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) { #if defined(_WIN32) #if IS_MINGW() - MingWError(); + xgboost::MingWError(); return -1; #else return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count()); @@ -115,458 +92,6 @@ int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) { #endif // IS_MINGW() } -/*! \brief data structure for network address */ -struct SockAddr { - sockaddr_in addr; - // constructor - SockAddr() = default; - SockAddr(const char *url, int port) { - this->Set(url, port); - } - inline static std::string GetHostName() { - std::string buf; buf.resize(256); -#if !IS_MINGW() - utils::Check(gethostname(&buf[0], 256) != -1, "fail to get host name"); -#endif // IS_MINGW() - return std::string(buf.c_str()); - } - /*! - * \brief set the address - * \param url the url of the address - * \param port the port of address - */ - inline void Set(const char *host, int port) { -#if !IS_MINGW() - addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_INET; - hints.ai_protocol = SOCK_STREAM; - addrinfo *res = nullptr; - int sig = getaddrinfo(host, nullptr, &hints, &res); - Check(sig == 0 && res != nullptr, "cannot obtain address of %s", host); - Check(res->ai_family == AF_INET, "Does not support IPv6"); - memcpy(&addr, res->ai_addr, res->ai_addrlen); - addr.sin_port = htons(port); - freeaddrinfo(res); -#endif // !IS_MINGW() - } - /*! \brief return port of the address*/ - inline int Port() const { - return ntohs(addr.sin_port); - } - /*! \return a string representation of the address */ - inline std::string AddrStr() const { - std::string buf; buf.resize(256); -#ifdef _WIN32 - const char *s = inet_ntop(AF_INET, (PVOID)&addr.sin_addr, - &buf[0], buf.length()); -#else - const char *s = inet_ntop(AF_INET, &addr.sin_addr, - &buf[0], buf.length()); -#endif // _WIN32 - Assert(s != nullptr, "cannot decode address"); - return std::string(s); - } -}; - -/*! - * \brief base class containing common operations of TCP and UDP sockets - */ -class Socket { - public: - /*! \brief the file descriptor of socket */ - SOCKET sockfd; - // default conversion to int - operator SOCKET() const { // NOLINT - return sockfd; - } - /*! - * \return last error of socket operation - */ - inline static int GetLastError() { -#ifdef _WIN32 - -#if IS_MINGW() - MingWError(); - return -1; -#else - return WSAGetLastError(); -#endif // IS_MINGW() - -#else - return errno; -#endif // _WIN32 - } - /*! \return whether last error was would block */ - inline static bool LastErrorWouldBlock() { - int errsv = GetLastError(); -#ifdef _WIN32 - return errsv == WSAEWOULDBLOCK; -#else - return errsv == EAGAIN || errsv == EWOULDBLOCK; -#endif // _WIN32 - } - /*! - * \brief start up the socket module - * call this before using the sockets - */ - inline static void Startup() { -#ifdef _WIN32 -#if !IS_MINGW() - WSADATA wsa_data; - if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) { - Socket::Error("Startup"); - } - if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) { - WSACleanup(); - utils::Error("Could not find a usable version of Winsock.dll\n"); - } -#endif // !IS_MINGW() -#endif // _WIN32 - } - /*! - * \brief shutdown the socket module after use, all sockets need to be closed - */ - inline static void Finalize() { -#ifdef _WIN32 -#if !IS_MINGW() - WSACleanup(); -#endif // !IS_MINGW() -#endif // _WIN32 - } - /*! - * \brief set this socket to use non-blocking mode - * \param non_block whether set it to be non-block, if it is false - * it will set it back to block mode - */ - inline void SetNonBlock(bool non_block) { -#ifdef _WIN32 -#if !IS_MINGW() - u_long mode = non_block ? 1 : 0; - if (ioctlsocket(sockfd, FIONBIO, &mode) != NO_ERROR) { - Socket::Error("SetNonBlock"); - } -#endif // !IS_MINGW() -#else - int flag = fcntl(sockfd, F_GETFL, 0); - if (flag == -1) { - Socket::Error("SetNonBlock-1"); - } - if (non_block) { - flag |= O_NONBLOCK; - } else { - flag &= ~O_NONBLOCK; - } - if (fcntl(sockfd, F_SETFL, flag) == -1) { - Socket::Error("SetNonBlock-2"); - } -#endif // _WIN32 - } - /*! - * \brief bind the socket to an address - * \param addr - */ - inline void Bind(const SockAddr &addr) { -#if !IS_MINGW() - if (bind(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == -1) { - Socket::Error("Bind"); - } -#endif // !IS_MINGW() - } - /*! - * \brief try bind the socket to host, from start_port to end_port - * \param start_port starting port number to try - * \param end_port ending port number to try - * \return the port successfully bind to, return -1 if failed to bind any port - */ - inline int TryBindHost(int start_port, int end_port) { - // TODO(tqchen) add prefix check -#if !IS_MINGW() - for (int port = start_port; port < end_port; ++port) { - SockAddr addr("0.0.0.0", port); - if (bind(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == 0) { - return port; - } -#if defined(_WIN32) - if (WSAGetLastError() != WSAEADDRINUSE) { - Socket::Error("TryBindHost"); - } -#else - if (errno != EADDRINUSE) { - Socket::Error("TryBindHost"); - } -#endif // defined(_WIN32) - } -#endif // !IS_MINGW() - return -1; - } - /*! \brief get last error code if any */ - inline int GetSockError() const { - int error = 0; - socklen_t len = sizeof(error); -#if !IS_MINGW() - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len) != 0) { - Error("GetSockError"); - } -#else - // undefined reference to `_imp__getsockopt@20' - MingWError(); -#endif // !IS_MINGW() - return error; - } - /*! \brief check if anything bad happens */ - inline bool BadSocket() const { - if (IsClosed()) return true; - int err = GetSockError(); - if (err == EBADF || err == EINTR) return true; - return false; - } - /*! \brief check if socket is already closed */ - inline bool IsClosed() const { - return sockfd == kInvalidSocket; - } - /*! \brief close the socket */ - inline void Close() { - if (sockfd != kInvalidSocket) { -#ifdef _WIN32 -#if !IS_MINGW() - closesocket(sockfd); -#endif // !IS_MINGW() -#else - close(sockfd); -#endif - sockfd = kInvalidSocket; - } else { - Error("Socket::Close double close the socket or close without create"); - } - } - // report an socket error - inline static void Error(const char *msg) { - int errsv = GetLastError(); -#ifdef _WIN32 - utils::Error("Socket %s Error:WSAError-code=%d", msg, errsv); -#else - utils::Error("Socket %s Error:%s", msg, strerror(errsv)); -#endif - } - - protected: - explicit Socket(SOCKET sockfd) : sockfd(sockfd) { - } -}; - -/*! - * \brief a wrapper of TCP socket that hopefully be cross platform - */ -class TCPSocket : public Socket{ - public: - // constructor - TCPSocket() : Socket(kInvalidSocket) { - } - explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { - } - /*! - * \brief enable/disable TCP keepalive - * \param keepalive whether to set the keep alive option on - */ - void SetKeepAlive(bool keepalive) { -#if !IS_MINGW() - int opt = static_cast(keepalive); - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&opt), sizeof(opt)) < 0) { - Socket::Error("SetKeepAlive"); - } -#endif // !IS_MINGW() - } - inline void SetLinger(int timeout = 0) { -#if !IS_MINGW() - struct linger sl; - sl.l_onoff = 1; /* non-zero value enables linger option in kernel */ - sl.l_linger = timeout; /* timeout interval in seconds */ - if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER, reinterpret_cast(&sl), sizeof(sl)) == -1) { - Socket::Error("SO_LINGER"); - } -#endif // !IS_MINGW() - } - /*! - * \brief create the socket, call this before using socket - * \param af domain - */ - inline void Create(int af = PF_INET) { -#if !IS_MINGW() - sockfd = socket(af, SOCK_STREAM, 0); - if (sockfd == kInvalidSocket) { - Socket::Error("Create"); - } -#endif // !IS_MINGW() - } - /*! - * \brief perform listen of the socket - * \param backlog backlog parameter - */ - inline void Listen(int backlog = 16) { -#if !IS_MINGW() - listen(sockfd, backlog); -#endif // !IS_MINGW() - } - /*! \brief get a new connection */ - TCPSocket Accept() { -#if !IS_MINGW() - SOCKET newfd = accept(sockfd, nullptr, nullptr); - if (newfd == kInvalidSocket) { - Socket::Error("Accept"); - } - return TCPSocket(newfd); -#else - return TCPSocket(); -#endif // !IS_MINGW() - } - /*! - * \brief decide whether the socket is at OOB mark - * \return 1 if at mark, 0 if not, -1 if an error occured - */ - inline int AtMark() const { -#if !IS_MINGW() - -#ifdef _WIN32 - unsigned long atmark; // NOLINT(*) - if (ioctlsocket(sockfd, SIOCATMARK, &atmark) != NO_ERROR) return -1; -#else - int atmark; - if (ioctl(sockfd, SIOCATMARK, &atmark) == -1) return -1; -#endif // _WIN32 - - return static_cast(atmark); - -#else - return -1; -#endif // !IS_MINGW() - } - /*! - * \brief connect to an address - * \param addr the address to connect to - * \return whether connect is successful - */ - inline bool Connect(const SockAddr &addr) { -#if !IS_MINGW() - return connect(sockfd, reinterpret_cast(&addr.addr), - sizeof(addr.addr)) == 0; -#else - return false; -#endif // !IS_MINGW() - } - /*! - * \brief send data using the socket - * \param buf the pointer to the buffer - * \param len the size of the buffer - * \param flags extra flags - * \return size of data actually sent - * return -1 if error occurs - */ - inline ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); -#if !IS_MINGW() - return send(sockfd, buf, static_cast(len), flag); -#else - return 0; -#endif // !IS_MINGW() - } - /*! - * \brief receive data using the socket - * \param buf_ the pointer to the buffer - * \param len the size of the buffer - * \param flags extra flags - * \return size of data actually received - * return -1 if error occurs - */ - inline ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); -#if !IS_MINGW() - return recv(sockfd, buf, static_cast(len), flags); -#else - return 0; -#endif // !IS_MINGW() - } - /*! - * \brief peform block write that will attempt to send all data out - * can still return smaller than request when error occurs - * \param buf the pointer to the buffer - * \param len the size of the buffer - * \return size of data actually sent - */ - inline size_t SendAll(const void *buf_, size_t len) { - const char *buf = reinterpret_cast(buf_); - size_t ndone = 0; -#if !IS_MINGW() - while (ndone < len) { - ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); - if (ret == -1) { - if (LastErrorWouldBlock()) return ndone; - Socket::Error("SendAll"); - } - buf += ret; - ndone += ret; - } -#endif // !IS_MINGW() - return ndone; - } - /*! - * \brief peforma block read that will attempt to read all data - * can still return smaller than request when error occurs - * \param buf_ the buffer pointer - * \param len length of data to recv - * \return size of data actually sent - */ - inline size_t RecvAll(void *buf_, size_t len) { - char *buf = reinterpret_cast(buf_); - size_t ndone = 0; -#if !IS_MINGW() - while (ndone < len) { - ssize_t ret = recv(sockfd, buf, - static_cast(len - ndone), MSG_WAITALL); - if (ret == -1) { - if (LastErrorWouldBlock()) return ndone; - Socket::Error("RecvAll"); - } - if (ret == 0) return ndone; - buf += ret; - ndone += ret; - } -#endif // !IS_MINGW() - return ndone; - } - /*! - * \brief send a string over network - * \param str the string to be sent - */ - inline void SendStr(const std::string &str) { - int len = static_cast(str.length()); - utils::Assert(this->SendAll(&len, sizeof(len)) == sizeof(len), - "error during send SendStr"); - if (len != 0) { - utils::Assert(this->SendAll(str.c_str(), str.length()) == str.length(), - "error during send SendStr"); - } - } - /*! - * \brief recv a string from network - * \param out_str the string to receive - */ - inline void RecvStr(std::string *out_str) { - int len; - utils::Assert(this->RecvAll(&len, sizeof(len)) == sizeof(len), - "error during send RecvStr"); - out_str->resize(len); - if (len != 0) { - utils::Assert(this->RecvAll(&(*out_str)[0], len) == out_str->length(), - "error during send SendStr"); - } - } -}; - /*! \brief helper data structure to perform poll */ struct PollHelper { public: @@ -579,6 +104,8 @@ struct PollHelper { pfd.fd = fd; pfd.events |= POLLIN; } + void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); } + /*! * \brief add file descriptor to watch for write * \param fd file descriptor to be watched @@ -588,6 +115,10 @@ struct PollHelper { pfd.fd = fd; pfd.events |= POLLOUT; } + void WatchWrite(xgboost::collective::TCPSocket const &socket) { + this->WatchWrite(socket.Handle()); + } + /*! * \brief add file descriptor to watch for exception * \param fd file descriptor to be watched @@ -597,6 +128,9 @@ struct PollHelper { pfd.fd = fd; pfd.events |= POLLPRI; } + void WatchException(xgboost::collective::TCPSocket const &socket) { + this->WatchException(socket.Handle()); + } /*! * \brief Check if the descriptor is ready for read * \param fd file descriptor to check status @@ -605,6 +139,10 @@ struct PollHelper { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); } + bool CheckRead(xgboost::collective::TCPSocket const &socket) const { + return this->CheckRead(socket.Handle()); + } + /*! * \brief Check if the descriptor is ready for write * \param fd file descriptor to check status @@ -613,7 +151,9 @@ struct PollHelper { const auto& pfd = fds.find(fd); return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); } - + bool CheckWrite(xgboost::collective::TCPSocket const &socket) const { + return this->CheckWrite(socket.Handle()); + } /*! * \brief perform poll on the set defined, read, write, exception * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block @@ -629,7 +169,7 @@ struct PollHelper { if (ret == 0) { LOG(FATAL) << "Poll timeout"; } else if (ret < 0) { - Socket::Error("Poll"); + LOG(FATAL) << "Failed to poll."; } else { for (auto& pfd : fdset) { auto revents = pfd.revents & pfd.events; diff --git a/rabit/include/rabit/internal/utils.h b/rabit/include/rabit/internal/utils.h index f23e27477d12..c1739ce7967b 100644 --- a/rabit/include/rabit/internal/utils.h +++ b/rabit/include/rabit/internal/utils.h @@ -8,15 +8,17 @@ #define RABIT_INTERNAL_UTILS_H_ #include -#include + +#include #include -#include #include +#include #include +#include #include + #include "dmlc/io.h" #include "xgboost/logging.h" -#include #if !defined(__GNUC__) || defined(__FreeBSD__) #define fopen64 std::fopen diff --git a/rabit/src/allreduce_base.cc b/rabit/src/allreduce_base.cc index 75ba901b2145..563898a30c1c 100644 --- a/rabit/src/allreduce_base.cc +++ b/rabit/src/allreduce_base.cc @@ -27,8 +27,6 @@ AllreduceBase::AllreduceBase() { tracker_uri = "NULL"; tracker_port = 9000; host_uri = ""; - slave_port = 9010; - nport_trial = 1000; rank = 0; world_size = -1; connect_retry = 5; @@ -114,16 +112,16 @@ bool AllreduceBase::Init(int argc, char* argv[]) { this->rank = -1; //--------------------- // start socket - utils::Socket::Startup(); + xgboost::system::SocketStartup(); utils::Assert(all_links.size() == 0, "can only call Init once"); - this->host_uri = utils::SockAddr::GetHostName(); + this->host_uri = xgboost::collective::GetHostName(); // get information from tracker return this->ReConnectLinks(); } bool AllreduceBase::Shutdown() { try { - for (auto & all_link : all_links) { + for (auto &all_link : all_links) { if (!all_link.sock.IsClosed()) { all_link.sock.Close(); } @@ -133,12 +131,12 @@ bool AllreduceBase::Shutdown() { if (tracker_uri == "NULL") return true; // notify tracker rank i have shutdown - utils::TCPSocket tracker = this->ConnectTracker(); - tracker.SendStr(std::string("shutdown")); + xgboost::collective::TCPSocket tracker = this->ConnectTracker(); + tracker.Send(xgboost::StringView{"shutdown"}); tracker.Close(); - utils::TCPSocket::Finalize(); + xgboost::system::SocketFinalize(); return true; - } catch (const std::exception& e) { + } catch (std::exception const &e) { LOG(WARNING) << "Failed to shutdown due to" << e.what(); return false; } @@ -148,9 +146,9 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { if (tracker_uri == "NULL") { utils::Printf("%s", msg.c_str()); return; } - utils::TCPSocket tracker = this->ConnectTracker(); - tracker.SendStr(std::string("print")); - tracker.SendStr(msg); + xgboost::collective::TCPSocket tracker = this->ConnectTracker(); + tracker.Send(xgboost::StringView{"print"}); + tracker.Send(xgboost::StringView{msg}); tracker.Close(); } @@ -227,21 +225,23 @@ void AllreduceBase::SetParam(const char *name, const char *val) { * \brief initialize connection to the tracker * \return a socket that initializes the connection */ -utils::TCPSocket AllreduceBase::ConnectTracker() const { +xgboost::collective::TCPSocket AllreduceBase::ConnectTracker() const { int magic = kMagic; // get information from tracker - utils::TCPSocket tracker; - tracker.Create(); + xgboost::collective::TCPSocket tracker; int retry = 0; do { - if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { + auto rc = xgboost::collective::Connect( + xgboost::collective::MakeSockAddress(xgboost::StringView{tracker_uri}, tracker_port), + &tracker); + if (rc != std::errc()) { if (++retry >= connect_retry) { - LOG(WARNING) << "Connect to (failed): [" << tracker_uri << "]\n"; - utils::Socket::Error("Connect"); + LOG(FATAL) << "Connecting to (failed): [" << tracker_uri << "]\n" << rc.message(); } else { - LOG(WARNING) << "Retry connect to ip(retry time " << retry << "): [" << tracker_uri << "]\n"; -#if defined(_MSC_VER) || defined (__MINGW32__) + LOG(WARNING) << rc.message() << "\nRetry connecting to IP(retry time: " << retry << "): [" + << tracker_uri << "]"; +#if defined(_MSC_VER) || defined(__MINGW32__) Sleep(retry << 1); #else sleep(retry << 1); @@ -253,16 +253,13 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const { } while (true); using utils::Assert; - Assert(tracker.SendAll(&magic, sizeof(magic)) == sizeof(magic), - "ReConnectLink failure 1"); - Assert(tracker.RecvAll(&magic, sizeof(magic)) == sizeof(magic), - "ReConnectLink failure 2"); + CHECK_EQ(tracker.SendAll(&magic, sizeof(magic)), sizeof(magic)); + CHECK_EQ(tracker.RecvAll(&magic, sizeof(magic)), sizeof(magic)); utils::Check(magic == kMagic, "sync::Invalid tracker message, init failure"); - Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 3"); + Assert(tracker.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3"); - tracker.SendStr(task_id); + CHECK_EQ(tracker.Send(xgboost::StringView{task_id}), task_id.size()); return tracker; } /*! @@ -272,12 +269,15 @@ utils::TCPSocket AllreduceBase::ConnectTracker() const { bool AllreduceBase::ReConnectLinks(const char *cmd) { // single node mode if (tracker_uri == "NULL") { - rank = 0; world_size = 1; return true; + rank = 0; + world_size = 1; + return true; } + try { - utils::TCPSocket tracker = this->ConnectTracker(); + xgboost::collective::TCPSocket tracker = this->ConnectTracker(); LOG(INFO) << "task " << task_id << " connected to the tracker"; - tracker.SendStr(std::string(cmd)); + tracker.Send(xgboost::StringView{cmd}); // the rank of previous link, next link in ring int prev_rank, next_rank; @@ -315,13 +315,9 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), "ReConnectLink failure 4"); - utils::TCPSocket sock_listen; - if (!sock_listen.IsClosed()) { - sock_listen.Close(); - } + auto sock_listen{xgboost::collective::TCPSocket::Create(tracker.Domain())}; // create listening socket - sock_listen.Create(); - int port = sock_listen.TryBindHost(slave_port, slave_port + nport_trial); + int port = sock_listen.BindHost(); utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); sock_listen.Listen(); @@ -338,29 +334,27 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { } } int ngood = static_cast(good_link.size()); - Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), - "ReConnectLink failure 5"); - for (int & i : good_link) { - Assert(tracker.SendAll(&i, sizeof(i)) == \ - sizeof(i), "ReConnectLink failure 6"); + // tracker construct goodset + Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), "ReConnectLink failure 5"); + for (int &i : good_link) { + Assert(tracker.SendAll(&i, sizeof(i)) == sizeof(i), "ReConnectLink failure 6"); } Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), "ReConnectLink failure 7"); - Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \ - sizeof(num_accept), "ReConnectLink failure 8"); + Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), + "ReConnectLink failure 8"); num_error = 0; for (int i = 0; i < num_conn; ++i) { LinkRecord r; int hport, hrank; std::string hname; - tracker.RecvStr(&hname); - Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), - "ReConnectLink failure 9"); - Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), - "ReConnectLink failure 10"); + tracker.Recv(&hname); + Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), "ReConnectLink failure 9"); + Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), "ReConnectLink failure 10"); - r.sock.Create(); - if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { + if (xgboost::collective::Connect( + xgboost::collective::MakeSockAddress(xgboost::StringView{hname}, hport), &r.sock) != + std::errc{}) { num_error += 1; r.sock.Close(); continue; @@ -376,12 +370,12 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { if (all_link.rank == hrank) { Assert(all_link.sock.IsClosed(), "Override a link that is active"); - all_link.sock = r.sock; + all_link.sock = std::move(r.sock); match = true; break; } } - if (!match) all_links.push_back(r); + if (!match) all_links.emplace_back(std::move(r)); } Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), "ReConnectLink failure 14"); @@ -404,30 +398,24 @@ bool AllreduceBase::ReConnectLinks(const char *cmd) { if (all_link.rank == r.rank) { utils::Assert(all_link.sock.IsClosed(), "Override a link that is active"); - all_link.sock = r.sock; + all_link.sock = std::move(r.sock); match = true; break; } } - if (!match) all_links.push_back(r); + if (!match) all_links.emplace_back(std::move(r)); } sock_listen.Close(); this->parent_index = -1; // setup tree links and ring structure tree_links.plinks.clear(); - int tcpNoDelay = 1; - for (auto & all_link : all_links) { + for (auto &all_link : all_links) { utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket"); // set the socket to non-blocking mode, enable TCP keepalive - all_link.sock.SetNonBlock(true); - all_link.sock.SetKeepAlive(true); + all_link.sock.SetNonBlock(); + all_link.sock.SetKeepAlive(); if (rabit_enable_tcp_no_delay) { -#if defined(__unix__) - setsockopt(all_link.sock, IPPROTO_TCP, - TCP_NODELAY, reinterpret_cast(&tcpNoDelay), sizeof(tcpNoDelay)); -#else - LOG(WARNING) << "tcp no delay is not implemented on non unix platforms"; -#endif + all_link.sock.SetNoDelay(); } if (tree_neighbors.count(all_link.rank) != 0) { if (all_link.rank == parent_rank) { diff --git a/rabit/src/allreduce_base.h b/rabit/src/allreduce_base.h index 0cd11c73ffa2..a3b67c980953 100644 --- a/rabit/src/allreduce_base.h +++ b/rabit/src/allreduce_base.h @@ -201,8 +201,8 @@ class AllreduceBase : public IEngine { } }; /*! \brief translate errno to return type */ - inline static ReturnType Errno2Return() { - int errsv = utils::Socket::GetLastError(); + static ReturnType Errno2Return() { + int errsv = xgboost::system::LastError(); if (errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == 0) return kSuccess; #ifdef _WIN32 if (errsv == WSAEWOULDBLOCK) return kSuccess; @@ -215,7 +215,7 @@ class AllreduceBase : public IEngine { struct LinkRecord { public: // socket to get data from/to link - utils::TCPSocket sock; + xgboost::collective::TCPSocket sock; // rank of the node in this link int rank; // size of data readed from link @@ -329,7 +329,7 @@ class AllreduceBase : public IEngine { * \brief initialize connection to the tracker * \return a socket that initializes the connection */ - utils::TCPSocket ConnectTracker() const; + xgboost::collective::TCPSocket ConnectTracker() const; /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up @@ -473,8 +473,6 @@ class AllreduceBase : public IEngine { std::string dmlc_role; // NOLINT // port of tracker address int tracker_port; // NOLINT - // port of slave process - int slave_port, nport_trial; // NOLINT // reduce buffer size size_t reduce_buffer_size; // NOLINT // reduction method diff --git a/src/collective/socket.cc b/src/collective/socket.cc new file mode 100644 index 000000000000..1ab84cef35d8 --- /dev/null +++ b/src/collective/socket.cc @@ -0,0 +1,94 @@ +/*! + * Copyright (c) 2022 by XGBoost Contributors + */ +#include "xgboost/collective/socket.h" + +#include // std::size_t +#include // std::int32_t +#include // std::memcpy, std::memset +#include // std::error_code, std::system_category + +#if defined(__unix__) || defined(__APPLE__) +#include // getaddrinfo, freeaddrinfo +#endif // defined(__unix__) || defined(__APPLE__) + +namespace xgboost { +namespace collective { +SockAddress MakeSockAddress(StringView host, in_port_t port) { + struct addrinfo hints; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_protocol = SOCK_STREAM; + struct addrinfo *res = nullptr; + int sig = getaddrinfo(host.c_str(), nullptr, &hints, &res); + if (sig != 0) { + return {}; + } + if (res->ai_family == static_cast(SockDomain::kV4)) { + sockaddr_in addr; + std::memcpy(&addr, res->ai_addr, res->ai_addrlen); + addr.sin_port = htons(port); + auto v = SockAddrV4{addr}; + freeaddrinfo(res); + return SockAddress{v}; + } else if (res->ai_family == static_cast(SockDomain::kV6)) { + sockaddr_in6 addr; + std::memcpy(&addr, res->ai_addr, res->ai_addrlen); + + addr.sin6_port = htons(port); + auto v = SockAddrV6{addr}; + freeaddrinfo(res); + return SockAddress{v}; + } else { + LOG(FATAL) << "Failed to get addr info for: " << host; + } + + return SockAddress{}; +} + +SockAddrV4 SockAddrV4::Loopback() { return MakeSockAddress("127.0.0.1", 0).V4(); } +SockAddrV4 SockAddrV4::InaddrAny() { return MakeSockAddress("0.0.0.0", 0).V4(); } + +SockAddrV6 SockAddrV6::Loopback() { return MakeSockAddress("::1", 0).V6(); } +SockAddrV6 SockAddrV6::InaddrAny() { return MakeSockAddress("::", 0).V6(); } + +std::size_t TCPSocket::Send(StringView str) { + CHECK(!this->IsClosed()); + CHECK_LT(str.size(), std::numeric_limits::max()); + std::int32_t len = static_cast(str.size()); + CHECK_EQ(this->SendAll(&len, sizeof(len)), sizeof(len)) << "Failed to send string length."; + auto bytes = this->SendAll(str.c_str(), str.size()); + CHECK_EQ(bytes, str.size()) << "Failed to send string."; + return bytes; +} + +std::size_t TCPSocket::Recv(std::string *p_str) { + CHECK(!this->IsClosed()); + std::int32_t len; + CHECK_EQ(this->RecvAll(&len, sizeof(len)), sizeof(len)) << "Failed to recv string length."; + p_str->resize(len); + auto bytes = this->RecvAll(&(*p_str)[0], len); + CHECK_EQ(bytes, len) << "Failed to recv string."; + return bytes; +} + +std::error_code Connect(SockAddress const &addr, TCPSocket *out) { + sockaddr const *addr_handle{nullptr}; + socklen_t addr_len{0}; + if (addr.IsV4()) { + addr_handle = reinterpret_cast(&addr.V4().Handle()); + addr_len = sizeof(addr.V4().Handle()); + } else { + addr_handle = reinterpret_cast(&addr.V6().Handle()); + addr_len = sizeof(addr.V6().Handle()); + } + auto socket = TCPSocket::Create(addr.Domain()); + CHECK_EQ(static_cast(socket.Domain()), static_cast(addr.Domain())); + auto rc = connect(socket.Handle(), addr_handle, addr_len); + if (rc != 0) { + return std::error_code{errno, std::system_category()}; + } + *out = std::move(socket); + return std::make_error_code(std::errc{}); +} +} // namespace collective +} // namespace xgboost diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index f669cdbb25c4..cdaf6615bf3f 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -121,6 +121,7 @@ def print_summary_map(result_map: Dict[str, Dict[str, int]]) -> int: "python-package/xgboost/sklearn.py", "python-package/xgboost/spark", "python-package/xgboost/federated.py", + "python-package/xgboost/testing.py", # tests "tests/python/test_config.py", "tests/python/test_spark/", diff --git a/tests/cpp/collective/test_socket.cc b/tests/cpp/collective/test_socket.cc new file mode 100644 index 000000000000..571e95f4deb8 --- /dev/null +++ b/tests/cpp/collective/test_socket.cc @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2022 by XGBoost Contributors + */ +#include +#include + +#include // EADDRNOTAVAIL +#include // ifstream +#include // std::error_code, std::system_category + +#include "../helpers.h" + +namespace xgboost { +namespace collective { +TEST(Socket, Basic) { + system::SocketStartup(); + + SockAddress addr{SockAddrV6::Loopback()}; + ASSERT_TRUE(addr.IsV6()); + addr = SockAddress{SockAddrV4::Loopback()}; + ASSERT_TRUE(addr.IsV4()); + + std::string msg{"Skipping IPv6 test"}; + + auto run_test = [msg](SockDomain domain) { + auto server = TCPSocket::Create(domain); + ASSERT_EQ(server.Domain(), domain); + auto port = server.BindHost(); + server.Listen(); + + TCPSocket client; + if (domain == SockDomain::kV4) { + auto const& addr = SockAddrV4::Loopback().Addr(); + ASSERT_EQ(Connect(MakeSockAddress(StringView{addr}, port), &client), std::errc{}); + } else { + auto const& addr = SockAddrV6::Loopback().Addr(); + auto rc = Connect(MakeSockAddress(StringView{addr}, port), &client); + // some environment (docker) has restricted network configuration. + if (rc == std::error_code{EADDRNOTAVAIL, std::system_category()}) { + GTEST_SKIP_(msg.c_str()); + } + ASSERT_EQ(rc, std::errc{}); + } + ASSERT_EQ(client.Domain(), domain); + + auto accepted = server.Accept(); + StringView msg{"Hello world."}; + accepted.Send(msg); + + std::string str; + client.Recv(&str); + ASSERT_EQ(StringView{str}, msg); + }; + + run_test(SockDomain::kV4); + + std::string path{"/sys/module/ipv6/parameters/disable"}; + if (FileExists(path)) { + std::ifstream fin(path); + if (!fin) { + GTEST_SKIP_(msg.c_str()); + } + std::string s_value; + fin >> s_value; + auto value = std::stoi(s_value); + if (value != 0) { + GTEST_SKIP_(msg.c_str()); + } + } else { + GTEST_SKIP_(msg.c_str()); + } + run_test(SockDomain::kV6); + + system::SocketFinalize(); +} +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/filesystem.h b/tests/cpp/filesystem.h index 5410feede316..c8d144291b0f 100644 --- a/tests/cpp/filesystem.h +++ b/tests/cpp/filesystem.h @@ -1,7 +1,6 @@ /*! * Copyright (c) 2022 by XGBoost Contributors */ - #ifndef XGBOOST_TESTS_CPP_FILESYSTEM_H #define XGBOOST_TESTS_CPP_FILESYSTEM_H diff --git a/tests/cpp/tree/test_partitioner.h b/tests/cpp/tree/test_partitioner.h index 109749a2832f..093aa69ebdd4 100644 --- a/tests/cpp/tree/test_partitioner.h +++ b/tests/cpp/tree/test_partitioner.h @@ -1,8 +1,12 @@ /*! * Copyright 2021-2022, XGBoost contributors. */ +#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ +#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_ #include + #include + #include "../../../src/tree/hist/expand_entry.h" namespace xgboost { @@ -19,3 +23,4 @@ inline void GetSplit(RegTree *tree, float split_value, std::vector None: from distributed import Client, LocalCluster from test_with_dask import _get_client_workers