From 328eb142ec7bd3e43507658b699376db7dac5f17 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 31 Jul 2020 19:28:09 +0800 Subject: [PATCH] Remove the device parameter. --- include/xgboost/c_api.h | 5 +---- python-package/xgboost/data.py | 3 +-- src/c_api/c_api.cc | 3 +-- src/data/data.cc | 38 ++++++---------------------------- 4 files changed, 9 insertions(+), 40 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 354e4b3e8d54..ddb01e7d98db 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -500,13 +500,10 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, * uint32_t = 3 * uint64_t = 4 * - * \param device Where the data is resided. Currently supports only CPU input data. - * * \return 0 when success, -1 when failure happens */ XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field, - void *data, bst_ulong size, int type, - int device); + void *data, bst_ulong size, int type); /*! * \brief Get feature info in a thread local buffer. diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 9a6bbac9e193..39590e86f752 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -549,8 +549,7 @@ def _meta_from_numpy(data, field, dtype, handle, is_feature: bool = False): c_str(field), data, size, - c_type, - ctypes.c_int(-1), + c_type )) else: data = _maybe_np_slice(data, dtype) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index e2f5ea9d0538..54e675bb1883 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -318,10 +318,9 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field, XGB_DLL int XGDMatrixSetFeatureInfo(DMatrixHandle handle, const char *field, void *data, xgboost::bst_ulong size, - int type, int device) { + int type) { API_BEGIN(); CHECK_HANDLE(); - CHECK_EQ(device, -1) << "GPU data for feature info is not yet supported."; auto &info = static_cast *>(handle)->get()->Info(); info.SetFeatureInfo(field, data, static_cast(type), size); API_END(); diff --git a/src/data/data.cc b/src/data/data.cc index cdf78468d452..00cb62ce4eeb 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -461,38 +461,12 @@ void MetaInfo::SetFeatureInfo(const char *c_field, const void *info, DataType ty CHECK_EQ(field, "feature_weight") << "Only feature weight is supported for feature info."; auto& h_feature_weights = feature_weigths.HostVector(); h_feature_weights.resize(size); - switch (type) { - case DataType::kFloat32: { - auto ptr = static_cast(info); - std::copy(ptr, ptr + size, h_feature_weights.begin()); - break; - } - case DataType::kDouble: { - auto ptr = static_cast(info); - std::copy(ptr, ptr + size, h_feature_weights.begin()); - break; - } - case DataType::kUInt32: { - auto ptr = static_cast(info); - std::copy(ptr, ptr + size, h_feature_weights.begin()); - break; - } - case DataType::kUInt64: { - auto ptr = static_cast(info); - std::copy(ptr, ptr + size, h_feature_weights.begin()); - break; - } - case DataType::kStr: { - LOG(FATAL) << "Use str feature info setter and getter instead."; - break; - } - default: - LOG(FATAL) << "Unknown data type for feature info: " << static_cast(type); - } - std::for_each(h_feature_weights.cbegin(), h_feature_weights.cend(), - [](float w) { - CHECK_GE(w, 0) << "Feature weight must be greater than 0."; - }); + DISPATCH_CONST_PTR( + type, info, cast_dptr, + std::copy(cast_dptr, cast_dptr + size, h_feature_weights.begin())); + bool valid = std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(), + [](float w) { return w >= 0; }); + CHECK(valid) << "Feature weight must be greater than 0."; } void MetaInfo::GetFeatureInfo(const char *c_field, DataType *out_type,