Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement secure boost scheme phase 2 - enable secure xgboost via processor interface #10124

Open
wants to merge 62 commits into
base: vertical-federated-learning
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
8570ba5
Add additional data split mode to cover the secure vertical pipeline
ZiyueXu77 Jan 31, 2024
2d00db6
Add IsSecure info and update corresponding functions
ZiyueXu77 Jan 31, 2024
ab17f5a
Modify evaluate_splits to block non-label owners to perform hist comp…
ZiyueXu77 Jan 31, 2024
fb1787c
Continue using Allgather for best split sync for secure vertical, equ…
ZiyueXu77 Feb 2, 2024
7a2a2b8
Modify histogram sync scheme for secure vertical case, can identify g…
ZiyueXu77 Feb 6, 2024
3ca3142
Sync cut informaiton across clients, full pipeline works for testing …
ZiyueXu77 Feb 7, 2024
22dd522
Code cleanup, phase 1 of alternative vertical pipeline finished
ZiyueXu77 Feb 8, 2024
52e8951
Code clean
ZiyueXu77 Feb 8, 2024
e9eef15
change kColS to kColSecure to avoid confusion with kCols
ZiyueXu77 Feb 12, 2024
91c8a2f
Replace allreduce with allgather, functional but inefficient version
ZiyueXu77 Feb 13, 2024
8340c26
Update AllGather behavior from individual pair to bulk by adopting ne…
ZiyueXu77 Feb 13, 2024
42a9df1
comment out the record printing
ZiyueXu77 Feb 13, 2024
41e5abb
fix pointer bug for histsync with allgather
ZiyueXu77 Feb 20, 2024
ea5dc98
Merge branch 'dmlc:master' into SecureBoostP2
ZiyueXu77 Feb 20, 2024
5d542f8
Merge branch 'dmlc:master' into SecureBoostP2
ZiyueXu77 Feb 23, 2024
d91be10
identify the HE adding locations
ZiyueXu77 Feb 23, 2024
dd60317
revise and simplify template code
ZiyueXu77 Mar 6, 2024
8da824c
revise and simplify template code
ZiyueXu77 Mar 6, 2024
fb9f4fa
prepare aggregator for gh broadcast
ZiyueXu77 Mar 13, 2024
e77f8c6
prepare histogram for histindex and row index for allgather
ZiyueXu77 Mar 14, 2024
7ef48c8
Merge branch 'vertical-federated-learning' into SecureBoostP2
ZiyueXu77 Mar 15, 2024
8405791
fix conflicts
ZiyueXu77 Mar 15, 2024
db7d518
fix conflicts
ZiyueXu77 Mar 15, 2024
dd6adde
fix format
ZiyueXu77 Mar 15, 2024
9567e67
fix allgather logic and update unit test
ZiyueXu77 Mar 19, 2024
53800f2
fix linting
ZiyueXu77 Mar 19, 2024
b7e70f1
fix linting and other unit test issues
ZiyueXu77 Mar 20, 2024
49e8fd6
fix linting and other unit test issues
ZiyueXu77 Mar 20, 2024
da0f7a6
integration with interface initial attempt
ZiyueXu77 Mar 22, 2024
406cda3
integration with interface initial attempt
ZiyueXu77 Mar 22, 2024
f6c63aa
integration with interface initial attempt
ZiyueXu77 Mar 22, 2024
f223df7
functional integration with interface
ZiyueXu77 Apr 1, 2024
d881d84
remove debugging prints
ZiyueXu77 Apr 1, 2024
2997cf7
remove processor from another PR
ZiyueXu77 Apr 1, 2024
3a1f9ac
Update the processor functions according to new processor implementation
ZiyueXu77 Apr 12, 2024
1107604
Move processor interface init from learner to communicator
ZiyueXu77 Apr 12, 2024
30b7ed5
Move processor interface init from learner to communicator functional
ZiyueXu77 Apr 12, 2024
a3ddf7d
switch to allgatherV for encrypted message with varying lenghts
ZiyueXu77 Apr 15, 2024
3123b51
consolidate with processor interface PR
ZiyueXu77 Apr 19, 2024
73225a0
remove prints and fix format
ZiyueXu77 Apr 23, 2024
e85b1fb
fix linting over reference pass
ZiyueXu77 Apr 24, 2024
57750b4
fix undefined symbol issue
ZiyueXu77 Apr 24, 2024
fa2665a
fix processor test
ZiyueXu77 Apr 24, 2024
87d2fdb
secure vertical relies on processor, move the unit test
ZiyueXu77 Apr 24, 2024
9941293
type correction
ZiyueXu77 Apr 24, 2024
dd4f440
type correction
ZiyueXu77 Apr 24, 2024
5b2dfe6
extra linting from last change
ZiyueXu77 Apr 24, 2024
80d3b89
Added Windows support
nvidianz Apr 24, 2024
184b67f
Merge pull request #4 from nvidianz/processor-windows-support
ZiyueXu77 Apr 25, 2024
3382707
fix for cstdint types
ZiyueXu77 Apr 25, 2024
2a8f19a
fix for cstdint types
ZiyueXu77 Apr 25, 2024
9ff2935
Added support for horizontal secure XGBoost
nvidianz Apr 25, 2024
38e9d3d
Merge pull request #5 from nvidianz/processor-horizontal-support
ZiyueXu77 Apr 25, 2024
82ad9a8
Merge branch 'vertical-federated-learning' into SecureBoostP2
ZiyueXu77 Apr 29, 2024
6418503
remove horizontal funcs from this PR
ZiyueXu77 Apr 29, 2024
3a86daa
change loader and proc params input pattern to align with std map
ZiyueXu77 Apr 29, 2024
f3967c5
add processor shutdown
ZiyueXu77 May 10, 2024
20bb965
move processor shutdown
ZiyueXu77 May 10, 2024
0be6129
fix memory leakage in processor test
ZiyueXu77 May 15, 2024
47176d8
Merge branch 'vertical-federated-learning' into SecureBoostP2
ZiyueXu77 May 16, 2024
6d1bbe7
fix double free issue
ZiyueXu77 May 17, 2024
3aa64b3
linting update
ZiyueXu77 May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 42 additions & 3 deletions src/collective/aggregator.h
Expand Up @@ -14,6 +14,7 @@
#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo
#include "../processing/processor.h" // for Processor

namespace xgboost::collective {

Expand Down Expand Up @@ -69,7 +70,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::si
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
template <bool is_gpair, typename T, typename Function>
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
if (info.IsVerticalFederated()) {
Expand All @@ -96,8 +97,46 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
}
collective::Broadcast(&size, sizeof(std::size_t), 0);

result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
if (info.IsSecure() && is_gpair) {
// Under secure mode, gpairs will be processed to vector and encrypt
// information only available on rank 0
std::size_t buffer_size{};
std::int8_t *buffer;
if (collective::GetRank() == 0) {
std::vector<double> vector_gh;
for (std::size_t i = 0; i < size; i++) {
auto gpair = result->HostVector()[i];
// cast from GradientPair to float pointer
auto gpair_ptr = reinterpret_cast<float*>(&gpair);
// save to vector
vector_gh.push_back(gpair_ptr[0]);
vector_gh.push_back(gpair_ptr[1]);
}
// provide the vectors to the processor interface
size_t size;
auto buf = processor_instance->ProcessGHPairs(&size, vector_gh);
buffer_size = size;
buffer = reinterpret_cast<std::int8_t *>(buf);
}

// broadcast the buffer size for other ranks to prepare
collective::Broadcast(&buffer_size, sizeof(std::size_t), 0);
// prepare buffer on passive parties for satisfying broadcast mpi call
if (collective::GetRank() != 0) {
buffer = reinterpret_cast<std::int8_t *>(malloc(buffer_size));
}

// broadcast the data buffer holding processed gpairs
collective::Broadcast(buffer, buffer_size, 0);

// call HandleGHPairs
size_t size;
processor_instance->HandleGHPairs(&size, buffer, buffer_size);
} else {
// clear text mode, broadcast the data directly
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
}
} else {
std::forward<Function>(function)();
}
Expand Down
27 changes: 25 additions & 2 deletions src/collective/communicator.cc
@@ -1,6 +1,7 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <map>
#include "communicator.h"

#include "comm.h"
Expand All @@ -9,9 +10,12 @@
#include "rabit_communicator.h"

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#include "../../plugin/federated/federated_communicator.h"
#endif

#include "../processing/processor.h"
processing::Processor *processor_instance;

namespace xgboost::collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
Expand All @@ -38,7 +42,26 @@ void Communicator::Init(Json const& config) {
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
communicator_.reset(FederatedCommunicator::Create(config));
// Get processor configs
std::string plugin_name{};
std::string loader_params_key{};
std::string loader_params_map{};
std::string proc_params_key{};
std::string proc_params_map{};
plugin_name = OptionalArg<String>(config, "plugin_name", plugin_name);
loader_params_key = OptionalArg<String>(config, "loader_params_key", loader_params_key);
loader_params_map = OptionalArg<String>(config, "loader_params_map", loader_params_map);
proc_params_key = OptionalArg<String>(config, "proc_params_key", proc_params_key);
proc_params_map = OptionalArg<String>(config, "proc_params_map", proc_params_map);
// Initialize processor if plugin_name is provided
if (!plugin_name.empty()) {
std::map<std::string, std::string> loader_params = {{loader_params_key, loader_params_map}};
std::map<std::string, std::string> proc_params = {{proc_params_key, proc_params_map}};
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);
}
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
Expand Down
7 changes: 4 additions & 3 deletions src/learner.cc
Expand Up @@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1);
collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(),
collective::ApplyWithLabels<false>(this->Ctx(), info, base_score->Data(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
Expand Down Expand Up @@ -1472,8 +1472,9 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
// calculate gradient and communicate
collective::ApplyWithLabels<true>(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
}

/*! \brief random number transformation seed. */
Expand Down
174 changes: 174 additions & 0 deletions src/processing/plugins/mock_processor.cc
@@ -0,0 +1,174 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#include <iostream>
#include <cstring>
#include <cstdint>
#include "./mock_processor.h"

const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int64_t kPrefixLen = 24;

bool ValidDam(void *buffer, std::size_t size) {
return size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0;
}

void* MockProcessor::ProcessGHPairs(std::size_t *size, const std::vector<double>& pairs) {
*size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x

int64_t buf_size = *size;
// This memory needs to be freed
char *buf = static_cast<char *>(calloc(*size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &kDataTypeGHPairs, 8);

// Simulate encryption by duplicating value 10 times
int index = kPrefixLen;
for (auto value : pairs) {
for (std::size_t i = 0; i < 10; i++) {
memcpy(buf+index, &value, 8);
index += 8;
}
}

// Save pairs for future operations
this->gh_pairs_ = new std::vector<double>(pairs);

return buf;
}


void* MockProcessor::HandleGHPairs(std::size_t *size, void *buffer, std::size_t buf_size) {
*size = buf_size;
if (!ValidDam(buffer, *size)) {
return buffer;
}

// For mock, this call is used to set gh_pairs for passive sites
if (!active_) {
int8_t *ptr = static_cast<int8_t *>(buffer);
ptr += kPrefixLen;
double *pairs = reinterpret_cast<double *>(ptr);
std::size_t num = (buf_size - kPrefixLen) / 8;
gh_pairs_ = new std::vector<double>();
for (std::size_t i = 0; i < num; i += 10) {
gh_pairs_->push_back(pairs[i]);
}
}

return buffer;
}

void *MockProcessor::ProcessAggregation(std::size_t *size, std::map<int, std::vector<int>> nodes) {
int total_bin_size = cuts_.back();
int histo_size = total_bin_size*2;
*size = kPrefixLen + 8*histo_size*nodes.size();
int64_t buf_size = *size;
int8_t *buf = static_cast<int8_t *>(calloc(buf_size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &kDataTypeHisto, 8);

double *histo = reinterpret_cast<double *>(buf + kPrefixLen);
for ( const auto &node : nodes ) {
auto rows = node.second;
for (const auto &row_id : rows) {
auto num = cuts_.size() - 1;
for (std::size_t f = 0; f < num; f++) {
int slot = slots_[f + num*row_id];
if ((slot < 0) || (slot >= total_bin_size)) {
continue;
}

auto g = (*gh_pairs_)[row_id*2];
auto h = (*gh_pairs_)[row_id*2+1];
histo[slot*2] += g;
histo[slot*2+1] += h;
}
}
histo += histo_size;
}

return buf;
}

std::vector<double> MockProcessor::HandleAggregation(void *buffer, std::size_t buf_size) {
std::vector<double> result = std::vector<double>();

int8_t* ptr = static_cast<int8_t *>(buffer);
auto rest_size = buf_size;

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr, rest_size)) {
break;
}
int64_t *size_ptr = reinterpret_cast<int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
auto array_size = (*size_ptr - kPrefixLen)/8;
result.insert(result.end(), array_start, array_start + array_size);
rest_size -= *size_ptr;
ptr = ptr + *size_ptr;
}

return result;
}

void* MockProcessor::ProcessHistograms(std::size_t *size, const std::vector<double>& histograms) {
*size = kPrefixLen + histograms.size()*10*8; // Assume encrypted size is 10x

int64_t buf_size = *size;
// This memory needs to be freed
char *buf = static_cast<char *>(malloc(buf_size));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &kDataTypeAggregatedHisto, 8);

// Simulate encryption by duplicating value 10 times
int index = kPrefixLen;
for (auto value : histograms) {
for (std::size_t i = 0; i < 10; i++) {
memcpy(buf+index, &value, 8);
index += 8;
}
}

return buf;
}

std::vector<double> MockProcessor::HandleHistograms(void *buffer, std::size_t buf_size) {
std::vector<double> result = std::vector<double>();

int8_t* ptr = static_cast<int8_t *>(buffer);
auto rest_size = buf_size;

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr, rest_size)) {
break;
}
int64_t *size_ptr = reinterpret_cast<int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
auto array_size = (*size_ptr - kPrefixLen)/8;
auto empty = result.empty();
if (!empty) {
if (result.size() != array_size / 10) {
std::cout << "Histogram size doesn't match " << result.size() << " != " << array_size << std::endl;
return result;
}
}

for (std::size_t i = 0; i < array_size/10; i++) {
auto value = array_start[i*10];
if (empty) {
result.push_back(value);
} else {
result[i] += value;
}
}

rest_size -= *size_ptr;
ptr = ptr + *size_ptr;
}

return result;
}
58 changes: 58 additions & 0 deletions src/processing/plugins/mock_processor.h
@@ -0,0 +1,58 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#pragma once
#include <string>
#include <vector>
#include <map>
#include "../processor.h"

// Data type definition
const int64_t kDataTypeGHPairs = 1;
const int64_t kDataTypeHisto = 2;
const int64_t kDataTypeAggregatedHisto = 3;

class MockProcessor: public processing::Processor {
private:
bool active_ = false;
const std::map<std::string, std::string> *params_{nullptr};
std::vector<double> *gh_pairs_{nullptr};
std::vector<uint32_t> cuts_;
std::vector<int> slots_;

public:
void Initialize(bool active, std::map<std::string, std::string> params) override {
this->active_ = active;
this->params_ = &params;
}

void Shutdown() override {
this->gh_pairs_ = nullptr;
this->cuts_.clear();
this->slots_.clear();
}

void FreeBuffer(void *buffer) override {
free(buffer);
}

void* ProcessGHPairs(size_t *size, const std::vector<double>& pairs) override;

void* HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override;

void InitAggregationContext(const std::vector<uint32_t> &cuts,
const std::vector<int> &slots) override {
this->cuts_ = cuts;
if (this->slots_.empty()) {
this->slots_ = slots;
}
}

void *ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;

void *ProcessHistograms(size_t *size, const std::vector<double>& histograms) override;

std::vector<double> HandleHistograms(void *buffer, size_t buf_size) override;
};