From 9a4e8b1d8196076098cccfac53bcfe0cc60e1d54 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 25 Aug 2020 12:47:41 +1200 Subject: [PATCH] GPUTreeShap (#6038) --- .gitmodules | 3 + gputreeshap | 1 + src/CMakeLists.txt | 1 + src/common/device_helpers.cuh | 12 +- src/gbm/gbtree.h | 8 +- src/predictor/gpu_predictor.cu | 198 ++++++++++++++++------ tests/cpp/predictor/test_gpu_predictor.cu | 56 ++++++ tests/python-gpu/test_gpu_prediction.py | 31 +++- tests/python/testing.py | 18 +- 9 files changed, 266 insertions(+), 62 deletions(-) create mode 160000 gputreeshap diff --git a/.gitmodules b/.gitmodules index b4aa65438e34..8df83c3ec0bd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "cub"] path = cub url = https://github.com/NVlabs/cub +[submodule "gputreeshap"] + path = gputreeshap + url = https://github.com/rapidsai/gputreeshap.git diff --git a/gputreeshap b/gputreeshap new file mode 160000 index 000000000000..a3d4c44cc6a0 --- /dev/null +++ b/gputreeshap @@ -0,0 +1 @@ +Subproject commit a3d4c44cc6a0a6c3870e7cebcd1ef1d09d7bc0cb diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 52813b83e717..7b5a221bb4e9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -9,6 +9,7 @@ if (USE_CUDA) file(GLOB_RECURSE CUDA_SOURCES *.cu *.cuh) target_sources(objxgboost PRIVATE ${CUDA_SOURCES}) target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_CUDA=1) + target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap) if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0) target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/) endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index beb94680f493..5e4f1eae0d6b 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -474,8 +474,18 @@ class TemporaryArray { using AllocT = XGBCachingDeviceAllocator; using value_type = T; // NOLINT explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); } + TemporaryArray(size_t n, T val) : size_(n) { + ptr_ = AllocT().allocate(n); + this->fill(val); + } ~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); } - + void fill(T val) // NOLINT + { + int device = 0; + dh::safe_cuda(cudaGetDevice(&device)); + auto d_data = ptr_.get(); + LaunchN(device, this->size(), [=] __device__(size_t idx) { d_data[idx] = val; }); + } thrust::device_ptr data() { return ptr_; } // NOLINT size_t size() { return size_; } // NOLINT diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index b96e825e32f8..c2c12f63afd7 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -238,11 +238,11 @@ class GBTree : public GradientBooster { void PredictContribution(DMatrix* p_fmat, std::vector* out_contribs, - unsigned ntree_limit, bool approximate, int condition, - unsigned condition_feature) override { + unsigned ntree_limit, bool approximate, + int condition, unsigned condition_feature) override { CHECK(configured_); - cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, - ntree_limit, nullptr, approximate); + this->GetPredictor()->PredictContribution( + p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate); } void PredictInteractionContributions(DMatrix* p_fmat, diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index c05688eaf4d8..a36a131fa3bb 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "xgboost/data.h" @@ -27,72 +28,79 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor); struct SparsePageView { common::Span d_data; common::Span d_row_ptr; + bst_feature_t num_features; XGBOOST_DEVICE SparsePageView(common::Span data, - common::Span row_ptr) : - d_data{data}, d_row_ptr{row_ptr} {} + common::Span row_ptr, + bst_feature_t num_features) + : d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {} + __device__ float GetElement(size_t ridx, size_t fidx) const { + // Binary search + auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; + auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1]; + if (end_ptr - begin_ptr == this->NumCols()) { + // Bypass span check for dense data + return d_data.data()[d_row_ptr[ridx] + fidx].fvalue; + } + common::Span::iterator previous_middle; + while (end_ptr != begin_ptr) { + auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; + if (middle == previous_middle) { + break; + } else { + previous_middle = middle; + } + + if (middle->index == fidx) { + return middle->fvalue; + } else if (middle->index < fidx) { + begin_ptr = middle; + } else { + end_ptr = middle; + } + } + // Value is missing + return nanf(""); + } + XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } + XGBOOST_DEVICE size_t NumCols() const { return num_features; } }; struct SparsePageLoader { bool use_shared; - common::Span d_row_ptr; - common::Span d_data; - bst_feature_t num_features; + SparsePageView data; float* smem; size_t entry_start; __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) : use_shared(use_shared), - d_row_ptr(data.d_row_ptr), - d_data(data.d_data), - num_features(num_features), + data(data), entry_start(entry_start) { extern __shared__ float _smem[]; smem = _smem; // Copy instances if (use_shared) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; - int shared_elements = blockDim.x * num_features; + int shared_elements = blockDim.x * data.num_features; dh::BlockFill(smem, shared_elements, nanf("")); __syncthreads(); if (global_idx < num_rows) { - bst_uint elem_begin = d_row_ptr[global_idx]; - bst_uint elem_end = d_row_ptr[global_idx + 1]; + bst_uint elem_begin = data.d_row_ptr[global_idx]; + bst_uint elem_end = data.d_row_ptr[global_idx + 1]; for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { - Entry elem = d_data[elem_idx - entry_start]; - smem[threadIdx.x * num_features + elem.index] = elem.fvalue; + Entry elem = data.d_data[elem_idx - entry_start]; + smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue; } } __syncthreads(); } } - __device__ float GetFvalue(int ridx, int fidx) const { + __device__ float GetElement(size_t ridx, size_t fidx) const { if (use_shared) { - return smem[threadIdx.x * num_features + fidx]; + return smem[threadIdx.x * data.num_features + fidx]; } else { - // Binary search - auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start); - auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start); - common::Span::iterator previous_middle; - while (end_ptr != begin_ptr) { - auto middle = begin_ptr + (end_ptr - begin_ptr) / 2; - if (middle == previous_middle) { - break; - } else { - previous_middle = middle; - } - - if (middle->index == fidx) { - return middle->fvalue; - } else if (middle->index < fidx) { - begin_ptr = middle; - } else { - end_ptr = middle; - } - } - // Value is missing - return nanf(""); + return data.GetElement(ridx, fidx); } } }; @@ -103,7 +111,7 @@ struct EllpackLoader { bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) : matrix{m} {} - __device__ __forceinline__ float GetFvalue(int ridx, int fidx) const { + __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const { auto gidx = matrix.GetBinIndex(ridx, fidx); if (gidx == -1) { return nan(""); @@ -150,7 +158,7 @@ struct DeviceAdapterLoader { __syncthreads(); } - DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const { + DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { if (use_shared) { return smem[threadIdx.x * columns + fidx]; } @@ -163,7 +171,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, Loader* loader) { RegTree::Node n = tree[0]; while (!n.IsLeaf()) { - float fvalue = loader->GetFvalue(ridx, n.SplitIndex()); + float fvalue = loader->GetElement(ridx, n.SplitIndex()); // Missing value if (isnan(fvalue)) { n = tree[n.DefaultChild()]; @@ -273,7 +281,8 @@ class GPUPredictor : public xgboost::Predictor { use_shared = false; } size_t entry_start = 0; - SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()}; + SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), + num_features); dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( PredictKernel, data, @@ -447,6 +456,60 @@ class GPUPredictor : public xgboost::Predictor { } } + void PredictContribution(DMatrix* p_fmat, + std::vector* out_contribs, + const gbm::GBTreeModel& model, unsigned ntree_limit, + std::vector* tree_weights, + bool approximate, int condition, + unsigned condition_feature) override { + if (approximate) { + LOG(FATAL) << "[Internal error]: " << __func__ + << " approximate is not implemented in GPU Predictor."; + } + + uint32_t real_ntree_limit = + ntree_limit * model.learner_model_param->num_output_group; + if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { + real_ntree_limit = static_cast(model.trees.size()); + } + + const int ngroup = model.learner_model_param->num_output_group; + CHECK_NE(ngroup, 0); + // allocate space for (number of features + bias) times the number of rows + std::vector& contribs = *out_contribs; + size_t contributions_columns = + model.learner_model_param->num_feature + 1; // +1 for bias + contribs.resize(p_fmat->Info().num_row_ * contributions_columns * + model.learner_model_param->num_output_group); + dh::TemporaryArray phis(contribs.size(), 0.0); + p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); + const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); + float base_score = model.learner_model_param->base_score; + auto d_phis = phis.data().get(); + // Add the base margin term to last column + dh::LaunchN( + generic_param_->gpu_id, + p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, + [=] __device__(size_t idx) { + d_phis[(idx + 1) * contributions_columns - 1] = + margin.empty() ? base_score : margin[idx]; + }); + + const auto& paths = this->ExtractPaths(model, real_ntree_limit); + for (auto& batch : p_fmat->GetBatches()) { + batch.data.SetDevice(generic_param_->gpu_id); + batch.offset.SetDevice(generic_param_->gpu_id); + SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), + model.learner_model_param->num_feature); + gpu_treeshap::GPUTreeShap( + X, paths, ngroup, + phis.data().get() + batch.base_rowid * contributions_columns); + } + dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(), + sizeof(float) * phis.size(), + cudaMemcpyDefault)); + } + protected: void InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, @@ -478,16 +541,6 @@ class GPUPredictor : public xgboost::Predictor { << " is not implemented in GPU Predictor."; } - void PredictContribution(DMatrix* p_fmat, - std::vector* out_contribs, - const gbm::GBTreeModel& model, unsigned ntree_limit, - std::vector* tree_weights, - bool approximate, int condition, - unsigned condition_feature) override { - LOG(FATAL) << "[Internal error]: " << __func__ - << " is not implemented in GPU Predictor."; - } - void PredictInteractionContributions(DMatrix* p_fmat, std::vector* out_contribs, const gbm::GBTreeModel& model, @@ -510,6 +563,49 @@ class GPUPredictor : public xgboost::Predictor { } } + std::vector ExtractPaths( + const gbm::GBTreeModel& model, size_t tree_limit) { + std::vector paths; + size_t path_idx = 0; + CHECK_LE(tree_limit, model.trees.size()); + for (auto i = 0ull; i < tree_limit; i++) { + const auto& tree = *model.trees.at(i); + size_t group = model.tree_info[i]; + const auto& nodes = tree.GetNodes(); + for (auto j = 0ull; j < nodes.size(); j++) { + if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) { + auto child = nodes[j]; + float v = child.LeafValue(); + size_t child_idx = j; + const float inf = std::numeric_limits::infinity(); + while (!child.IsRoot()) { + float child_cover = tree.Stat(child_idx).sum_hess; + float parent_cover = tree.Stat(child.Parent()).sum_hess; + float zero_fraction = child_cover / parent_cover; + CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0); + auto parent = nodes[child.Parent()]; + CHECK(parent.LeftChild() == child_idx || + parent.RightChild() == child_idx); + bool is_left_path = parent.LeftChild() == child_idx; + bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) || + (parent.DefaultLeft() && is_left_path); + float lower_bound = is_left_path ? -inf : parent.SplitCond(); + float upper_bound = is_left_path ? parent.SplitCond() : inf; + paths.emplace_back(path_idx, parent.SplitIndex(), group, + lower_bound, upper_bound, is_missing_path, + zero_fraction, v); + child_idx = child.Parent(); + child = parent; + } + // Root node has feature -1 + paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v); + path_idx++; + } + } + } + return paths; + } + std::mutex lock_; DeviceModel model_; size_t max_shared_memory_bytes_; diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index b14b731142c2..585acf1790b6 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -163,5 +163,61 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT TEST(GpuPredictor, LesserFeatures) { TestPredictionWithLesserFeatures("gpu_predictor"); } +// Very basic test of empty model +TEST(GPUPredictor, ShapStump) { + cudaSetDevice(0); + LearnerModelParam param; + param.num_feature = 1; + param.num_output_group = 1; + param.base_score = 0.5; + gbm::GBTreeModel model(¶m); + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + model.CommitModel(std::move(trees), 0); + + auto gpu_lparam = CreateEmptyGenericParam(0); + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor", &gpu_lparam)); + gpu_predictor->Configure({}); + std::vector phis; + auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix(); + gpu_predictor->PredictContribution(dmat.get(), &phis, model); + EXPECT_EQ(phis[0], 0.0); + EXPECT_EQ(phis[1], param.base_score); + EXPECT_EQ(phis[2], 0.0); + EXPECT_EQ(phis[3], param.base_score); + EXPECT_EQ(phis[4], 0.0); + EXPECT_EQ(phis[5], param.base_score); +} +TEST(GPUPredictor, Shap) { + LearnerModelParam param; + param.num_feature = 1; + param.num_output_group = 1; + param.base_score = 0.5; + gbm::GBTreeModel model(¶m); + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0); + model.CommitModel(std::move(trees), 0); + + auto gpu_lparam = CreateEmptyGenericParam(0); + auto cpu_lparam = CreateEmptyGenericParam(-1); + std::unique_ptr gpu_predictor = + std::unique_ptr(Predictor::Create("gpu_predictor", &gpu_lparam)); + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor", &cpu_lparam)); + gpu_predictor->Configure({}); + cpu_predictor->Configure({}); + std::vector phis; + std::vector cpu_phis; + auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix(); + gpu_predictor->PredictContribution(dmat.get(), &phis, model); + cpu_predictor->PredictContribution(dmat.get(), &cpu_phis, model); + for(auto i = 0ull; i < phis.size(); i++) + { + EXPECT_NEAR(cpu_phis[i], phis[i], 1e-3); + } +} + } // namespace predictor } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 26ae95d1bf5c..3810c30daa1a 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -4,6 +4,7 @@ import numpy as np import xgboost as xgb +from hypothesis import given, strategies, assume, settings, note sys.path.append("tests/python") import testing as tm @@ -11,6 +12,12 @@ rng = np.random.RandomState(1994) +shap_parameter_strategy = strategies.fixed_dictionaries({ + 'max_depth': strategies.integers(0, 11), + 'max_leaves': strategies.integers(0, 256), + 'num_parallel_tree': strategies.sampled_from([1, 10]), +}) + class TestGPUPredict(unittest.TestCase): def test_predict(self): @@ -149,7 +156,8 @@ def predict_dense(x): # Don't do this on Windows, see issue #5793 if sys.platform.startswith("win"): - pytest.skip('Multi-threaded in-place prediction with cuPy is not working on Windows') + pytest.skip( + 'Multi-threaded in-place prediction with cuPy is not working on Windows') for i in range(10): run_threaded_predict(X, rows, predict_dense) @@ -185,3 +193,24 @@ def predict_df(x): for i in range(10): run_threaded_predict(X, rows, predict_df) + + @given(strategies.integers(1, 200), + tm.dataset_strategy, shap_parameter_strategy, strategies.booleans()) + @settings(deadline=None) + def test_shap(self, num_rounds, dataset, param, all_rows): + param.update({"predictor": "gpu_predictor", "gpu_id": 0}) + param = dataset.set_params(param) + dmat = dataset.get_dmat() + bst = xgb.train(param, dmat, num_rounds) + if all_rows: + test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin) + else: + test_dmat = xgb.DMatrix(dataset.X[0:1, :]) + shap = bst.predict(test_dmat, pred_contribs=True) + bst.set_param({"predictor": "cpu_predictor"}) + cpu_shap = bst.predict(test_dmat, pred_contribs=True) + margin = bst.predict(test_dmat, output_margin=True) + assert np.allclose(shap, cpu_shap, 1e-3, 1e-3) + # feature contributions should add up to predictions + assume(len(dataset.y) > 0) + assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3) diff --git a/tests/python/testing.py b/tests/python/testing.py index 30b44079607b..a81e7ea87048 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -131,6 +131,7 @@ def __init__(self, name, get_dataset, objective, metric self.metric = metric self.X, self.y = get_dataset() self.w = None + self.margin = None def set_params(self, params_in): params_in['objective'] = self.objective @@ -140,13 +141,13 @@ def set_params(self, params_in): return params_in def get_dmat(self): - return xgb.DMatrix(self.X, self.y, self.w) + return xgb.DMatrix(self.X, self.y, self.w, base_margin=self.margin) def get_device_dmat(self): w = None if self.w is None else cp.array(self.w) X = cp.array(self.X, dtype=np.float32) y = cp.array(self.y, dtype=np.float32) - return xgb.DeviceQuantileDMatrix(X, y, w) + return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin) def get_external_dmat(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -157,7 +158,7 @@ def get_external_dmat(self): uri = path + '?format=csv&label_column=0#tmptmp_' # The uri looks like: # 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_' - return xgb.DMatrix(uri, weight=self.w) + return xgb.DMatrix(uri, weight=self.w, base_margin=self.margin) def __repr__(self): return self.name @@ -206,16 +207,23 @@ def get_sparse(): @strategies.composite -def _dataset_and_weight(draw): +def _dataset_weight_margin(draw): data = draw(_unweighted_datasets_strategy) if draw(strategies.booleans()): data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))) + if draw(strategies.booleans()): + num_class = 1 + if data.objective == "multi:softmax": + num_class = int(np.max(data.y) + 1) + data.margin = draw( + arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0))) + return data # A strategy for drawing from a set of example datasets # May add random weights to the dataset -dataset_strategy = _dataset_and_weight() +dataset_strategy = _dataset_weight_margin() def non_increasing(L, tolerance=1e-4):