diff --git a/CMakeLists.txt b/CMakeLists.txt index fd087a68f208..64aeae29c737 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,6 +4,7 @@ include(cmake/Utils.cmake) list(APPEND CMAKE_MODULE_PATH "${xgboost_SOURCE_DIR}/cmake/modules") cmake_policy(SET CMP0022 NEW) cmake_policy(SET CMP0079 NEW) +cmake_policy(SET CMP0076 NEW) set(CMAKE_POLICY_DEFAULT_CMP0063 NEW) cmake_policy(SET CMP0063 NEW) @@ -117,6 +118,20 @@ endif (BUILD_STATIC_LIB AND (R_LIB OR JVM_BINDINGS)) if (PLUGIN_RMM AND (NOT BUILD_WITH_CUDA_CUB)) message(SEND_ERROR "Cannot build with RMM using cub submodule.") endif (PLUGIN_RMM AND (NOT BUILD_WITH_CUDA_CUB)) +if (PLUGIN_FEDERATED) + if (CMAKE_CROSSCOMPILING) + message(SEND_ERROR "Cannot cross compile with federated learning support") + endif () + if (BUILD_STATIC_LIB) + message(SEND_ERROR "Cannot build static lib with federated learning support") + endif () + if (R_LIB OR JVM_BINDINGS) + message(SEND_ERROR "Cannot enable federated learning support when R or JVM packages are enabled.") + endif () + if (WIN32) + message(SEND_ERROR "Federated learning not supported for Windows platform") + endif () +endif () #-- Sanitizer if (USE_SANITIZER) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 8dbf8227f887..4f0d1b05c2b8 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -1,7 +1,9 @@ # gRPC needs to be installed first. See README.md. +set(protobuf_MODULE_COMPATIBLE TRUE) +set(protobuf_BUILD_SHARED_LIBS TRUE) find_package(Protobuf CONFIG REQUIRED) find_package(gRPC CONFIG REQUIRED) -find_package(Threads) +message(STATUS "Found gRPC: ${gRPC_CONFIG}") # Generated code from the protobuf definition. add_library(federated_proto federated.proto) @@ -9,13 +11,16 @@ target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gR target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) xgboost_target_properties(federated_proto) -get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) -protobuf_generate(TARGET federated_proto LANGUAGE cpp) +protobuf_generate( + TARGET federated_proto + LANGUAGE cpp + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") protobuf_generate( TARGET federated_proto LANGUAGE grpc GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc - PLUGIN "protoc-gen-grpc=${grpc_cpp_plugin_location}") + PLUGIN "protoc-gen-grpc=\$" + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") # Wrapper for the gRPC client. add_library(federated_client INTERFACE) diff --git a/plugin/federated/README.md b/plugin/federated/README.md index 5858d7cebf50..061cb77149d0 100644 --- a/plugin/federated/README.md +++ b/plugin/federated/README.md @@ -5,14 +5,7 @@ This folder contains the plugin for federated learning. Follow these steps to bu Install gRPC ------------ -```shell -sudo apt-get install build-essential autoconf libtool pkg-config cmake ninja-build -git clone -b v1.47.0 https://github.com/grpc/grpc -cd grpc -git submodule update --init -cmake -S . -B build -GNinja -DABSL_PROPAGATE_CXX_STD=ON -cmake --build build --target install -``` +Refer to the [installation guide from the gRPC website](https://grpc.io/docs/languages/cpp/quickstart/). Build the Plugin ---------------- @@ -20,16 +13,16 @@ Build the Plugin # Under xgboost source tree. mkdir build cd build -# For now NCCL needs to be turned off. -cmake .. -GNinja\ - -DPLUGIN_FEDERATED=ON\ +cmake .. -GNinja \ + -DPLUGIN_FEDERATED=ON \ + -DBUILD_WITH_CUDA_CUB=ON \ -DUSE_CUDA=ON\ - -DBUILD_WITH_CUDA_CUB=ON\ - -DUSE_NCCL=OFF + -DUSE_NCCL=ON ninja cd ../python-package pip install -e . # or equivalently python setup.py develop ``` +If CMake fails to locate gRPC, you may need to pass `-DCMAKE_PREFIX_PATH=` to CMake. Test Federated XGBoost ---------------------- diff --git a/tests/buildkite/build-containers.sh b/tests/buildkite/build-containers.sh index b12da2a634ed..41a13eaea5fb 100755 --- a/tests/buildkite/build-containers.sh +++ b/tests/buildkite/build-containers.sh @@ -6,7 +6,7 @@ set -x if [ "$#" -lt 1 ] then echo "Usage: $0 [container to build]" - return 1 + exit 1 fi container=$1 @@ -17,18 +17,21 @@ echo "--- Build container ${container}" BUILD_ARGS="" case "${container}" in + cpu) + ;; + gpu|rmm) BUILD_ARGS="$BUILD_ARGS --build-arg CUDA_VERSION_ARG=$CUDA_VERSION" BUILD_ARGS="$BUILD_ARGS --build-arg RAPIDS_VERSION_ARG=$RAPIDS_VERSION" ;; - jvm_gpu_build) + gpu_build_centos7|jvm_gpu_build) BUILD_ARGS="$BUILD_ARGS --build-arg CUDA_VERSION_ARG=$CUDA_VERSION" ;; *) echo "Unrecognized container ID: ${container}" - return 2 + exit 2 ;; esac diff --git a/tests/buildkite/build-cpu.sh b/tests/buildkite/build-cpu.sh index 60c84c52ccfb..88da7d39504a 100755 --- a/tests/buildkite/build-cpu.sh +++ b/tests/buildkite/build-cpu.sh @@ -14,7 +14,8 @@ $command_wrapper rm -fv dmlc-core/include/dmlc/build_config_default.h # the configured header build/dmlc/build_config.h instead of # include/dmlc/build_config_default.h. echo "--- Build libxgboost from the source" -$command_wrapper tests/ci_build/build_via_cmake.sh -DPLUGIN_DENSE_PARSER=ON +$command_wrapper tests/ci_build/build_via_cmake.sh -DCMAKE_PREFIX_PATH=/opt/grpc \ + -DPLUGIN_DENSE_PARSER=ON -DPLUGIN_FEDERATED=ON echo "--- Run Google Test" $command_wrapper bash -c "cd build && ctest --extra-verbose" echo "--- Stash XGBoost CLI executable" diff --git a/tests/buildkite/build-cuda.sh b/tests/buildkite/build-cuda.sh index f8efb0853c11..a50963f7c7fc 100755 --- a/tests/buildkite/build-cuda.sh +++ b/tests/buildkite/build-cuda.sh @@ -20,10 +20,10 @@ command_wrapper="tests/ci_build/ci_build.sh gpu_build_centos7 docker --build-arg echo "--- Build libxgboost from the source" $command_wrapper tests/ci_build/prune_libnccl.sh -$command_wrapper tests/ci_build/build_via_cmake.sh -DUSE_CUDA=ON -DUSE_NCCL=ON \ - -DUSE_OPENMP=ON -DHIDE_CXX_SYMBOLS=ON -DUSE_NCCL_LIB_PATH=ON \ - -DNCCL_INCLUDE_DIR=/usr/include -DNCCL_LIBRARY=/workspace/libnccl_static.a \ - ${arch_flag} +$command_wrapper tests/ci_build/build_via_cmake.sh -DCMAKE_PREFIX_PATH=/opt/grpc \ + -DUSE_CUDA=ON -DUSE_NCCL=ON -DUSE_OPENMP=ON -DHIDE_CXX_SYMBOLS=ON -DPLUGIN_FEDERATED=ON \ + -DUSE_NCCL_LIB_PATH=ON -DNCCL_INCLUDE_DIR=/usr/include \ + -DNCCL_LIBRARY=/workspace/libnccl_static.a ${arch_flag} echo "--- Build binary wheel" $command_wrapper bash -c \ "cd python-package && rm -rf dist/* && python setup.py bdist_wheel --universal" diff --git a/tests/buildkite/pipeline-mgpu.yml b/tests/buildkite/pipeline-mgpu.yml index 690027da5009..75d7855b6dc9 100644 --- a/tests/buildkite/pipeline-mgpu.yml +++ b/tests/buildkite/pipeline-mgpu.yml @@ -17,6 +17,7 @@ steps: - label: ":docker: Build containers" commands: - "tests/buildkite/build-containers.sh gpu" + - "tests/buildkite/build-containers.sh gpu_build_centos7" - "tests/buildkite/build-containers.sh jvm_gpu_build" key: build-containers agents: diff --git a/tests/buildkite/pipeline.yml b/tests/buildkite/pipeline.yml index 8d6ab86d95f6..e2a4fcaf2405 100644 --- a/tests/buildkite/pipeline.yml +++ b/tests/buildkite/pipeline.yml @@ -13,7 +13,9 @@ steps: #### -------- CONTAINER BUILD -------- - label: ":docker: Build containers" commands: + - "tests/buildkite/build-containers.sh cpu" - "tests/buildkite/build-containers.sh gpu" + - "tests/buildkite/build-containers.sh gpu_build_centos7" - "tests/buildkite/build-containers.sh rmm" key: build-containers agents: diff --git a/tests/ci_build/Dockerfile.cpu b/tests/ci_build/Dockerfile.cpu index 786ab834b014..5111f4d00184 100644 --- a/tests/ci_build/Dockerfile.cpu +++ b/tests/ci_build/Dockerfile.cpu @@ -26,6 +26,15 @@ ENV CPP=cpp-8 ENV GOSU_VERSION 1.10 ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/ +# Install gRPC +RUN git clone -b v1.49.1 https://github.com/grpc/grpc.git \ + --recurse-submodules --depth 1 --shallow-submodules && \ + pushd grpc && \ + cmake -S . -B build -GNinja -DCMAKE_INSTALL_PREFIX=/opt/grpc && \ + cmake --build build --target install && \ + popd && \ + rm -rf grpc + # Create new Conda environment COPY conda_env/cpu_test.yml /scripts/ RUN mamba env create -n cpu_test --file=/scripts/cpu_test.yml diff --git a/tests/ci_build/Dockerfile.gpu_build b/tests/ci_build/Dockerfile.gpu_build deleted file mode 100644 index 0d9f6a27c5ea..000000000000 --- a/tests/ci_build/Dockerfile.gpu_build +++ /dev/null @@ -1,49 +0,0 @@ -ARG CUDA_VERSION_ARG -FROM nvidia/cuda:$CUDA_VERSION_ARG-devel-ubuntu16.04 -ARG CUDA_VERSION_ARG - -# Environment -ENV DEBIAN_FRONTEND noninteractive -SHELL ["/bin/bash", "-c"] # Use Bash as shell - -# Install all basic requirements -RUN \ - apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/3bf863cc.pub && \ - apt-get update && \ - apt-get install -y software-properties-common && \ - add-apt-repository ppa:ubuntu-toolchain-r/test && \ - apt-get update && \ - apt-get install -y tar unzip wget bzip2 libgomp1 git build-essential doxygen graphviz llvm libasan2 libidn11 ninja-build gcc-8 g++-8 && \ - # CMake - wget -nv -nc https://cmake.org/files/v3.14/cmake-3.14.0-Linux-x86_64.sh --no-check-certificate && \ - bash cmake-3.14.0-Linux-x86_64.sh --skip-license --prefix=/usr && \ - # Python - wget -nv -O Miniconda3.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - bash Miniconda3.sh -b -p /opt/python - -# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) -RUN \ - export CUDA_SHORT=`echo $CUDA_VERSION_ARG | grep -o -E '[0-9]+\.[0-9]'` && \ - export NCCL_VERSION=2.13.4-1 && \ - apt-get update && \ - apt-get install -y --allow-downgrades --allow-change-held-packages libnccl2=${NCCL_VERSION}+cuda${CUDA_SHORT} libnccl-dev=${NCCL_VERSION}+cuda${CUDA_SHORT} - -ENV PATH=/opt/python/bin:$PATH -ENV CC=gcc-8 -ENV CXX=g++-8 -ENV CPP=cpp-8 - -ENV GOSU_VERSION 1.10 - -# Install lightweight sudo (not bound to TTY) -RUN set -ex; \ - wget -nv -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \ - chmod +x /usr/local/bin/gosu && \ - gosu nobody true - -# Default entry-point to use if running locally -# It will preserve attributes of created files -COPY entrypoint.sh /scripts/ - -WORKDIR /workspace -ENTRYPOINT ["/scripts/entrypoint.sh"] diff --git a/tests/ci_build/Dockerfile.gpu_build_centos7 b/tests/ci_build/Dockerfile.gpu_build_centos7 index d92bb4984b0e..b6b38575bd6a 100644 --- a/tests/ci_build/Dockerfile.gpu_build_centos7 +++ b/tests/ci_build/Dockerfile.gpu_build_centos7 @@ -35,6 +35,15 @@ ENV CPP=/opt/rh/devtoolset-8/root/usr/bin/cpp ENV GOSU_VERSION 1.10 +# Install gRPC +RUN git clone -b v1.49.1 https://github.com/grpc/grpc.git \ + --recurse-submodules --depth 1 && \ + pushd grpc && \ + cmake -S . -B build -GNinja -DCMAKE_INSTALL_PREFIX=/opt/grpc && \ + cmake --build build --target install && \ + popd && \ + rm -rf grpc + # Install lightweight sudo (not bound to TTY) RUN set -ex; \ wget -nv -nc -O /usr/local/bin/gosu "https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-amd64" && \ diff --git a/tests/cpp/plugin/helpers.cc b/tests/cpp/plugin/helpers.cc new file mode 100644 index 000000000000..a70479b1bb1c --- /dev/null +++ b/tests/cpp/plugin/helpers.cc @@ -0,0 +1,19 @@ +#include +#include +#include +#include + +#include "helpers.h" + +using namespace std::chrono_literals; + +int GenerateRandomPort(int low, int high) { + // Ensure unique timestamp by introducing a small artificial delay + std::this_thread::sleep_for(100ms); + auto timestamp = static_cast(std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()).count()); + std::mt19937_64 rng(timestamp); + std::uniform_int_distribution dist(low, high); + int port = dist(rng); + return port; +} diff --git a/tests/cpp/plugin/helpers.h b/tests/cpp/plugin/helpers.h new file mode 100644 index 000000000000..ea72f1538af6 --- /dev/null +++ b/tests/cpp/plugin/helpers.h @@ -0,0 +1,10 @@ +/*! + * Copyright 2022 XGBoost contributors + */ + +#ifndef XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ +#define XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ + +int GenerateRandomPort(int low, int high); + +#endif // XGBOOST_TESTS_CPP_PLUGIN_HELPERS_H_ diff --git a/tests/cpp/plugin/test_federated_adapter.cu b/tests/cpp/plugin/test_federated_adapter.cu index 09187f940c5f..794c60909e76 100644 --- a/tests/cpp/plugin/test_federated_adapter.cu +++ b/tests/cpp/plugin/test_federated_adapter.cu @@ -5,24 +5,36 @@ #include #include +#include #include +#include +#include "./helpers.h" #include "../../../plugin/federated/federated_communicator.h" #include "../../../plugin/federated/federated_server.h" #include "../../../src/collective/device_communicator_adapter.cuh" +namespace { + +std::string GetServerAddress() { + int port = GenerateRandomPort(50000, 60000); + std::string address = std::string("localhost:") + std::to_string(port); + return address; +} + +} // anonymous namespace + namespace xgboost { namespace collective { -std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) - class FederatedAdapterTest : public ::testing::Test { protected: void SetUp() override { + server_address_ = GetServerAddress(); server_thread_.reset(new std::thread([this] { grpc::ServerBuilder builder; federated::FederatedService service{kWorldSize}; - builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); builder.RegisterService(&service); server_ = builder.BuildAndStart(); server_->Wait(); @@ -35,6 +47,7 @@ class FederatedAdapterTest : public ::testing::Test { } static int const kWorldSize{2}; + std::string server_address_; std::unique_ptr server_thread_; std::unique_ptr server_; }; @@ -52,9 +65,10 @@ TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) { TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread([rank] { - FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; - DeviceCommunicatorAdapter adapter{rank, &comm}; + threads.emplace_back(std::thread([rank, server_address=server_address_] { + FederatedCommunicator comm{kWorldSize, rank, server_address}; + // Assign device 0 to all workers, since we run gtest in a single-GPU machine + DeviceCommunicatorAdapter adapter{0, &comm}; int const count = 3; thrust::device_vector buffer(count, 0); thrust::sequence(buffer.begin(), buffer.end()); @@ -74,9 +88,10 @@ TEST_F(FederatedAdapterTest, DeviceAllReduceSum) { TEST_F(FederatedAdapterTest, DeviceAllGatherV) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread([rank] { - FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; - DeviceCommunicatorAdapter adapter{rank, &comm}; + threads.emplace_back(std::thread([rank, server_address=server_address_] { + FederatedCommunicator comm{kWorldSize, rank, server_address}; + // Assign device 0 to all workers, since we run gtest in a single-GPU machine + DeviceCommunicatorAdapter adapter{0, &comm}; int const count = rank + 2; thrust::device_vector buffer(count, 0); diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 3cfa15fb1eb7..51d258f02b57 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -5,34 +5,46 @@ #include #include +#include #include +#include +#include "helpers.h" #include "../../../plugin/federated/federated_communicator.h" #include "../../../plugin/federated/federated_server.h" +namespace { + +std::string GetServerAddress() { + int port = GenerateRandomPort(50000, 60000); + std::string address = std::string("localhost:") + std::to_string(port); + return address; +} + +} // anonymous namespace + namespace xgboost { namespace collective { -std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) - class FederatedCommunicatorTest : public ::testing::Test { public: - static void VerifyAllreduce(int rank) { - FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + static void VerifyAllreduce(int rank, const std::string& server_address) { + FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckAllreduce(comm); } - static void VerifyBroadcast(int rank) { - FederatedCommunicator comm{kWorldSize, rank, kServerAddress}; + static void VerifyBroadcast(int rank, const std::string& server_address) { + FederatedCommunicator comm{kWorldSize, rank, server_address}; CheckBroadcast(comm, rank); } protected: void SetUp() override { + server_address_ = GetServerAddress(); server_thread_.reset(new std::thread([this] { grpc::ServerBuilder builder; federated::FederatedService service{kWorldSize}; - builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); builder.RegisterService(&service); server_ = builder.BuildAndStart(); server_->Wait(); @@ -66,29 +78,40 @@ class FederatedCommunicatorTest : public ::testing::Test { } static int const kWorldSize{3}; + std::string server_address_; std::unique_ptr server_thread_; std::unique_ptr server_; }; TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { - auto construct = []() { FederatedCommunicator comm{0, 0, kServerAddress, "", "", ""}; }; + std::string server_address{GetServerAddress()}; + auto construct = [server_address]() { + FederatedCommunicator comm{0, 0, server_address, "", "", ""}; + }; EXPECT_THROW(construct(), dmlc::Error); } TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) { - auto construct = []() { FederatedCommunicator comm{1, -1, kServerAddress, "", "", ""}; }; + std::string server_address{GetServerAddress()}; + auto construct = [server_address]() { + FederatedCommunicator comm{1, -1, server_address, "", "", ""}; + }; EXPECT_THROW(construct(), dmlc::Error); } TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { - auto construct = []() { FederatedCommunicator comm{1, 1, kServerAddress, "", "", ""}; }; + std::string server_address{GetServerAddress()}; + auto construct = [server_address]() { + FederatedCommunicator comm{1, 1, server_address, "", "", ""}; + }; EXPECT_THROW(construct(), dmlc::Error); } TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { - auto construct = []() { + std::string server_address{GetServerAddress()}; + auto construct = [server_address]() { Json config{JsonObject()}; - config["federated_server_address"] = kServerAddress; + config["federated_server_address"] = server_address; config["federated_world_size"] = std::string("1"); config["federated_rank"] = Integer(0); auto *comm = FederatedCommunicator::Create(config); @@ -97,9 +120,10 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { } TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { - auto construct = []() { + std::string server_address{GetServerAddress()}; + auto construct = [server_address]() { Json config{JsonObject()}; - config["federated_server_address"] = kServerAddress; + config["federated_server_address"] = server_address; config["federated_world_size"] = 1; config["federated_rank"] = std::string("0"); auto *comm = FederatedCommunicator::Create(config); @@ -108,20 +132,23 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { } TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { - FederatedCommunicator comm{6, 3, kServerAddress}; + std::string server_address{GetServerAddress()}; + FederatedCommunicator comm{6, 3, server_address}; EXPECT_EQ(comm.GetWorldSize(), 6); EXPECT_EQ(comm.GetRank(), 3); } TEST(FederatedCommunicatorSimpleTest, IsDistributed) { - FederatedCommunicator comm{2, 1, kServerAddress}; + std::string server_address{GetServerAddress()}; + FederatedCommunicator comm{2, 1, server_address}; EXPECT_TRUE(comm.IsDistributed()); } TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank)); + threads.emplace_back( + std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank, server_address_)); } for (auto &thread : threads) { thread.join(); @@ -131,7 +158,8 @@ TEST_F(FederatedCommunicatorTest, Allreduce) { TEST_F(FederatedCommunicatorTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank)); + threads.emplace_back( + std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank, server_address_)); } for (auto &thread : threads) { thread.join(); diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 1c3e4f0bc84c..2e7afe5a294d 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -4,32 +4,45 @@ #include #include +#include #include +#include +#include "helpers.h" #include "federated_client.h" #include "federated_server.h" +namespace { + +std::string GetServerAddress() { + int port = GenerateRandomPort(50000, 60000); + std::string address = std::string("localhost:") + std::to_string(port); + return address; +} + +} // anonymous namespace + namespace xgboost { class FederatedServerTest : public ::testing::Test { public: - static void VerifyAllgather(int rank) { - federated::FederatedClient client{kServerAddress, rank}; + static void VerifyAllgather(int rank, const std::string& server_address) { + federated::FederatedClient client{server_address, rank}; CheckAllgather(client, rank); } - static void VerifyAllreduce(int rank) { - federated::FederatedClient client{kServerAddress, rank}; + static void VerifyAllreduce(int rank, const std::string& server_address) { + federated::FederatedClient client{server_address, rank}; CheckAllreduce(client); } - static void VerifyBroadcast(int rank) { - federated::FederatedClient client{kServerAddress, rank}; + static void VerifyBroadcast(int rank, const std::string& server_address) { + federated::FederatedClient client{server_address, rank}; CheckBroadcast(client, rank); } - static void VerifyMixture(int rank) { - federated::FederatedClient client{kServerAddress, rank}; + static void VerifyMixture(int rank, const std::string& server_address) { + federated::FederatedClient client{server_address, rank}; for (auto i = 0; i < 10; i++) { CheckAllgather(client, rank); CheckAllreduce(client); @@ -39,10 +52,11 @@ class FederatedServerTest : public ::testing::Test { protected: void SetUp() override { + server_address_ = GetServerAddress(); server_thread_.reset(new std::thread([this] { grpc::ServerBuilder builder; federated::FederatedService service{kWorldSize}; - builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials()); + builder.AddListeningPort(server_address_, grpc::InsecureServerCredentials()); builder.RegisterService(&service); server_ = builder.BuildAndStart(); server_->Wait(); @@ -80,17 +94,15 @@ class FederatedServerTest : public ::testing::Test { } static int const kWorldSize{3}; - static std::string const kServerAddress; + std::string server_address_; std::unique_ptr server_thread_; std::unique_ptr server_; }; -std::string const FederatedServerTest::kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp) - TEST_F(FederatedServerTest, Allgather) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank)); + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_)); } for (auto& thread : threads) { thread.join(); @@ -100,7 +112,7 @@ TEST_F(FederatedServerTest, Allgather) { TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank)); + threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllreduce, rank, server_address_)); } for (auto& thread : threads) { thread.join(); @@ -110,7 +122,7 @@ TEST_F(FederatedServerTest, Allreduce) { TEST_F(FederatedServerTest, Broadcast) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank)); + threads.emplace_back(std::thread(&FederatedServerTest::VerifyBroadcast, rank, server_address_)); } for (auto& thread : threads) { thread.join(); @@ -120,7 +132,7 @@ TEST_F(FederatedServerTest, Broadcast) { TEST_F(FederatedServerTest, Mixture) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank)); + threads.emplace_back(std::thread(&FederatedServerTest::VerifyMixture, rank, server_address_)); } for (auto& thread : threads) { thread.join();