Skip to content

Commit

Permalink
Remove the device parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 31, 2020
1 parent 7f16aa4 commit 328eb14
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 40 deletions.
5 changes: 1 addition & 4 deletions include/xgboost/c_api.h
Expand Up @@ -500,13 +500,10 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
* uint32_t = 3
* uint64_t = 4
*
* \param device Where the data is resided. Currently supports only CPU input data.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type,
int device);
void *data, bst_ulong size, int type);

/*!
* \brief Get feature info in a thread local buffer.
Expand Down
3 changes: 1 addition & 2 deletions python-package/xgboost/data.py
Expand Up @@ -549,8 +549,7 @@ def _meta_from_numpy(data, field, dtype, handle, is_feature: bool = False):
c_str(field),
data,
size,
c_type,
ctypes.c_int(-1),
c_type
))
else:
data = _maybe_np_slice(data, dtype)
Expand Down
3 changes: 1 addition & 2 deletions src/c_api/c_api.cc
Expand Up @@ -318,10 +318,9 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,

XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field,
void *data, xgboost::bst_ulong size,
int type, int device) {
int type) {
API_BEGIN();
CHECK_HANDLE();
CHECK_EQ(device, -1) << "GPU data for feature info is not yet supported.";
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
info.SetFeatureInfo(field, data, static_cast<DataType>(type), size);
API_END();
Expand Down
38 changes: 6 additions & 32 deletions src/data/data.cc
Expand Up @@ -461,38 +461,12 @@ void MetaInfo::SetFeatureInfo(const char *c_field, const void *info, DataType ty
CHECK_EQ(field, "feature_weight") << "Only feature weight is supported for feature info.";
auto& h_feature_weights = feature_weigths.HostVector();
h_feature_weights.resize(size);
switch (type) {
case DataType::kFloat32: {
auto ptr = static_cast<float const*>(info);
std::copy(ptr, ptr + size, h_feature_weights.begin());
break;
}
case DataType::kDouble: {
auto ptr = static_cast<double const*>(info);
std::copy(ptr, ptr + size, h_feature_weights.begin());
break;
}
case DataType::kUInt32: {
auto ptr = static_cast<uint32_t const*>(info);
std::copy(ptr, ptr + size, h_feature_weights.begin());
break;
}
case DataType::kUInt64: {
auto ptr = static_cast<uint64_t const*>(info);
std::copy(ptr, ptr + size, h_feature_weights.begin());
break;
}
case DataType::kStr: {
LOG(FATAL) << "Use str feature info setter and getter instead.";
break;
}
default:
LOG(FATAL) << "Unknown data type for feature info: " << static_cast<int>(type);
}
std::for_each(h_feature_weights.cbegin(), h_feature_weights.cend(),
[](float w) {
CHECK_GE(w, 0) << "Feature weight must be greater than 0.";
});
DISPATCH_CONST_PTR(
type, info, cast_dptr,
std::copy(cast_dptr, cast_dptr + size, h_feature_weights.begin()));
bool valid = std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
[](float w) { return w >= 0; });
CHECK(valid) << "Feature weight must be greater than 0.";
}

void MetaInfo::GetFeatureInfo(const char *c_field, DataType *out_type,
Expand Down

0 comments on commit 328eb14

Please sign in to comment.