Skip to content

Commit

Permalink
Fixed a memory issue with deleted gh_pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Mar 23, 2024
1 parent 0a7ec90 commit cf890e2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 10 deletions.
46 changes: 39 additions & 7 deletions src/processing/plugins/dummy_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ using std::endl;
const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int kPrefixLen = 24;

bool ValidDam(std::int8_t *buffer) {
return memcmp(buffer, kSignature, strlen(kSignature)) == 0;
}

xgboost::common::Span<int8_t> DummyProcessor::ProcessGHPairs(vector<double> &pairs) {
cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl;

Expand All @@ -31,13 +35,18 @@ xgboost::common::Span<int8_t> DummyProcessor::ProcessGHPairs(vector<double> &pai
}

// Save pairs for future operations
this->gh_pairs_ = &pairs;
this->gh_pairs_ = new vector<double>(pairs);

return xgboost::common::Span<int8_t>(reinterpret_cast<int8_t *>(buf), buf_size);
}

xgboost::common::Span<int8_t> DummyProcessor::HandleGHPairs(xgboost::common::Span<int8_t> buffer) {
cout << "HandleGHPairs called with buffer size: " << buffer.size() << endl;
cout << "HandleGHPairs called with buffer size: " << buffer.size() << " Active: " << active_ << endl;

if (!ValidDam(buffer.data())) {
cout << "Invalid buffer received" << endl;
return buffer;
}

// For dummy, this call is used to set gh_pairs for passive sites
if (!active_) {
Expand All @@ -48,6 +57,7 @@ xgboost::common::Span<int8_t> DummyProcessor::HandleGHPairs(xgboost::common::Spa
for (int i = 0; i < num; i += 10) {
gh_pairs_->push_back(pairs[i]);
}
cout << "GH Pairs saved. Size: " << gh_pairs_->size() << endl;
}

return buffer;
Expand All @@ -58,6 +68,7 @@ xgboost::common::Span<std::int8_t> DummyProcessor::ProcessAggregation(
auto total_bin_size = gidx_->Cuts().Values().size();
auto histo_size = total_bin_size*2;
auto buf_size = kPrefixLen + 8*histo_size*nodes_to_build.size();
cout << "ProcessAggregation called with bin size: " << total_bin_size << " Buffer Size: " << buf_size << endl;
std::int8_t *buf = static_cast<std::int8_t *>(calloc(buf_size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
Expand All @@ -74,6 +85,15 @@ xgboost::common::Span<std::int8_t> DummyProcessor::ProcessAggregation(
continue;
}

if (slot >= total_bin_size) {
cout << "Slot too big, ignored: " << slot << endl;
continue;
}

if (row_id >= gh_pairs_->size()/2) {
cout << "Row ID too big: " << row_id << endl;
}

auto g = (*gh_pairs_)[row_id*2];
auto h = (*gh_pairs_)[row_id*2+1];
histo[slot*2] += g;
Expand All @@ -86,17 +106,29 @@ xgboost::common::Span<std::int8_t> DummyProcessor::ProcessAggregation(
return xgboost::common::Span<int8_t>(reinterpret_cast<int8_t *>(buf), buf_size);
}

std::vector<double> DummyProcessor::HandleAggregation(std::vector<xgboost::common::Span<std::int8_t>> buffers) {
std::vector<double> DummyProcessor::HandleAggregation(xgboost::common::Span<std::int8_t> buffer) {
cout << "HandleAggregation called with buffer size: " << buffer.size() << endl;
std::vector<double> result = std::vector<double>();

for (auto buf : buffers) {
int8_t *ptr = buf.data();
int8_t* ptr = buffer.data();
auto rest_size = buffer.size();

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr)) {
cout << "Invalid buffer at offset " << buffer.size() - rest_size << endl;
continue;
}
std::int64_t *size_ptr = reinterpret_cast<std::int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
auto array_size = (*size_ptr - kPrefixLen) / 8;
auto array_size = (*size_ptr - kPrefixLen)/8;
cout << "Histo size for buffer: " << array_size << endl;
result.insert(result.end(), array_start, array_start + array_size);
cout << "Result size: " << result.size() << endl;
rest_size -= *size_ptr;
ptr = ptr + *size_ptr;
}

cout << "Total histo size: " << result.size() << endl;

return result;
}

2 changes: 1 addition & 1 deletion src/processing/plugins/dummy_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,5 @@ class DummyProcessor: public xgboost::processing::Processor {
xgboost::common::Span<std::int8_t> ProcessAggregation(std::vector<xgboost::bst_node_t> const &nodes_to_build,
xgboost::common::RowSetCollection const &row_set) override;

std::vector<double> HandleAggregation(std::vector<xgboost::common::Span<std::int8_t>> buffers) override;
std::vector<double> HandleAggregation(xgboost::common::Span<std::int8_t> buffer) override;
};
4 changes: 2 additions & 2 deletions src/processing/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ class Processor {
/*!
* \brief Handle all gather result
*
* \param buffers Buffer from all gather, only buffer from active site is needed
* \param buffer Buffer from all gather, only buffer from active site is needed
*
* \return A flattened vector of histograms for each site, each node in the form of
* site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3
*/
virtual std::vector<double> HandleAggregation(std::vector<common::Span<std::int8_t>> buffers) = 0;
virtual std::vector<double> HandleAggregation(xgboost::common::Span<std::int8_t> buffer) = 0;
};

class ProcessorLoader {
Expand Down

0 comments on commit cf890e2

Please sign in to comment.