Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement secure boost scheme - secure evaluation and validation (during training) without local feature leakage #10079

Merged
merged 40 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8570ba5
Add additional data split mode to cover the secure vertical pipeline
ZiyueXu77 Jan 31, 2024
2d00db6
Add IsSecure info and update corresponding functions
ZiyueXu77 Jan 31, 2024
ab17f5a
Modify evaluate_splits to block non-label owners to perform hist comp…
ZiyueXu77 Jan 31, 2024
fb1787c
Continue using Allgather for best split sync for secure vertical, equ…
ZiyueXu77 Feb 2, 2024
7a2a2b8
Modify histogram sync scheme for secure vertical case, can identify g…
ZiyueXu77 Feb 6, 2024
3ca3142
Sync cut informaiton across clients, full pipeline works for testing …
ZiyueXu77 Feb 7, 2024
22dd522
Code cleanup, phase 1 of alternative vertical pipeline finished
ZiyueXu77 Feb 8, 2024
52e8951
Code clean
ZiyueXu77 Feb 8, 2024
e9eef15
change kColS to kColSecure to avoid confusion with kCols
ZiyueXu77 Feb 12, 2024
70e6ca6
Add additional data split mode to cover the secure vertical pipeline
ZiyueXu77 Jan 31, 2024
a54ea6a
Add IsSecure info and update corresponding functions
ZiyueXu77 Jan 31, 2024
6fe61dd
Modify evaluate_splits to block non-label owners to perform hist comp…
ZiyueXu77 Jan 31, 2024
1c2b7ed
Continue using Allgather for best split sync for secure vertical, equ…
ZiyueXu77 Feb 2, 2024
b36ff2b
Modify histogram sync scheme for secure vertical case, can identify g…
ZiyueXu77 Feb 6, 2024
0707731
Sync cut informaiton across clients, full pipeline works for testing …
ZiyueXu77 Feb 7, 2024
dce7609
Code cleanup, phase 1 of alternative vertical pipeline finished
ZiyueXu77 Feb 8, 2024
6cebc31
Code clean
ZiyueXu77 Feb 8, 2024
1562f52
change kColS to kColSecure to avoid confusion with kCols
ZiyueXu77 Feb 12, 2024
f31c824
Add one unit test
YuanTingHsieh Feb 17, 2024
6fcbe02
Merge branch 'SecureBoost' into add_alternate_vertical_splits
ZiyueXu77 Feb 20, 2024
967e307
Merge pull request #1 from YuanTingHsieh/add_alternate_vertical_splits
ZiyueXu77 Feb 20, 2024
04cd1cb
Merge branch 'dmlc:master' into SecureBoost
ZiyueXu77 Feb 20, 2024
087a8dd
Merge branch 'dmlc:master' into SecureBoost
ZiyueXu77 Feb 23, 2024
5e85438
modify inference behavior of secure vertical from split value to inde…
ZiyueXu77 Feb 27, 2024
e008818
fix the logic for secure vertical inference, each client save a diffe…
ZiyueXu77 Feb 27, 2024
1fd1fb0
code clean
ZiyueXu77 Feb 27, 2024
72159b9
code clean
ZiyueXu77 Feb 27, 2024
069f811
code clean
ZiyueXu77 Feb 27, 2024
4e3c329
code clean
ZiyueXu77 Feb 27, 2024
d0bee2f
Merge branch 'vertical-federated-learning' into SecureBoostInf
ZiyueXu77 Mar 1, 2024
4624c3f
clean the conflicts, make sure the pipeline functions
ZiyueXu77 Mar 4, 2024
85b215d
address comments on split_value update et al
ZiyueXu77 Mar 5, 2024
cb6af9f
remove secure flags no longer needed
ZiyueXu77 Mar 5, 2024
b1ca59a
linting update
ZiyueXu77 Mar 5, 2024
2ba22dd
correction on split_value recovery, perform only for secure mode
ZiyueXu77 Mar 5, 2024
7cdde6f
Add secure inf unit tests
YuanTingHsieh Apr 15, 2024
be37fcd
Merge pull request #3 from YuanTingHsieh/add_secure_inf_unit_tests
ZiyueXu77 Apr 15, 2024
090cb1a
Merge branch 'vertical-federated-learning' into SecureBoostInf
ZiyueXu77 Apr 29, 2024
da97000
fix clang-tidy warning
ZiyueXu77 May 10, 2024
0c854a4
fix memory leakage for unit test
ZiyueXu77 May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 21 additions & 19 deletions src/common/quantile.cc
Expand Up @@ -361,21 +361,23 @@ void SketchContainerImpl<WQSketch>::AllReduce(
}

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
bool AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts, bool secure) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// Sync the required_cuts across all workers
// sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);
}
// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// if empty column, fill the cut values with 0
// if secure and empty column, fill the cut values with NaN
if (secure && (required_cuts_original == 0)) {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
cut_values.push_back(std::numeric_limits<double>::quiet_NaN());
}
return true;
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
Expand All @@ -384,6 +386,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
cut_values.push_back(cpt);
}
}
return false;
}
}

Expand Down Expand Up @@ -437,6 +440,7 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (info.IsVerticalFederated() && info.IsSecure()) {
collective::Allreduce<collective::Operation::kMax>(&max_num_bins, 1);
}
Expand All @@ -445,29 +449,27 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my understanding, categorical features are not yet supported right?

Copy link
Author

@ZiyueXu77 ZiyueXu77 Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, will need to find a proper use-case / testing data with categorical features to add the support, it seems the categorical feature is "experimental" according to some of the last year's release notes, is it still the case? maybe we can add the support later when we find it really necessary.

} else {
// use special AddCutPoint scheme for secure vertical federated learning
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
bool is_nan = AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything if the feature is not empty
// i.e. if the last value is not NaN
if (!is_nan) {
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
p_cuts->cut_values_.HostVector().push_back(last);
} else {
// if the feature is empty, push a NaN value
p_cuts->cut_values_.HostVector().push_back(std::numeric_limits<double>::quiet_NaN());
}
}

// Ensure that every feature gets at least one quantile point
CHECK_LE(p_cuts->cut_values_.HostVector().size(), std::numeric_limits<uint32_t>::max());
auto cut_size = static_cast<uint32_t>(p_cuts->cut_values_.HostVector().size());
CHECK_GT(cut_size, p_cuts->cut_ptrs_.HostVector().back());
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
44 changes: 33 additions & 11 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -303,22 +303,35 @@ class HistEvaluator {
// forward enumeration: split at right bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
GradStats{right_sum}) - parent.root_gain);
if (!is_secure_) {
split_pt = cut_val[i]; // not used for partition based
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
// secure mode: record the best split point, rather than the actual value
// since it is not accessible at this point (active party finding best-split)
best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum);
}
} else {
Comment on lines +307 to 315
Copy link
Member

@trivialfis trivialfis Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think a policy class might help here? Or maybe there are other efficient ways to handle these conditions? I'm losing track of these conditions, considering that we have three enumeration functions:

  • numeric
  • partition
  • one hot

Then we have three split modes:

  • column
  • row
  • column + secure

So, in combination, 9 potential cases, and we haven't counted vector leaf yet. Need to find a better way to manage these conditions.

Copy link
Author

@ZiyueXu77 ZiyueXu77 Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be tricky to consolidate, since the 9 cases have high overlaps (e.g. same enumeration logic for all splits modes except when secure+passive party), some further processing only for col_split (w/ w/o secure), but irrelevant to enumeration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing regarding this mode combinations: potentially with the upcoming processor interface we will be able to enable encrypted horizontal, shall we further add a row + secure mode, adding a 4th one for
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };? (or maybe there are better options?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my preference, I would have put it in the CommunicatorContext configuration for whether the channel is encrypted.

// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
if (i == imin) {
split_pt = cut.MinValues()[fidx];
GradStats{left_sum}) - parent.root_gain);
if (!is_secure_) {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
} else {
split_pt = cut_val[i - 1];
// secure mode: record the best split point, rather than the actual value
// since it is not accessible at this point (active party finding best-split)
if (i != imin) {
i = i - 1;
}
best.Update(loss_chg, fidx, i, d_step == -1, false, right_sum, left_sum);
}
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
}
Expand Down Expand Up @@ -352,7 +365,6 @@ class HistEvaluator {
}
auto evaluator = tree_evaluator_.GetEvaluator();
auto const &cut_ptrs = cut.Ptrs();

// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
Expand Down Expand Up @@ -417,6 +429,16 @@ class HistEvaluator {
all_entries[worker * entries.size() + nidx_in_set].split);
}
}
if (is_secure_) {
// At this point, all the workers have the best splits for all the nodes
// and workers can recover the actual split value with the split index
// Note that after the recovery, different workers will hold different
// split_value: real value for feature owner, NaN for others
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
auto cut_index = entries[nidx_in_set].split.split_value;
entries[nidx_in_set].split.split_value = cut.Values()[cut_index];
}
}
}
}

Expand Down
100 changes: 100 additions & 0 deletions tests/cpp/common/test_quantile.cc
Expand Up @@ -7,6 +7,7 @@

#include "../../../src/common/hist_util.h"
#include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix
#include "xgboost/context.h"

namespace xgboost::common {
Expand Down Expand Up @@ -296,6 +297,105 @@ TEST(Quantile, ColSplitSorted) {
TestColSplitQuantile<true>(kRows, kCols);
}

namespace {
template <bool use_column>
void DoTestColSplitQuantileSecure() {
Context ctx;
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
size_t cols = 2;
size_t rows = 3;

auto m = std::unique_ptr<DMatrix>{[=]() {
std::vector<float> data = {1, 1, 0.6, 0.4, 0.8};
std::vector<unsigned> row_idx = {0, 2, 0, 1, 2};
std::vector<size_t> col_ptr = {0, 2, 5};
data::CSCAdapter adapter(col_ptr.data(), row_idx.data(), data.data(), 2, 3);
std::unique_ptr<data::SimpleDMatrix> dmat(new data::SimpleDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1));
EXPECT_EQ(dmat->Info().num_col_, cols);
EXPECT_EQ(dmat->Info().num_row_, rows);
EXPECT_EQ(dmat->Info().num_nonzero_, 5);
return dmat->SliceCol(world, rank);
}()};

std::vector<bst_row_t> column_size(cols, 0);
auto const slice_size = cols / world;
auto const slice_start = slice_size * rank;
auto const slice_end = (rank == world - 1) ? cols : slice_start + slice_size;
for (auto i = slice_start; i < slice_end; i++) {
column_size[i] = rows;
}

auto const n_bins = 64;

m->Info().data_split_mode = DataSplitMode::kColSecure;
// Generate cuts for distributed environment.
HistogramCuts distributed_cuts;
{
ContainerType<use_column> sketch_distributed(
&ctx, n_bins, m->Info().feature_types.ConstHostSpan(), column_size, false);

std::vector<float> hessian(rows, 1.0);
auto hess = Span<float const>{hessian};
if (use_column) {
for (auto const& page : m->GetBatches<SortedCSCPage>(&ctx)) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
} else {
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
PushPage(&sketch_distributed, page, m->Info(), hess);
}
}

sketch_distributed.MakeCuts(&ctx, m->Info(), &distributed_cuts);
}

auto const& dptrs = distributed_cuts.Ptrs();
auto const& dvals = distributed_cuts.Values();
auto const& dmins = distributed_cuts.MinValues();
std::vector<float> expected_ptrs = {0, 1, 4};
std::vector<float> expected_vals = {2, 0, 0, 0};
std::vector<float> expected_mins = {-1e-5, 1e-5};
if (rank == 1) {
expected_ptrs = {0, 1, 4};
expected_vals = {0, 0.6, 0.8, 1.6};
expected_mins = {1e-5, -1e-5};
}

EXPECT_EQ(dptrs.size(), expected_ptrs.size());
for (size_t i = 0; i < expected_ptrs.size(); ++i) {
EXPECT_EQ(dptrs[i], expected_ptrs[i]) << "rank: " << rank << ", i: " << i;
}

EXPECT_EQ(dvals.size(), expected_vals.size());
for (size_t i = 0; i < expected_vals.size(); ++i) {
if (!std::isnan(dvals[i])) {
EXPECT_NEAR(dvals[i], expected_vals[i], 2e-2f) << "rank: " << rank << ", i: " << i;
}
}

EXPECT_EQ(dmins.size(), expected_mins.size());
for (size_t i = 0; i < expected_mins.size(); ++i) {
EXPECT_FLOAT_EQ(dmins[i], expected_mins[i]) << "rank: " << rank << ", i: " << i;
}
}

template <bool use_column>
void TestColSplitQuantileSecure() {
auto constexpr kWorkers = 2;
RunWithInMemoryCommunicator(kWorkers, DoTestColSplitQuantileSecure<use_column>);
}
} // anonymous namespace

TEST(Quantile, ColSplitSecure) {
TestColSplitQuantileSecure<false>();
}

TEST(Quantile, ColSplitSecureSorted) {
TestColSplitQuantileSecure<true>();
}

namespace {
void TestSameOnAllWorkers() {
auto const world = collective::GetWorldSize();
Expand Down
83 changes: 83 additions & 0 deletions tests/cpp/tree/hist/test_evaluate_splits.cc
Expand Up @@ -289,4 +289,87 @@ TEST_F(TestCategoricalSplitWithMissing, HistEvaluator) {
GradientPairPrecise{split.left_sum.GetGrad(), split.left_sum.GetHess()},
GradientPairPrecise{split.right_sum.GetGrad(), split.right_sum.GetHess()});
}

namespace {
void DoTestEvaluateSplitsSecure(bool force_read_by_column) {
Context ctx;
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
int static constexpr kRows = 8, kCols = 16;
auto sampler = std::make_shared<common::ColumnSampler>(1u);

TrainParam param;
param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}});

auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix();
auto m = dmat->SliceCol(world, rank);
m->Info().data_split_mode = DataSplitMode::kColSecure;

auto evaluator = HistEvaluator{&ctx, &param, m->Info(), sampler};
BoundedHistCollection hist;
std::vector<GradientPair> row_gpairs = {
{1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}};

size_t constexpr kMaxBins = 4;
// dense, no missing values
GHistIndexMatrix gmat(&ctx, dmat.get(), kMaxBins, 0.5, false);
common::RowSetCollection row_set_collection;
std::vector<size_t> &row_indices = *row_set_collection.Data();
row_indices.resize(kRows);
std::iota(row_indices.begin(), row_indices.end(), 0);
row_set_collection.Init();

HistMakerTrainParam hist_param;
hist.Reset(gmat.cut.Ptrs().back(), hist_param.max_cached_hist_node);
hist.AllocateHistograms({0});
common::BuildHist<false>(row_gpairs, row_set_collection[0], gmat, hist[0], force_read_by_column);

// Compute total gradient for all data points
GradientPairPrecise total_gpair;
for (const auto &e : row_gpairs) {
total_gpair += GradientPairPrecise(e);
}

RegTree tree;
std::vector<CPUExpandEntry> entries(1);
entries.front().nid = 0;
entries.front().depth = 0;

evaluator.InitRoot(GradStats{total_gpair});
evaluator.EvaluateSplits(hist, gmat.cut, {}, tree, &entries);

auto best_loss_chg =
evaluator.Evaluator().CalcSplitGain(
param, 0, entries.front().split.SplitIndex(),
entries.front().split.left_sum, entries.front().split.right_sum) -
evaluator.Stats().front().root_gain;
ASSERT_EQ(entries.front().split.loss_chg, best_loss_chg);
ASSERT_GT(entries.front().split.loss_chg, 16.2f);

// Assert that's the best split
for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) {
GradStats left, right;
for (size_t j = gmat.cut.Ptrs()[i-1]; j < gmat.cut.Ptrs()[i]; ++j) {
auto loss_chg =
evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) -
evaluator.Stats().front().root_gain;
ASSERT_GE(best_loss_chg, loss_chg);
left.Add(hist[0][j].GetGrad(), hist[0][j].GetHess());
right.SetSubstract(GradStats{total_gpair}, left);
}
}
}

void TestEvaluateSplitsSecure (bool force_read_by_column) {
auto constexpr kWorkers = 2;
RunWithInMemoryCommunicator(kWorkers, DoTestEvaluateSplitsSecure, force_read_by_column);
}
} // anonymous namespace

TEST(HistEvaluator, SecureEvaluate) {
TestEvaluateSplitsSecure(false);
TestEvaluateSplitsSecure(true);
}

} // namespace xgboost::tree