diff --git a/include/xgboost/feature_map.h b/include/xgboost/feature_map.h index a48e28ba1bfa..d5ff431d64eb 100644 --- a/include/xgboost/feature_map.h +++ b/include/xgboost/feature_map.h @@ -82,7 +82,9 @@ class FeatureMap { if (!strcmp("q", tname)) return kQuantitive; if (!strcmp("int", tname)) return kInteger; if (!strcmp("float", tname)) return kFloat; - LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity"; + if (!strcmp("categorical", tname)) return kInteger; + LOG(FATAL) << "unknown feature type, use i for indicator, q for quantity " + "and categorical for categorical split."; return kIndicator; } /*! \brief name of the feature */ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 60863dbd45cb..b1122153103d 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -384,7 +384,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None, silent=False, feature_names=None, feature_types=None, - nthread=None): + nthread=None, + enable_categorical=False): """Parameters ---------- data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/ @@ -419,6 +420,17 @@ def __init__(self, data, label=None, weight=None, base_margin=None, Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. + enable_categorical: boolean, optional + + .. versionadded:: 1.3.0 + + Experimental support of specializing for categorical features. Do + not set to True unless you are interested in development. + Currently it's only available for `gpu_hist` tree method with 1 vs + rest (one hot) categorical split. Also, JSON serialization format, + `enable_experimental_json_serialization`, `gpu_predictor` and + pandas input are required. + """ if isinstance(data, list): raise TypeError('Input data can not be a list.') @@ -437,7 +449,8 @@ def __init__(self, data, label=None, weight=None, base_margin=None, data, missing=self.missing, threads=self.nthread, feature_names=feature_names, - feature_types=feature_types) + feature_types=feature_types, + enable_categorical=enable_categorical) assert handle is not None self.handle = handle diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 03d929b4d645..75d46824f1aa 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -175,20 +175,24 @@ def _is_modin_df(data): } -def _transform_pandas_df(data, feature_names=None, feature_types=None, +def _transform_pandas_df(data, enable_categorical, + feature_names=None, feature_types=None, meta=None, meta_type=None): from pandas import MultiIndex, Int64Index - from pandas.api.types import is_sparse + from pandas.api.types import is_sparse, is_categorical + data_dtypes = data.dtypes - if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) + if not all(dtype.name in _pandas_dtype_mapper or is_sparse(dtype) or + (is_categorical(dtype) and enable_categorical) for dtype in data_dtypes): bad_fields = [ str(data.columns[i]) for i, dtype in enumerate(data_dtypes) if dtype.name not in _pandas_dtype_mapper ] - msg = """DataFrame.dtypes for data must be int, float or bool. - Did not expect the data types in fields """ + msg = """DataFrame.dtypes for data must be int, float, bool or categorical. When + categorical type is supplied, DMatrix parameter + `enable_categorical` must be set to `True`.""" raise ValueError(msg + ', '.join(bad_fields)) if feature_names is None and meta is None: @@ -207,6 +211,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None, if is_sparse(dtype): feature_types.append(_pandas_dtype_mapper[ dtype.subtype.name]) + elif is_categorical(dtype) and enable_categorical: + feature_types.append('categorical') else: feature_types.append(_pandas_dtype_mapper[dtype.name]) @@ -215,15 +221,21 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None, 'DataFrame for {meta} cannot have multiple columns'.format( meta=meta)) - dtype = meta_type if meta_type else np.float32 data = np.ascontiguousarray(data.values, dtype=dtype) + dtype = meta_type if meta_type else np.float32 + try: + data = data.values.astype(dtype) + except ValueError as e: + raise ValueError('Data must be convertable to float, even ' + + 'for categorical data.') from e return data, feature_names, feature_types -def _from_pandas_df(data, missing, nthread, feature_names, feature_types): +def _from_pandas_df(data, enable_categorical, missing, nthread, + feature_names, feature_types): data, feature_names, feature_types = _transform_pandas_df( - data, feature_names, feature_types) + data, enable_categorical, feature_names, feature_types) return _from_numpy_array(data, missing, nthread, feature_names, feature_types) @@ -498,7 +510,8 @@ def _has_array_protocol(data): def dispatch_data_backend(data, missing, threads, - feature_names, feature_types): + feature_names, feature_types, + enable_categorical=False): '''Dispatch data for DMatrix.''' if _is_scipy_csr(data): return _from_scipy_csr(data, missing, feature_names, feature_types) @@ -514,7 +527,7 @@ def dispatch_data_backend(data, missing, threads, if _is_tuple(data): return _from_tuple(data, missing, feature_names, feature_types) if _is_pandas_df(data): - return _from_pandas_df(data, missing, threads, + return _from_pandas_df(data, enable_categorical, missing, threads, feature_names, feature_types) if _is_pandas_series(data): return _from_pandas_series(data, missing, threads, feature_names, @@ -644,7 +657,8 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_df(data): - data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype) + data, _, _ = _transform_pandas_df(data, False, meta=name, + meta_type=dtype) _meta_from_numpy(data, name, dtype, handle) return if _is_pandas_series(data): diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 471ec31f42ac..33905a9fae43 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -80,6 +80,11 @@ struct AtomicDispatcher { using Type = unsigned long long; // NOLINT static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits."); }; + +template <> +struct AtomicDispatcher { + using Type = uint8_t; // NOLINT +}; } // namespace detail } // namespace dh @@ -536,6 +541,17 @@ void CopyDeviceSpanToVector(std::vector *dst, xgboost::common::Span cudaMemcpyDeviceToHost)); } +template +void CopyToD(HContainer const &h, DContainer *d) { + d->resize(h.size()); + using HVT = std::remove_cv_t; + using DVT = std::remove_cv_t; + static_assert(std::is_same::value, + "Host and device containers must have same value type."); + dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT), + cudaMemcpyHostToDevice)); +} + // Keep track of pinned memory allocation struct PinnedMemory { void *temp_storage{nullptr}; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index b60c2a5d52eb..376a64742abb 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -178,7 +178,7 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + host_data.begin() + end); thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index f9974f8ecfaf..a16154966588 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -10,6 +10,7 @@ #include #include #include +#include "xgboost/tree_model.h" #include "xgboost/host_device_vector.h" namespace xgboost { @@ -176,6 +177,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t +template class HostDeviceVector; #if defined(__APPLE__) /* diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 2de1bb652900..02a47bea8e60 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -404,6 +404,7 @@ template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; +template class HostDeviceVector; template class HostDeviceVector; #if defined(__APPLE__) diff --git a/src/data/adapter.h b/src/data/adapter.h index c3981c24fffd..aa7d44f2e609 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -68,7 +68,7 @@ namespace data { /** \brief An adapter can return this value for number of rows or columns * indicating that this value is currently unknown and should be inferred while * passing over the data. */ -constexpr size_t kAdapterUnknownSize = std::numeric_limits::max(); +constexpr size_t kAdapterUnknownSize = std::numeric_limits::max(); struct COOTuple { COOTuple() = default; diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index eb3d34443659..58b7b4628c67 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -98,6 +98,9 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin })); nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end()); + + this->Info().feature_types.Resize(proxy->Info().feature_types.Size()); + this->Info().feature_types.Copy(proxy->Info().feature_types); batches++; } iter.Reset(); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index fe9664a55254..3265ebbfb83a 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -18,6 +18,8 @@ #include "../data/ellpack_page.cuh" #include "../data/device_adapter.cuh" #include "../common/common.h" +#include "../common/bitfield.h" +#include "../common/categorical.h" #include "../common/device_helpers.cuh" namespace xgboost { @@ -169,33 +171,49 @@ struct DeviceAdapterLoader { template __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, + common::Span split_types, + common::Span d_cat_ptrs, + common::Span d_categories, Loader* loader) { - RegTree::Node n = tree[0]; + bst_node_t nidx = 0; + RegTree::Node n = tree[nidx]; while (!n.IsLeaf()) { float fvalue = loader->GetElement(ridx, n.SplitIndex()); // Missing value - if (isnan(fvalue)) { - n = tree[n.DefaultChild()]; + if (common::CheckNAN(fvalue)) { + nidx = n.DefaultChild(); } else { - if (fvalue < n.SplitCond()) { - n = tree[n.LeftChild()]; + bool go_left = true; + if (common::IsCat(split_types, nidx)) { + auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg, + d_cat_ptrs[nidx].size); + go_left = Decision(categories, common::AsCat(fvalue)); } else { - n = tree[n.RightChild()]; + go_left = fvalue < n.SplitCond(); + } + if (go_left) { + nidx = n.LeftChild(); + } else { + nidx = n.RightChild(); } } + n = tree[nidx]; } - return n.LeafValue(); + return tree[nidx].LeafValue(); } template -__global__ void PredictKernel(Data data, - common::Span d_nodes, - common::Span d_out_predictions, - common::Span d_tree_segments, - common::Span d_tree_group, - size_t tree_begin, size_t tree_end, size_t num_features, - size_t num_rows, size_t entry_start, - bool use_shared, int num_group) { +__global__ void +PredictKernel(Data data, common::Span d_nodes, + common::Span d_out_predictions, + common::Span d_tree_segments, + common::Span d_tree_group, + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories, size_t tree_begin, + size_t tree_end, size_t num_features, size_t num_rows, + size_t entry_start, bool use_shared, int num_group) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; Loader loader(data, use_shared, num_features, num_rows, entry_start); if (global_idx >= num_rows) return; @@ -204,7 +222,18 @@ __global__ void PredictKernel(Data data, for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; - float leaf = GetLeafWeight(global_idx, d_tree, &loader); + auto tree_cat_ptrs = d_cat_node_segments.subspan( + d_tree_segments[tree_idx - tree_begin], + d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]); + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); + float leaf = GetLeafWeight(global_idx, d_tree, d_tree_split_types, + tree_cat_ptrs, + tree_categories, + &loader); sum += leaf; } d_out_predictions[global_idx] += sum; @@ -214,8 +243,19 @@ __global__ void PredictKernel(Data data, const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; bst_uint out_prediction_idx = global_idx * num_group + tree_group; + auto tree_cat_ptrs = d_cat_node_segments.subspan( + d_tree_segments[tree_idx - tree_begin], + d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]); + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); d_out_predictions[out_prediction_idx] += - GetLeafWeight(global_idx, d_tree, &loader); + GetLeafWeight(global_idx, d_tree, d_tree_split_types, + tree_cat_ptrs, + tree_categories, + &loader); } } } @@ -223,10 +263,16 @@ __global__ void PredictKernel(Data data, class DeviceModel { public: // Need to lazily construct the vectors because GPU id is only known at runtime - HostDeviceVector nodes; HostDeviceVector stats; HostDeviceVector tree_segments; + HostDeviceVector nodes; HostDeviceVector tree_group; + HostDeviceVector split_types; + + HostDeviceVector categories; + HostDeviceVector categories_tree_segments; + HostDeviceVector categories_node_segments; + size_t tree_beg_; // NOLINT size_t tree_end_; // NOLINT int num_group; @@ -264,10 +310,43 @@ class DeviceModel { } tree_group = std::move(HostDeviceVector(model.tree_info.size(), 0, gpu_id)); - auto d_tree_group = tree_group.DevicePointer(); - dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(), - sizeof(int) * model.tree_info.size(), - cudaMemcpyDefault)); + auto& h_tree_group = tree_group.HostVector(); + std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size()); + + // Initialize categorical splits. + split_types.SetDevice(gpu_id); + std::vector& h_split_types = split_types.HostVector(); + h_split_types.resize(h_tree_segments.back()); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& src_st = model.trees.at(tree_idx)->GetSplitTypes(); + std::copy(src_st.cbegin(), src_st.cend(), + h_split_types.begin() + h_tree_segments[tree_idx - tree_begin]); + } + + categories = HostDeviceVector({}, gpu_id); + categories_tree_segments = HostDeviceVector(1, 0, gpu_id); + std::vector &h_categories = categories.HostVector(); + std::vector& h_split_cat_segments = categories_tree_segments.HostVector(); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& src_cats = model.trees.at(tree_idx)->GetSplitCategories(); + size_t orig_size = h_categories.size(); + h_categories.resize(orig_size + src_cats.size()); + std::copy(src_cats.cbegin(), src_cats.cend(), + h_categories.begin() + orig_size); + h_split_cat_segments.push_back(h_categories.size()); + } + + categories_node_segments = + HostDeviceVector(h_tree_segments.back(), {}, gpu_id); + std::vector &h_categories_node_segments = + categories_node_segments.HostVector(); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr(); + std::copy(src_cats_ptr.cbegin(), src_cats_ptr.cend(), + h_categories_node_segments.begin() + + h_tree_segments[tree_idx - tree_begin]); + } + this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; this->num_group = model.learner_model_param->num_output_group; @@ -360,7 +439,8 @@ void ExtractPaths(dh::device_vector* paths, class GPUPredictor : public xgboost::Predictor { private: - void PredictInternal(const SparsePage& batch, size_t num_features, + void PredictInternal(const SparsePage& batch, + size_t num_features, HostDeviceVector* predictions, size_t batch_offset) { batch.offset.SetDevice(generic_param_->gpu_id); @@ -380,14 +460,18 @@ class GPUPredictor : public xgboost::Predictor { SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - data, - model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), - model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), - model_.tree_beg_, model_.tree_end_, num_features, num_rows, - entry_start, use_shared, model_.num_group); + PredictKernel, data, + model_.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), + model_.split_types.ConstDeviceSpan(), + model_.categories_tree_segments.ConstDeviceSpan(), + model_.categories_node_segments.ConstDeviceSpan(), + model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + num_features, num_rows, entry_start, use_shared, model_.num_group); } - void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector* out_preds, + void PredictInternal(EllpackDeviceAccessor const& batch, + HostDeviceVector* out_preds, size_t batch_offset) { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; @@ -396,12 +480,15 @@ class GPUPredictor : public xgboost::Predictor { bool use_shared = false; size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( - PredictKernel, - batch, - model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), - model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), - model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows, - entry_start, use_shared, model_.num_group); + PredictKernel, batch, + model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), + model_.split_types.ConstDeviceSpan(), + model_.categories_tree_segments.ConstDeviceSpan(), + model_.categories_node_segments.ConstDeviceSpan(), + model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + batch.NumFeatures(), num_rows, entry_start, use_shared, + model_.num_group); } void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, @@ -413,6 +500,7 @@ class GPUPredictor : public xgboost::Predictor { } model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id); out_preds->SetDevice(generic_param_->gpu_id); + auto const& info = dmat->Info(); if (dmat->PageExists()) { size_t batch_offset = 0; @@ -425,7 +513,8 @@ class GPUPredictor : public xgboost::Predictor { size_t batch_offset = 0; for (auto const& page : dmat->GetBatches()) { this->PredictInternal( - page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), out_preds, + page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), + out_preds, batch_offset); batch_offset += page.Impl()->n_rows; } @@ -528,12 +617,14 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - m->Value(), - d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(), - d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(), - tree_begin, tree_end, m->NumColumns(), info.num_row_, - entry_start, use_shared, output_groups); + PredictKernel, m->Value(), + d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(), + d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), + d_model.split_types.ConstDeviceSpan(), + d_model.categories_tree_segments.ConstDeviceSpan(), + d_model.categories_node_segments.ConstDeviceSpan(), + d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), + info.num_row_, entry_start, use_shared, output_groups); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ef1c6946cf3b..7d7edcb55e36 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,8 +1,9 @@ /*! * Copyright 2020 by XGBoost Contributors */ -#include "evaluate_splits.cuh" #include +#include "evaluate_splits.cuh" +#include "../../common/categorical.h" namespace xgboost { namespace tree { @@ -66,13 +67,84 @@ ReduceFeature(common::Span feature_histogram, if (threadIdx.x == 0) { shared_sum = local_sum; } - __syncthreads(); + cub::CTA_SYNC(); return shared_sum; } +template struct OneHotBin { + GradientSumT __device__ operator()( + bool thread_active, uint32_t scan_begin, + SumCallbackOp*, + GradientSumT const &missing, + EvaluateSplitInputs const &inputs, TempStorageT *) { + GradientSumT bin = thread_active + ? inputs.gradient_histogram[scan_begin + threadIdx.x] + : GradientSumT(); + auto rest = inputs.parent_sum - bin - missing; + return rest; + } +}; + +template +struct UpdateOneHot { + void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, + bst_feature_t fidx, GradientSumT const &missing, + GradientSumT const &bin, + EvaluateSplitInputs const &inputs, + DeviceSplitCandidate *best_split) { + int split_gidx = (scan_begin + threadIdx.x); + float fvalue = inputs.feature_values[split_gidx]; + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, + GradientPair(left), GradientPair(right), true, + inputs.param); + } +}; + +template +struct NumericBin { + GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin, + SumCallbackOp* prefix_callback, + GradientSumT const &missing, + EvaluateSplitInputs inputs, + TempStorageT *temp_storage) { + GradientSumT bin = thread_active + ? inputs.gradient_histogram[scan_begin + threadIdx.x] + : GradientSumT(); + ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback); + return bin; + } +}; + +template +struct UpdateNumeric { + void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, + bst_feature_t fidx, GradientSumT const &missing, + GradientSumT const &bin, + EvaluateSplitInputs const &inputs, + DeviceSplitCandidate *best_split) { + // Use pointer from cut to indicate begin and end of bins for each feature. + uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin + int split_gidx = (scan_begin + threadIdx.x) - 1; + float fvalue; + if (split_gidx < static_cast(gidx_begin)) { + fvalue = inputs.min_fvalue[fidx]; + } else { + fvalue = inputs.feature_values[split_gidx]; + } + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, + fidx, GradientPair(left), GradientPair(right), + false, inputs.param); + } +}; + /*! \brief Find the thread with best gain. */ template + typename MaxReduceT, typename TempStorageT, typename GradientSumT, + typename BinFn, typename UpdateFn> __device__ void EvaluateFeature( int fidx, EvaluateSplitInputs inputs, TreeEvaluator::SplitEvaluator evaluator, @@ -83,12 +155,14 @@ __device__ void EvaluateFeature( uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin uint32_t gidx_end = inputs.feature_segments[fidx + 1]; // end bin for i^th feature + auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin); + auto bin_fn = BinFn(); + auto update_fn = UpdateFn(); // Sum histogram bins for current feature GradientSumT const feature_sum = ReduceFeature( - inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin), - temp_storage); + feature_hist, temp_storage); GradientSumT const missing = inputs.parent_sum - feature_sum; float const null_gain = -std::numeric_limits::infinity(); @@ -97,12 +171,7 @@ __device__ void EvaluateFeature( for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) { bool thread_active = (scan_begin + threadIdx.x) < gidx_end; - - // Gradient value for current bin. - GradientSumT bin = thread_active - ? inputs.gradient_histogram[scan_begin + threadIdx.x] - : GradientSumT(); - ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage); // Whether the gradient of missing values is put to the left side. bool missing_left = true; @@ -127,24 +196,14 @@ __device__ void EvaluateFeature( block_max = best; } - __syncthreads(); + cub::CTA_SYNC(); // Best thread updates split if (threadIdx.x == block_max.key) { - int split_gidx = (scan_begin + threadIdx.x) - 1; - float fvalue; - if (split_gidx < static_cast(gidx_begin)) { - fvalue = inputs.min_fvalue[fidx]; - } else { - fvalue = inputs.feature_values[split_gidx]; - } - GradientSumT left = missing_left ? bin + missing : bin; - GradientSumT right = inputs.parent_sum - left; - best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, - fidx, GradientPair(left), GradientPair(right), - inputs.param); + update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs, + best_split); } - __syncthreads(); + cub::CTA_SYNC(); } } @@ -186,11 +245,21 @@ __global__ void EvaluateSplitsKernel( // One block for each feature. Features are sampled, so fidx != blockIdx.x int fidx = inputs.feature_set[is_left ? blockIdx.x : blockIdx.x - left.feature_set.size()]; + if (common::IsCat(inputs.feature_types, fidx)) { + EvaluateFeature, + UpdateOneHot>(fidx, inputs, evaluator, &best_split, + &temp_storage); + } else { + EvaluateFeature, + UpdateNumeric>(fidx, inputs, evaluator, &best_split, + &temp_storage); + } - EvaluateFeature( - fidx, inputs, evaluator, &best_split, &temp_storage); - - __syncthreads(); + cub::CTA_SYNC(); if (threadIdx.x == 0) { // Record best loss for each feature diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index f847518dba68..e30901134cda 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -18,6 +18,7 @@ struct EvaluateSplitInputs { GradientSumT parent_sum; GPUTrainingParam param; common::Span feature_set; + common::Span feature_types; common::Span feature_segments; common::Span feature_values; common::Span min_fvalue; diff --git a/src/tree/gpu_hist/feature_groups.cu b/src/tree/gpu_hist/feature_groups.cu index 5a2c8ee6cbd8..9bb9d816283a 100644 --- a/src/tree/gpu_hist/feature_groups.cu +++ b/src/tree/gpu_hist/feature_groups.cu @@ -23,13 +23,13 @@ FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, return; } - std::vector& feature_segments_h = feature_segments.HostVector(); + std::vector& feature_segments_h = feature_segments.HostVector(); std::vector& bin_segments_h = bin_segments.HostVector(); feature_segments_h.push_back(0); bin_segments_h.push_back(0); const std::vector& cut_ptrs = cuts.Ptrs(); - int max_shmem_bins = shm_size / bin_size; + size_t max_shmem_bins = shm_size / bin_size; max_group_bins = 0; for (size_t i = 2; i < cut_ptrs.size(); ++i) { @@ -49,7 +49,7 @@ FeatureGroups::FeatureGroups(const common::HistogramCuts& cuts, bool is_dense, } void FeatureGroups::InitSingle(const common::HistogramCuts& cuts) { - std::vector& feature_segments_h = feature_segments.HostVector(); + std::vector& feature_segments_h = feature_segments.HostVector(); feature_segments_h.push_back(0); feature_segments_h.push_back(cuts.Ptrs().size() - 1); diff --git a/src/tree/gpu_hist/feature_groups.cuh b/src/tree/gpu_hist/feature_groups.cuh index 3af230c2ccf6..a0fc765a6b4a 100644 --- a/src/tree/gpu_hist/feature_groups.cuh +++ b/src/tree/gpu_hist/feature_groups.cuh @@ -20,7 +20,7 @@ namespace tree { consecutive feature indices, and also contains a range of all bin indices associated with those features. */ struct FeatureGroup { - __host__ __device__ FeatureGroup(int start_feature_, int num_features_, + __host__ __device__ FeatureGroup(size_t start_feature_, size_t num_features_, int start_bin_, int num_bins_) : start_feature(start_feature_), num_features(num_features_), start_bin(start_bin_), num_bins(num_bins_) {} @@ -36,24 +36,24 @@ struct FeatureGroup { /** \brief FeatureGroupsAccessor is a non-owning accessor for FeatureGroups. */ struct FeatureGroupsAccessor { - FeatureGroupsAccessor(common::Span feature_segments_, + FeatureGroupsAccessor(common::Span feature_segments_, common::Span bin_segments_, int max_group_bins_) : feature_segments(feature_segments_), bin_segments(bin_segments_), max_group_bins(max_group_bins_) {} - - common::Span feature_segments; + + common::Span feature_segments; common::Span bin_segments; int max_group_bins; - + /** \brief Gets the number of feature groups. */ - __host__ __device__ int NumGroups() const { + __host__ __device__ size_t NumGroups() const { return feature_segments.size() - 1; } /** \brief Gets the information about a feature group with index i. */ __host__ __device__ FeatureGroup operator[](int i) const { return {feature_segments[i], feature_segments[i + 1] - feature_segments[i], - bin_segments[i], bin_segments[i + 1] - bin_segments[i]}; + bin_segments[i], bin_segments[i + 1] - bin_segments[i]}; } }; @@ -78,13 +78,13 @@ struct FeatureGroupsAccessor { */ struct FeatureGroups { /** Group cuts for features. Size equals to (number of groups + 1). */ - HostDeviceVector feature_segments; + HostDeviceVector feature_segments; /** Group cuts for bins. Size equals to (number of groups + 1) */ HostDeviceVector bin_segments; /** Maximum number of bins in a group. Useful to compute the amount of dynamic shared memory when launching a kernel. */ int max_group_bins; - + /** Creates feature groups by splitting features into groups. \param cuts Histogram cuts that given the number of bins per feature. \param is_dense Whether the data matrix is dense. @@ -106,12 +106,12 @@ struct FeatureGroups { feature_segments.SetDevice(device); bin_segments.SetDevice(device); return {feature_segments.ConstDeviceSpan(), bin_segments.ConstDeviceSpan(), - max_group_bins}; + max_group_bins}; } private: void InitSingle(const common::HistogramCuts& cuts); -}; +}; } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index e48e0d4f2cfc..4219a3399e40 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -59,6 +59,7 @@ struct DeviceSplitCandidate { DefaultDirection dir {kLeftDir}; int findex {-1}; float fvalue {0}; + bool is_cat { false }; GradientPair left_sum; GradientPair right_sum; @@ -79,6 +80,7 @@ struct DeviceSplitCandidate { float fvalue_in, int findex_in, GradientPair left_sum_in, GradientPair right_sum_in, + bool cat, const GPUTrainingParam& param) { if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight && @@ -86,6 +88,7 @@ struct DeviceSplitCandidate { loss_chg = loss_chg_in; dir = dir_in; fvalue = fvalue_in; + is_cat = cat; left_sum = left_sum_in; right_sum = right_sum_in; findex = findex_in; @@ -98,6 +101,7 @@ struct DeviceSplitCandidate { << "dir: " << c.dir << ", " << "findex: " << c.findex << ", " << "fvalue: " << c.fvalue << ", " + << "is_cat: " << c.is_cat << ", " << "left sum: " << c.left_sum << ", " << "right sum: " << c.right_sum << std::endl; return os; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 19698695544e..c63d510de86c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -19,7 +19,9 @@ #include "../common/io.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "../common/bitfield.h" #include "../common/timer.h" +#include "../common/categorical.h" #include "../data/ellpack_page.cuh" #include "param.h" @@ -161,6 +163,7 @@ template struct GPUHistMakerDevice { int device_id; EllpackPageImpl* page; + common::Span feature_types; BatchParam batch_param; std::unique_ptr row_partitioner; @@ -191,9 +194,12 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; + // Storing split categories for 1 node. + dh::caching_device_vector node_categories; GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, + common::Span _feature_types, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, @@ -202,6 +208,7 @@ struct GPUHistMakerDevice { BatchParam _batch_param) : device_id(_device_id), page(_page), + feature_types{_feature_types}, param(std::move(_param)), tree_evaluator(param, n_features, _device_id), column_sampler(column_sampler_seed), @@ -293,6 +300,7 @@ struct GPUHistMakerDevice { {root_sum.GetGrad(), root_sum.GetHess()}, gpu_param, feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -331,6 +339,7 @@ struct GPUHistMakerDevice { candidate.split.left_sum.GetHess()}, gpu_param, left_feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -341,6 +350,7 @@ struct GPUHistMakerDevice { candidate.split.right_sum.GetHess()}, gpu_param, right_feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -399,8 +409,11 @@ struct GPUHistMakerDevice { hist.HistogramExists(nidx_parent); } - void UpdatePosition(int nidx, RegTree::Node split_node) { + void UpdatePosition(int nidx, RegTree* p_tree) { + RegTree::Node split_node = (*p_tree)[nidx]; + auto split_type = p_tree->NodeSplitType(nidx); auto d_matrix = page->GetDeviceAccessor(device_id); + auto node_cats = dh::ToSpan(node_categories); row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), @@ -409,11 +422,17 @@ struct GPUHistMakerDevice { bst_float cut_value = d_matrix.GetFvalue(ridx, split_node.SplitIndex()); // Missing value - int new_position = 0; + bst_node_t new_position = 0; if (isnan(cut_value)) { new_position = split_node.DefaultChild(); } else { - if (cut_value <= split_node.SplitCond()) { + bool go_left = true; + if (split_type == FeatureType::kCategorical) { + go_left = common::Decision(node_cats, common::AsCat(cut_value)); + } else { + go_left = cut_value <= split_node.SplitCond(); + } + if (go_left) { new_position = split_node.LeftChild(); } else { new_position = split_node.RightChild(); @@ -428,48 +447,77 @@ struct GPUHistMakerDevice { // prediction cache void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); - dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(), - d_nodes.size() * sizeof(RegTree::Node), - cudaMemcpyHostToDevice)); + dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), + d_nodes.size() * sizeof(RegTree::Node), + cudaMemcpyHostToDevice)); + auto const& h_split_types = p_tree->GetSplitTypes(); + auto const& categories = p_tree->GetSplitCategories(); + auto const& categories_segments = p_tree->GetSplitCategoriesPtr(); + + dh::device_vector d_split_types; + dh::device_vector d_categories; + dh::device_vector d_categories_segments; + + dh::CopyToD(h_split_types, &d_split_types); + dh::CopyToD(categories, &d_categories); + dh::CopyToD(categories_segments, &d_categories_segments); if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); } if (page->n_rows == p_fmat->Info().num_row_) { - FinalisePositionInPage(page, dh::ToSpan(d_nodes)); + FinalisePositionInPage(page, dh::ToSpan(d_nodes), + dh::ToSpan(d_split_types), dh::ToSpan(d_categories), + dh::ToSpan(d_categories_segments)); } else { for (auto& batch : p_fmat->GetBatches(batch_param)) { - FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes)); + FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), + dh::ToSpan(d_split_types), dh::ToSpan(d_categories), + dh::ToSpan(d_categories_segments)); } } } - void FinalisePositionInPage(EllpackPageImpl* page, const common::Span d_nodes) { + void FinalisePositionInPage(EllpackPageImpl *page, + const common::Span d_nodes, + common::Span d_feature_types, + common::Span categories, + common::Span categories_segments) { auto d_matrix = page->GetDeviceAccessor(device_id); row_partitioner->FinalisePosition( [=] __device__(size_t row_id, int position) { - if (!d_matrix.IsInRange(row_id)) { - return RowPartitioner::kIgnoredTreePosition; - } - auto node = d_nodes[position]; + // What happens if user prune the tree? + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; + } + auto node = d_nodes[position]; - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + bool go_left = true; + if (common::IsCat(d_feature_types, position)) { + auto node_cats = + categories.subspan(categories_segments[position].beg, + categories_segments[position].size); + go_left = common::Decision(node_cats, common::AsCat(element)); + } else { + go_left = element <= node.SplitCond(); + } + if (go_left) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; } - } - node = d_nodes[position]; - } - return position; - }); + return position; + }); } void UpdatePredictionCache(bst_float* out_preds_d) { @@ -561,11 +609,30 @@ struct GPUHistMakerDevice { auto left_weight = candidate.left_weight * param.learning_rate; auto right_weight = candidate.right_weight * param.learning_rate; - tree.ExpandNode(candidate.nid, candidate.split.findex, - candidate.split.fvalue, candidate.split.dir == kLeftDir, - base_weight, left_weight, right_weight, - candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + auto is_cat = candidate.split.is_cat; + if (is_cat) { + auto cat = common::AsCat(candidate.split.fvalue); + std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1))); + LBitField32 cats_bits(split_cats); + cats_bits.Set(cat); + node_categories.resize(split_cats.size()); + dh::safe_cuda(cudaMemcpyAsync( + node_categories.data().get(), split_cats.data(), + split_cats.size() * sizeof(uint32_t), cudaMemcpyHostToDevice)); + tree.ExpandCategorical( + candidate.nid, candidate.split.findex, split_cats, + candidate.split.dir == kLeftDir, base_weight, left_weight, + right_weight, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + } else { + tree.ExpandNode(candidate.nid, candidate.split.findex, + candidate.split.fvalue, candidate.split.dir == kLeftDir, + base_weight, left_weight, right_weight, + candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + } // Set up child constraints auto left_child = tree[candidate.nid].LeftChild(); @@ -664,7 +731,7 @@ struct GPUHistMakerDevice { if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { monitor.Start("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); @@ -752,8 +819,10 @@ class GPUHistMakerSpecialised { }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); dh::safe_cuda(cudaSetDevice(device_)); + info_->feature_types.SetDevice(device_); maker.reset(new GPUHistMakerDevice(device_, page, + info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_, column_sampling_seed, diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 9b025ed9b17a..c2c9cdb296ec 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -6,6 +6,7 @@ #include #include #include "../../../src/common/hist_util.h" +#include "../../../src/common/categorical.h" #include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/adapter.h" @@ -60,13 +61,6 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, } #endif -inline std::shared_ptr -GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns) { - data::DenseAdapter adapter(x.data(), num_rows, num_columns); - return std::shared_ptr(new data::SimpleDMatrix( - &adapter, std::numeric_limits::quiet_NaN(), 1)); -} - inline std::shared_ptr GetExternalMemoryDMatrixFromData( const std::vector& x, int num_rows, int num_columns, size_t page_size, const dmlc::TemporaryDirectory& tempdir) { @@ -134,12 +128,14 @@ inline void TestRank(const std::vector &column_cuts, inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, const std::vector& sorted_column, const std::vector& sorted_weights, - size_t num_bins) { - + size_t num_bins, bool is_cat = false) { // Check the endpoints are correct CHECK_GT(sorted_column.size(), 0); - EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); - EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + if (is_cat) { + EXPECT_EQ(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + } else { + EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + } EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back()); // Check the cuts are sorted @@ -174,7 +170,9 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, int num_bins) { // Collect data into columns - std::vector> columns(dmat->Info().num_col_); + auto const& info = dmat->Info(); + auto const& ft = info.feature_types.ConstHostSpan(); + std::vector> columns(info.num_col_); for (auto& batch : dmat->GetBatches()) { ASSERT_GT(batch.Size(), 0ul); for (auto i = 0ull; i < batch.Size(); i++) { @@ -184,7 +182,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, } } // Sort - for (auto i = 0ull; i < columns.size(); i++) { + for (auto i = 0ul; i < columns.size(); i++) { auto& col = columns.at(i); const auto& w = dmat->Info().weights_.HostVector(); std::vector index(col.size()); @@ -201,7 +199,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, } } - ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins); + ValidateColumn(cuts, i, sorted_column, sorted_weights, num_bins, IsCat(ft, i)); } } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index fa748de1cc6c..345bfe5d4c4c 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -1,11 +1,12 @@ #include +#include "test_hist_util.h" #include "test_quantile.h" + #include "../../../src/common/quantile.h" #include "../../../src/common/hist_util.h" namespace xgboost { namespace common { - TEST(Quantile, LoadBalance) { size_t constexpr kRows = 1000, kCols = 100; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); @@ -183,5 +184,17 @@ TEST(Quantile, SameOnAllWorkers) { rabit::Finalize(); #endif // defined(__unix__) } + +TEST(CPUQuantile, FromOneHot) { + std::vector x = BasicOneHotEncodedData(); + auto m = GetDMatrixFromData(x, 5, 3); + + int32_t max_bins = 16; + HistogramCuts cuts = SketchOnDMatrix(m.get(), max_bins); + + std::vector const& h_cuts_ptr = cuts.Ptrs(); + std::vector h_cuts_values = cuts.Values(); + ValidateBasicOneHot(h_cuts_ptr, h_cuts_values); +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 8bb426aa4d76..398148f9f910 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,6 +1,7 @@ #include #include "test_quantile.h" #include "../helpers.h" +#include "test_quantile.h" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" @@ -95,8 +96,7 @@ void TestQuantileElemRank(int32_t device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, - MetaInfo const &info) { + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, 0); @@ -293,9 +293,8 @@ TEST(GPUQuantile, AllReduceBasic) { } constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, - MetaInfo const &info) { - // Set up single node version + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + // Set up single node version; HostDeviceVector ft; SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0); @@ -444,5 +443,17 @@ TEST(GPUQuantile, SameOnAllWorkers) { return; #endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL) } + +TEST(GPUQuantile, FromOneHot) { + std::vector x = BasicOneHotEncodedData(); + auto m = GetDMatrixFromData(x, 5, 3); + int32_t max_bins = 16; + auto cuts = DeviceSketch(0, m.get(), max_bins); + + std::vector const& h_cuts_ptr = cuts.Ptrs(); + std::vector h_cuts_values = cuts.Values(); + + ValidateBasicOneHot(h_cuts_ptr, h_cuts_values); +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index e91f19ef84a8..e17465752803 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -1,4 +1,9 @@ +#ifndef XGBOOST_TEST_QUANTILE_H_ +#define XGBOOST_TEST_QUANTILE_H_ + #include +#include + #include #include #include @@ -50,5 +55,33 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { } } } +inline auto BasicOneHotEncodedData() { + std::vector x { + 0, 1, 0, + 0, 1, 0, + 0, 1, 0, + 0, 0, 1, + 1, 0, 0 + }; + return x; +} + +inline void ValidateBasicOneHot(std::vector const &h_cuts_ptr, + std::vector const &h_cuts_values) { + size_t const cols = 3; + ASSERT_EQ(h_cuts_ptr.size(), cols + 1); + ASSERT_EQ(h_cuts_values.size(), cols * 2); + + for (size_t i = 1; i < h_cuts_ptr.size(); ++i) { + auto feature = + common::Span(h_cuts_values) + .subspan(h_cuts_ptr[i - 1], h_cuts_ptr[i] - h_cuts_ptr[i - 1]); + EXPECT_EQ(feature.size(), 2); + // 0 is discarded as min value. + EXPECT_EQ(feature[0], 1.0f); + EXPECT_GT(feature[1], 1.0f); + } +} } // namespace common } // namespace xgboost +#endif // XGBOOST_TEST_QUANTILE_H_ diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 23d566068619..183eac436731 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -7,11 +7,14 @@ #include "../helpers.h" #include "../histogram_helpers.h" +#include "../common/test_quantile.h" #include "gtest/gtest.h" #include "../../../src/common/categorical.h" #include "../../../src/common/hist_util.h" #include "../../../src/data/ellpack_page.cuh" +#include "../../../src/data/adapter.h" +#include "../../../src/data/simple_dmatrix.h" namespace xgboost { @@ -117,6 +120,25 @@ TEST(EllpackPage, FromCategoricalBasic) { } } +TEST(EllpackPage, FromOneHot) { + std::vector x = common::BasicOneHotEncodedData(); + auto m = GetDMatrixFromData(x, 5, 3); + int32_t max_bins = 16; + BatchParam p(0, max_bins); + auto ellpack = EllpackPage(m.get(), p); + auto accessor = ellpack.Impl()->GetDeviceAccessor(0); + + std::vector h_cuts_ptr(accessor.feature_segments.size()); + dh::CopyDeviceSpanToVector(&h_cuts_ptr, accessor.feature_segments); + std::vector h_cuts_values(accessor.gidx_fvalue_map.size()); + dh::CopyDeviceSpanToVector(&h_cuts_values, accessor.gidx_fvalue_map); + + size_t const cols = 3; + ASSERT_EQ(h_cuts_ptr.size(), cols + 1); + ASSERT_EQ(h_cuts_values.size(), cols * 2); + common::ValidateBasicOneHot(h_cuts_ptr, h_cuts_values); +} + struct ReadRowFunction { EllpackDeviceAccessor matrix; int row; @@ -234,5 +256,4 @@ TEST(EllpackPage, Compact) { } } } - } // namespace xgboost diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 585acf1790b6..c928c2886d91 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -160,9 +160,11 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT dmlc::Error); } + TEST(GpuPredictor, LesserFeatures) { TestPredictionWithLesserFeatures("gpu_predictor"); } + // Very basic test of empty model TEST(GPUPredictor, ShapStump) { cudaSetDevice(0); @@ -189,6 +191,7 @@ TEST(GPUPredictor, ShapStump) { EXPECT_EQ(phis[4], 0.0); EXPECT_EQ(phis[5], param.base_score); } + TEST(GPUPredictor, Shap) { LearnerModelParam param; param.num_feature = 1; @@ -219,5 +222,8 @@ TEST(GPUPredictor, Shap) { } } +TEST(GPUPredictor, CategoricalPrediction) { + TestCategoricalPrediction("gpu_predictor"); +} } // namespace predictor } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 3005d585f5a6..7f6de563125b 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -12,6 +12,8 @@ #include "../helpers.h" #include "../../../src/common/io.h" +#include "../../../src/common/categorical.h" +#include "../../../src/common/bitfield.h" namespace xgboost { TEST(Predictor, PredictionCache) { @@ -27,7 +29,7 @@ TEST(Predictor, PredictionCache) { }; add_cache(); - ASSERT_EQ(container.Container().size(), 0); + ASSERT_EQ(container.Container().size(), 0ul); add_cache(); EXPECT_ANY_THROW(container.Entry(m)); } @@ -174,4 +176,55 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) { } #endif // defined(XGBOOST_USE_CUDA) } + +void TestCategoricalPrediction(std::string name) { + size_t constexpr kCols = 10; + PredictionCacheEntry out_predictions; + + LearnerModelParam param; + param.num_feature = kCols; + param.num_output_group = 1; + param.base_score = 0.5; + + gbm::GBTreeModel model(¶m); + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + auto& p_tree = trees.front(); + + uint32_t split_ind = 3; + bst_cat_t split_cat = 4; + float left_weight = 1.3f; + float right_weight = 1.7f; + + std::vector split_cats(LBitField32::ComputeStorageSize(split_cat)); + LBitField32 cats_bits(split_cats); + cats_bits.Set(split_cat); + + p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, + left_weight, right_weight, + 3.0f, 2.2f, 7.0f, 9.0f); + model.CommitModel(std::move(trees), 0); + + GenericParameter runtime; + runtime.gpu_id = 0; + std::unique_ptr predictor{ + Predictor::Create(name.c_str(), &runtime)}; + + std::vector row(kCols); + row[split_ind] = split_cat; + auto m = GetDMatrixFromData(row, 1, kCols); + + predictor->PredictBatch(m.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.Size(), 1ul); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], + right_weight + param.base_score); + + row[split_ind] = split_cat + 1; + m = GetDMatrixFromData(row, 1, kCols); + out_predictions.version = 0; + predictor->PredictBatch(m.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], + left_weight + param.base_score); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index b6a3180111f2..68e034e0a581 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -61,6 +61,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor, int32_t device = -1); void TestPredictionWithLesserFeatures(std::string preditor_name); + +void TestCategoricalPrediction(std::string name); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_ diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 84b2d13c7fb9..d90d47b150a7 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -15,7 +15,7 @@ auto ZeroParam() { } } // anonymous namespace -TEST(GpuHist, EvaluateSingleSplit) { +void TestEvaluateSingleSplit(bool is_categorical) { thrust::device_vector out_splits(1); GradientPair parent_sum(0.0, 1.0); TrainParam tparam = ZeroParam(); @@ -33,11 +33,19 @@ TEST(GpuHist, EvaluateSingleSplit) { thrust::device_vector feature_histogram = std::vector{ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}}; + thrust::device_vector monotonic_constraints(feature_set.size(), 0); + dh::device_vector feature_types(feature_set.size(), + FeatureType::kCategorical); + common::Span d_feature_types; + if (is_categorical) { + d_feature_types = dh::ToSpan(feature_types); + } EvaluateSplitInputs input{1, parent_sum, param, dh::ToSpan(feature_set), + d_feature_types, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -55,6 +63,14 @@ TEST(GpuHist, EvaluateSingleSplit) { parent_sum.GetHess()); } +TEST(GpuHist, EvaluateSingleSplit) { + TestEvaluateSingleSplit(false); +} + +TEST(GpuHist, EvaluateCategoricalSplit) { + TestEvaluateSingleSplit(true); +} + TEST(GpuHist, EvaluateSingleSplitMissing) { thrust::device_vector out_splits(1); GradientPair parent_sum(1.0, 1.5); @@ -74,6 +90,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -134,6 +151,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -174,6 +192,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -215,6 +234,7 @@ TEST(GpuHist, EvaluateSplits) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -224,6 +244,7 @@ TEST(GpuHist, EvaluateSplits) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -241,6 +262,5 @@ TEST(GpuHist, EvaluateSplits) { EXPECT_EQ(result_right.findex, 0); EXPECT_EQ(result_right.fvalue, 1.0); } - } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 4ff7ec1062b0..2975d34fc3ac 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -4,6 +4,7 @@ #include "../../../../src/common/categorical.h" #include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/histogram.cuh" +#include "../../../../src/common/categorical.h" namespace xgboost { namespace tree { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ec598c5fce56..2c4b6be34dd6 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -80,7 +80,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols, + GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); @@ -154,19 +154,18 @@ TEST(GpuHist, EvaluateRootSplit) { TrainParam param; - std::vector> args { - {"max_depth", "1"}, - {"max_leaves", "0"}, - - // Disable all other parameters. - {"colsample_bynode", "1"}, - {"colsample_bylevel", "1"}, - {"colsample_bytree", "1"}, - {"min_child_weight", "0.01"}, - {"reg_alpha", "0"}, - {"reg_lambda", "0"}, - {"max_delta_step", "0"} - }; + std::vector> args{ + {"max_depth", "1"}, + {"max_leaves", "0"}, + + // Disable all other parameters. + {"colsample_bynode", "1"}, + {"colsample_bylevel", "1"}, + {"colsample_bytree", "1"}, + {"min_child_weight", "0.01"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"}, + {"max_delta_step", "0"}}; param.Init(args); for (size_t i = 0; i < kNCols; ++i) { param.monotone_constraints.emplace_back(0); @@ -178,7 +177,7 @@ TEST(GpuHist, EvaluateRootSplit) { auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; GPUHistMakerDevice - maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param); + maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {}; @@ -257,7 +256,6 @@ void TestHistogramIndexImpl() { ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); ASSERT_EQ(maker->page->gidx_buffer.Size(), maker_ext->page->gidx_buffer.Size()); - } TEST(GpuHist, TestHistogramIndex) { diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index ce555bd6a5a6..a1b607865305 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -41,6 +41,52 @@ def test_gpu_hist(self, param, num_rounds, dataset): note(result) assert tm.non_increasing(result['train'][dataset.metric]) + def run_categorical_basic(self, cat, onehot, label, rounds): + by_etl_results = {} + by_builtin_results = {} + + parameters = {'tree_method': 'gpu_hist', + 'predictor': 'gpu_predictor', + 'enable_experimental_json_serialization': True} + + m = xgb.DMatrix(onehot, label, enable_categorical=True) + xgb.train(parameters, m, + num_boost_round=rounds, + evals=[(m, 'Train')], evals_result=by_etl_results) + + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train(parameters, m, + num_boost_round=rounds, + evals=[(m, 'Train')], evals_result=by_builtin_results) + np.testing.assert_allclose( + np.array(by_etl_results['Train']['rmse']), + np.array(by_builtin_results['Train']['rmse']), + rtol=1e-4) + assert tm.non_increasing(by_builtin_results['Train']['rmse']) + + @given(strategies.integers(10, 400), strategies.integers(5, 10), + strategies.integers(1, 6), strategies.integers(4, 8)) + @settings(deadline=None) + @pytest.mark.skipif(**tm.no_pandas()) + def test_categorical(self, rows, cols, rounds, cats): + import pandas as pd + rng = np.random.RandomState(1994) + + pd_dict = {} + for i in range(cols): + c = rng.randint(low=0, high=cats+1, size=rows) + pd_dict[str(i)] = pd.Series(c, dtype=np.int64) + + df = pd.DataFrame(pd_dict) + label = df.iloc[:, 0] + for i in range(0, cols-1): + label += df.iloc[:, i] + label += 1 + df = df.astype('category') + x = pd.get_dummies(df) + + self.run_categorical_basic(df, x, label, rounds) + @pytest.mark.skipif(**tm.no_cupy()) @given(parameter_strategy, strategies.integers(1, 20), tm.dataset_strategy) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index a06bfc28361f..0b9d68491974 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -165,7 +165,8 @@ def test_dask_dataframe(self, local_cuda_cluster): @settings(deadline=duration(seconds=120)) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) - @pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], indirect=['local_cuda_cluster']) + @pytest.mark.parametrize('local_cuda_cluster', [{'n_workers': 2}], + indirect=['local_cuda_cluster']) @pytest.mark.mgpu def test_gpu_hist(self, params, num_rounds, dataset, local_cuda_cluster): with Client(local_cuda_cluster) as client: diff --git a/tests/python/test_with_pandas.py b/tests/python/test_with_pandas.py index 56aa3e9f3247..7ea8012425ff 100644 --- a/tests/python/test_with_pandas.py +++ b/tests/python/test_with_pandas.py @@ -67,7 +67,8 @@ def test_pandas(self): # 0 1 1 0 0 # 1 2 0 1 0 # 2 3 0 0 1 - result, _, _ = xgb.data._transform_pandas_df(dummies) + result, _, _ = xgb.data._transform_pandas_df(dummies, + enable_categorical=False) exp = np.array([[1., 1., 0., 0.], [2., 0., 1., 0.], [3., 0., 0., 1.]]) @@ -109,6 +110,16 @@ def test_pandas(self): assert dm.num_row() == 2 assert dm.num_col() == 6 + def test_pandas_categorical(self): + rng = np.random.RandomState(1994) + rows = 100 + X = rng.randint(3, 7, size=rows) + X = pd.Series(X, dtype="category") + X = pd.DataFrame({'f0': X}) + y = rng.randn(rows) + m = xgb.DMatrix(X, y, enable_categorical=True) + assert m.feature_types[0] == 'categorical' + def test_pandas_sparse(self): import pandas as pd rows = 100 @@ -129,15 +140,15 @@ def test_pandas_label(self): # label must be a single column df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, - None, None, 'label', 'float') + False, None, None, 'label', 'float') # label must be supported dtype df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, - None, None, 'label', 'float') + False, None, None, 'label', 'float') df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) - result, _, _ = xgb.data._transform_pandas_df(df, None, None, + result, _, _ = xgb.data._transform_pandas_df(df, False, None, None, 'label', 'float') np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], dtype=float)) diff --git a/tests/python/testing.py b/tests/python/testing.py index f6a05a5d7d22..97a77be4f4b1 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -53,6 +53,17 @@ def no_dt(): 'reason': 'Datatable is not installed.'} +def no_cupy(): + reason = 'cupy is not installed.' + try: + import cupy # noqa + return {'condition': False, + 'reason': reason} + except ImportError: + return {'condition': True, + 'reason': reason} + + def no_matplotlib(): reason = 'Matplotlib is not installed.' try: