Skip to content

Commit

Permalink
Add GPU support.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 18, 2020
1 parent 171b062 commit cbd8613
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/data/data.cu
Expand Up @@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
std::partial_sum(out->begin(), out->end(), out->begin());
}

namespace {
// thrust::all_of tries to copy lambda function.
struct AllOfOp {
__device__ bool operator()(float w) {
return w >= 0;
}
};
} // anonymous namespace

void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
Expand All @@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_);
return;
} else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_);
return;
} else if (key == "label_upper_bound") {
CopyInfoImpl(array_interface, &labels_upper_bound_);
return;
} else if (key == "feature_weights") {
CopyInfoImpl(array_interface, &feature_weigths);
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
auto valid =
thrust::all_of(thrust::device, d_feature_weights.data(),
d_feature_weights.data() + d_feature_weights.size(),
AllOfOp{});
CHECK(valid) << "Feature weight must be greater than 0.";
return;
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
}
Expand Down
14 changes: 14 additions & 0 deletions tests/python-gpu/test_device_quantile_dmatrix.py
Expand Up @@ -16,6 +16,20 @@ def test_dmatrix_numpy_init(self):
match='is not supported for DeviceQuantileDMatrix'):
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))

@pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_feature_weights(self):
import cupy as cp
rng = cp.random.RandomState(1994)
data = rng.randn(5, 5)
m = xgb.DMatrix(data)

feature_weights = rng.uniform(size=5)
m.set_info(feature_weights=feature_weights)

cp.testing.assert_array_equal(
cp.array(m.get_float_info('feature_weights')),
feature_weights.astype(np.float32))

@pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_cupy_init(self):
import cupy as cp
Expand Down

0 comments on commit cbd8613

Please sign in to comment.