From cbd8613941e0c2efe1cefab3c281a047838a81e8 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 14 Aug 2020 13:51:40 +0800 Subject: [PATCH] Add GPU support. --- src/data/data.cu | 24 +++++++++++++++++++ .../test_device_quantile_dmatrix.py | 14 +++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/data/data.cu b/src/data/data.cu index 5e63a828c207..15260498734d 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { std::partial_sum(out->begin(), out->end(), out->begin()); } +namespace { +// thrust::all_of tries to copy lambda function. +struct AllOfOp { + __device__ bool operator()(float w) { + return w >= 0; + } +}; +} // anonymous namespace + void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); auto const& j_arr = get(j_interface); @@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { } else if (key == "group") { CopyGroupInfoImpl(array_interface, &group_ptr_); return; + } else if (key == "label_lower_bound") { + CopyInfoImpl(array_interface, &labels_lower_bound_); + return; + } else if (key == "label_upper_bound") { + CopyInfoImpl(array_interface, &labels_upper_bound_); + return; + } else if (key == "feature_weights") { + CopyInfoImpl(array_interface, &feature_weigths); + auto d_feature_weights = feature_weigths.ConstDeviceSpan(); + auto valid = + thrust::all_of(thrust::device, d_feature_weights.data(), + d_feature_weights.data() + d_feature_weights.size(), + AllOfOp{}); + CHECK(valid) << "Feature weight must be greater than 0."; + return; } else { LOG(FATAL) << "Unknown metainfo: " << key; } diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index f0978a0afaf4..c44de28bd2ff 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -16,6 +16,20 @@ def test_dmatrix_numpy_init(self): match='is not supported for DeviceQuantileDMatrix'): xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) + @pytest.mark.skipif(**tm.no_cupy()) + def test_dmatrix_feature_weights(self): + import cupy as cp + rng = cp.random.RandomState(1994) + data = rng.randn(5, 5) + m = xgb.DMatrix(data) + + feature_weights = rng.uniform(size=5) + m.set_info(feature_weights=feature_weights) + + cp.testing.assert_array_equal( + cp.array(m.get_float_info('feature_weights')), + feature_weights.astype(np.float32)) + @pytest.mark.skipif(**tm.no_cupy()) def test_dmatrix_cupy_init(self): import cupy as cp