Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support categorical data in ellpack. #6140

Merged
merged 2 commits into from Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 29 additions & 18 deletions src/data/ellpack_page.cu
@@ -1,10 +1,10 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2020 XGBoost contributors
*/

#include <xgboost/data.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include "../common/categorical.h"
#include "../common/hist_util.cuh"
#include "../common/random.h"
#include "./ellpack_page.cuh"
Expand Down Expand Up @@ -33,6 +33,7 @@ __global__ void CompressBinEllpackKernel(
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
common::Span<FeatureType const> feature_types,
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
Expand All @@ -51,11 +52,19 @@ __global__ void CompressBinEllpackKernel(
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
const float* feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
bool is_cat = common::IsCat(feature_types, ifeature);
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
fvalue) -
feature_cuts;
if (is_cat) {
auto it = dh::MakeTransformIterator<int>(
feature_cuts, [](float v) { return common::AsCat(v); });
bin = thrust::lower_bound(thrust::seq, it, it + ncuts, common::AsCat(fvalue)) - it;
} else {
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
fvalue) -
feature_cuts;
}

if (bin >= ncuts) {
bin = ncuts - 1;
}
Expand Down Expand Up @@ -83,14 +92,13 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
}

EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
const SparsePage& page, bool is_dense,
size_t row_stride)
: cuts_(std::move(cuts)),
is_dense(is_dense),
n_rows(page.Size()),
const SparsePage &page, bool is_dense,
size_t row_stride,
common::Span<FeatureType const> feature_types)
: cuts_(std::move(cuts)), is_dense(is_dense), n_rows(page.Size()),
row_stride(row_stride) {
this->InitCompressedData(device);
this->CreateHistIndices(device, page);
this->CreateHistIndices(device, page, feature_types);
}

// Construct an ELLPACK matrix in memory.
Expand All @@ -108,12 +116,14 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
monitor_.Stop("Quantiles");

monitor_.Start("InitCompressedData");
InitCompressedData(param.gpu_id);
this->InitCompressedData(param.gpu_id);
monitor_.Stop("InitCompressedData");

dmat->Info().feature_types.SetDevice(param.gpu_id);
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
monitor_.Start("BinningCompression");
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
CreateHistIndices(param.gpu_id, batch);
CreateHistIndices(param.gpu_id, batch, ft);
}
monitor_.Stop("BinningCompression");
}
Expand Down Expand Up @@ -365,7 +375,8 @@ void EllpackPageImpl::InitCompressedData(int device) {

// Compress a CSR page into ELLPACK.
void EllpackPageImpl::CreateHistIndices(int device,
const SparsePage& row_batch) {
const SparsePage& row_batch,
common::Span<FeatureType const> feature_types) {
if (row_batch.Size() == 0) return;
unsigned int null_gidx_value = NumSymbols() - 1;

Expand Down Expand Up @@ -397,9 +408,9 @@ void EllpackPageImpl::CreateHistIndices(int device,
size_t n_entries = ent_cnt_end - ent_cnt_begin;
dh::device_vector<Entry> entries_d(n_entries);
// copy data entries to device.
dh::safe_cuda(cudaMemcpy(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
dh::safe_cuda(cudaMemcpyAsync(entries_d.data().get(),
data_vec.data() + ent_cnt_begin,
n_entries * sizeof(Entry), cudaMemcpyDefault));
const dim3 block3(32, 8, 1); // 256 threads
const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x),
common::DivRoundUp(row_stride, block3.y), 1);
Expand All @@ -408,7 +419,7 @@ void EllpackPageImpl::CreateHistIndices(int device,
CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()),
gidx_buffer.DevicePointer(), row_ptrs.data().get(),
entries_d.data().get(), device_accessor.gidx_fvalue_map.data(),
device_accessor.feature_segments.data(),
device_accessor.feature_segments.data(), feature_types,
row_batch.base_rowid + batch_row_begin, batch_nrows, row_stride,
null_gidx_value);
}
Expand Down
12 changes: 7 additions & 5 deletions src/data/ellpack_page.cuh
Expand Up @@ -118,10 +118,12 @@ class EllpackPageImpl {
*/
EllpackPageImpl(int device, common::HistogramCuts cuts, bool is_dense,
size_t row_stride, size_t n_rows);

/*!
* \brief Constructor used for external memory.
*/
EllpackPageImpl(int device, common::HistogramCuts cuts,
const SparsePage& page,
bool is_dense, size_t row_stride);
const SparsePage &page, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types);

/*!
* \brief Constructor from an existing DMatrix.
Expand Down Expand Up @@ -184,8 +186,8 @@ class EllpackPageImpl {
* @param row_batch The CSR page.
*/
void CreateHistIndices(int device,
const SparsePage& row_batch
);
const SparsePage& row_batch,
common::Span<FeatureType const> feature_types);
/*!
* \brief Initialize the buffer to store compressed features.
*/
Expand Down
5 changes: 3 additions & 2 deletions src/data/ellpack_page_source.cu
Expand Up @@ -55,6 +55,7 @@ void EllpackPageSource::WriteEllpackPages(int device, DMatrix* dmat,
SparsePage temp_host_page;
writer.Alloc(&page);
auto* impl = page->Impl();
auto ft = dmat->Info().feature_types.ConstDeviceSpan();

size_t bytes_write = 0;
double tstart = dmlc::GetTime();
Expand All @@ -66,7 +67,7 @@ void EllpackPageSource::WriteEllpackPages(int device, DMatrix* dmat,
if (mem_cost_bytes >= page_size_) {
bytes_write += mem_cost_bytes;
*impl = EllpackPageImpl(device, cuts, temp_host_page, dmat->IsDense(),
row_stride);
row_stride, ft);
writer.PushWrite(std::move(page));
writer.Alloc(&page);
impl = page->Impl();
Expand All @@ -79,7 +80,7 @@ void EllpackPageSource::WriteEllpackPages(int device, DMatrix* dmat,
}
if (temp_host_page.Size() != 0) {
*impl = EllpackPageImpl(device, cuts, temp_host_page, dmat->IsDense(),
row_stride);
row_stride, ft);
writer.PushWrite(std::move(page));
}
}
Expand Down
13 changes: 0 additions & 13 deletions tests/cpp/common/test_hist_util.h
Expand Up @@ -60,19 +60,6 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector<float> &x,
}
#endif

inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
int num_categories) {
std::vector<float> x(n);
std::mt19937 rng(0);
std::uniform_int_distribution<int> dist(0, num_categories - 1);
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
// Make sure each category is present
for(auto i = 0; i < num_categories; i++) {
x[i] = i;
}
return x;
}

inline std::shared_ptr<data::SimpleDMatrix>
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns) {
data::DenseAdapter adapter(x.data(), num_rows, num_columns);
Expand Down
42 changes: 41 additions & 1 deletion tests/cpp/data/test_ellpack_page.cu
@@ -1,5 +1,5 @@
/*!
* Copyright 2019 XGBoost contributors
* Copyright 2019-2020 XGBoost contributors
*/
#include <xgboost/base.h>

Expand All @@ -9,6 +9,7 @@
#include "../histogram_helpers.h"
#include "gtest/gtest.h"

#include "../../../src/common/categorical.h"
#include "../../../src/common/hist_util.h"
#include "../../../src/data/ellpack_page.cuh"

Expand Down Expand Up @@ -77,6 +78,45 @@ TEST(EllpackPage, BuildGidxSparse) {
}
}

TEST(EllpackPage, FromCategoricalBasic) {
using common::AsCat;
size_t constexpr kRows = 1000, kCats = 13, kCols = 1;
size_t max_bins = 8;
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
auto m = GetDMatrixFromData(x, kRows, 1);
auto& h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);

BatchParam p(0, max_bins);
auto ellpack = EllpackPage(m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(0);
ASSERT_EQ(kCats, accessor.NumBins());

auto x_copy = x;
std::sort(x_copy.begin(), x_copy.end());
auto n_uniques = std::unique(x_copy.begin(), x_copy.end()) - x_copy.begin();
ASSERT_EQ(n_uniques, kCats);

std::vector<uint32_t> h_cuts_ptr(accessor.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, accessor.feature_segments);
std::vector<float> h_cuts_values(accessor.gidx_fvalue_map.size());
dh::CopyDeviceSpanToVector(&h_cuts_values, accessor.gidx_fvalue_map);

ASSERT_EQ(h_cuts_ptr.size(), 2);
ASSERT_EQ(h_cuts_values.size(), kCats);

std::vector<common::CompressedByteT> const &h_gidx_buffer =
ellpack.Impl()->gidx_buffer.HostVector();
auto h_gidx_iter = common::CompressedIterator<uint32_t>(
h_gidx_buffer.data(), accessor.NumSymbols());

for (size_t i = 0; i < x.size(); ++i) {
auto bin = h_gidx_iter[i];
auto bin_value = h_cuts_values.at(bin);
ASSERT_EQ(AsCat(x[i]), AsCat(bin_value));
}
}

struct ReadRowFunction {
EllpackDeviceAccessor matrix;
int row;
Expand Down
9 changes: 8 additions & 1 deletion tests/cpp/helpers.cc
Expand Up @@ -17,6 +17,7 @@
#include "helpers.h"
#include "xgboost/c_api.h"
#include "../../src/data/adapter.h"
#include "../../src/data/simple_dmatrix.h"
#include "../../src/gbm/gbtree_model.h"
#include "xgboost/predictor.h"

Expand Down Expand Up @@ -350,6 +351,13 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label,
return out;
}

std::shared_ptr<DMatrix>
GetDMatrixFromData(const std::vector<float> &x, int num_rows, int num_columns){
data::DenseAdapter adapter(x.data(), num_rows, num_columns);
return std::shared_ptr<DMatrix>(new data::SimpleDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1));
}

std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
size_t n_entries, size_t page_size, std::string tmp_file) {
// Create sufficiently large data to make two row pages
Expand Down Expand Up @@ -539,5 +547,4 @@ RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv) {
return RMMAllocatorPtr(nullptr, DeleteRMMResource);
}
#endif // !defined(XGBOOST_USE_RMM) || XGBOOST_USE_RMM != 1

} // namespace xgboost
22 changes: 22 additions & 0 deletions tests/cpp/helpers.h
Expand Up @@ -42,6 +42,12 @@ struct LearnerModelParam;
class GradientBooster;
}

template <typename Float>
Float RelError(Float l, Float r) {
static_assert(std::is_floating_point<Float>::value, "");
return std::abs(1.0f - l / r);
}

bool FileExists(const std::string& filename);

int64_t GetFileSize(const std::string& filename);
Expand Down Expand Up @@ -254,6 +260,22 @@ class RandomDataGenerator {
#endif
};

inline std::vector<float>
GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
std::vector<float> x(n);
std::mt19937 rng(0);
std::uniform_int_distribution<size_t> dist(0, num_categories - 1);
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
// Make sure each category is present
for(size_t i = 0; i < num_categories; i++) {
x[i] = i;
}
return x;
}

std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float> &x,
int num_rows, int num_columns);

std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
size_t n_entries, size_t page_size, std::string tmp_file);

Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/histogram_helpers.h
Expand Up @@ -45,7 +45,7 @@ inline std::unique_ptr<EllpackPageImpl> BuildEllpackPage(
}

auto page = std::unique_ptr<EllpackPageImpl>(
new EllpackPageImpl(0, cmat, batch, dmat->IsDense(), row_stride));
new EllpackPageImpl(0, cmat, batch, dmat->IsDense(), row_stride, {}));

return page;
}
Expand Down