diff --git a/.ci/test.sh b/.ci/test.sh index 435614bb826..7fd09d8d20d 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -132,6 +132,16 @@ if [[ $TASK == "gpu" ]]; then exit 0 fi cmake -DUSE_GPU=ON -DOpenCL_INCLUDE_DIR=$AMDAPPSDK_PATH/include/ .. +elif [[ $TASK == "cuda" ]]; then + sed -i'.bak' 's/std::string device_type = "cpu";/std::string device_type = "cuda";/' $BUILD_DIRECTORY/include/LightGBM/config.h + grep -q 'std::string device_type = "cuda"' $BUILD_DIRECTORY/include/LightGBM/config.h || exit -1 # make sure that changes were really done + if [[ $METHOD == "pip" ]]; then + cd $BUILD_DIRECTORY/python-package && python setup.py sdist || exit -1 + pip install --user $BUILD_DIRECTORY/python-package/dist/lightgbm-$LGB_VER.tar.gz -v --install-option=--cuda || exit -1 + pytest $BUILD_DIRECTORY/tests/python_package_test || exit -1 + exit 0 + fi + cmake -DUSE_CUDA=ON .. elif [[ $TASK == "mpi" ]]; then if [[ $METHOD == "pip" ]]; then cd $BUILD_DIRECTORY/python-package && python setup.py sdist || exit -1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 78c6c0d18ef..b2e206fe5fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,16 @@ if(USE_GPU OR APPLE) cmake_minimum_required(VERSION 3.2) +elseif(USE_CUDA) + cmake_minimum_required(VERSION 3.16) else() cmake_minimum_required(VERSION 2.8) endif() -PROJECT(lightgbm) +if(USE_CUDA) + PROJECT(lightgbm LANGUAGES C CXX CUDA) +else() + PROJECT(lightgbm LANGUAGES C CXX) +endif() OPTION(USE_MPI "Enable MPI-based parallel learning" OFF) OPTION(USE_OPENMP "Enable OpenMP" ON) @@ -12,6 +18,7 @@ OPTION(USE_GPU "Enable GPU-accelerated training" OFF) OPTION(USE_SWIG "Enable SWIG to generate Java API" OFF) OPTION(USE_HDFS "Enable HDFS support (EXPERIMENTAL)" OFF) OPTION(USE_TIMETAG "Set to ON to output time costs" OFF) +OPTION(USE_CUDA "Enable CUDA-accelerated training (EXPERIMENTAL)" OFF) OPTION(USE_DEBUG "Set to ON for Debug mode" OFF) OPTION(BUILD_STATIC_LIB "Build static library" OFF) OPTION(BUILD_FOR_R "Set to ON if building lib_lightgbm for use with the R package" OFF) @@ -94,6 +101,10 @@ else() ADD_DEFINITIONS(-DUSE_SOCKET) endif(USE_MPI) +if(USE_CUDA) + SET(USE_OPENMP ON CACHE BOOL "CUDA requires OpenMP" FORCE) +endif(USE_CUDA) + if(USE_OPENMP) find_package(OpenMP REQUIRED) SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") @@ -123,6 +134,67 @@ if(USE_GPU) ADD_DEFINITIONS(-DUSE_GPU) endif(USE_GPU) +if(USE_CUDA) + find_package(CUDA REQUIRED) + include_directories(${CUDA_INCLUDE_DIRS}) + LIST(APPEND CMAKE_CUDA_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS} -Xcompiler=-fPIC -Xcompiler=-Wall) + CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS 6.0 6.1 6.2 7.0 7.5+PTX) + + LIST(APPEND CMAKE_CUDA_FLAGS ${CUDA_ARCH_FLAGS}) + if(USE_DEBUG) + SET(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") + else() + SET(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -lineinfo") + endif() + string(REPLACE ";" " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}") + message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") + + ADD_DEFINITIONS(-DUSE_CUDA) + if (NOT DEFINED CMAKE_CUDA_STANDARD) + set(CMAKE_CUDA_STANDARD 11) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) + endif() + + set(BASE_DEFINES + -DPOWER_FEATURE_WORKGROUPS=12 + -DUSE_CONSTANT_BUF=0 + ) + set(ALLFEATS_DEFINES + ${BASE_DEFINES} + -DENABLE_ALL_FEATURES + ) + set(FULLDATA_DEFINES + ${ALLFEATS_DEFINES} + -DIGNORE_INDICES + ) + + message(STATUS "ALLFEATS_DEFINES: ${ALLFEATS_DEFINES}") + message(STATUS "FULLDATA_DEFINES: ${FULLDATA_DEFINES}") + + function(add_histogram hsize hname hadd hconst hdir) + add_library(histo${hsize}${hname} OBJECT src/treelearner/kernels/histogram${hsize}.cu) + set_target_properties(histo${hsize}${hname} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + if(hadd) + list(APPEND histograms histo${hsize}${hname}) + set(histograms ${histograms} PARENT_SCOPE) + endif() + target_compile_definitions( + histo${hsize}${hname} PRIVATE + -DCONST_HESSIAN=${hconst} + ${hdir} + ) + endfunction() + + foreach (hsize _16_64_256) + add_histogram("${hsize}" "_sp_const" "True" "1" "${BASE_DEFINES}") + add_histogram("${hsize}" "_sp" "True" "0" "${BASE_DEFINES}") + add_histogram("${hsize}" "-allfeats_sp_const" "False" "1" "${ALLFEATS_DEFINES}") + add_histogram("${hsize}" "-allfeats_sp" "False" "0" "${ALLFEATS_DEFINES}") + add_histogram("${hsize}" "-fulldata_sp_const" "True" "1" "${FULLDATA_DEFINES}") + add_histogram("${hsize}" "-fulldata_sp" "True" "0" "${FULLDATA_DEFINES}") + endforeach() +endif(USE_CUDA) + if(USE_HDFS) find_package(JNI REQUIRED) find_path(HDFS_INCLUDE_DIR hdfs.h REQUIRED) @@ -228,6 +300,9 @@ file(GLOB SOURCES src/objective/*.cpp src/network/*.cpp src/treelearner/*.cpp +if(USE_CUDA) + src/treelearner/*.cu +endif(USE_CUDA) ) add_executable(lightgbm src/main.cpp ${SOURCES}) @@ -303,6 +378,19 @@ if(USE_GPU) TARGET_LINK_LIBRARIES(_lightgbm ${OpenCL_LIBRARY} ${Boost_LIBRARIES}) endif(USE_GPU) +if(USE_CUDA) + set_target_properties(lightgbm PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) + TARGET_LINK_LIBRARIES( + lightgbm + ${histograms} + ) + set_target_properties(_lightgbm PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) + TARGET_LINK_LIBRARIES( + _lightgbm + ${histograms} + ) +endif(USE_CUDA) + if(USE_HDFS) TARGET_LINK_LIBRARIES(lightgbm ${HDFS_CXX_LIBRARIES}) TARGET_LINK_LIBRARIES(_lightgbm ${HDFS_CXX_LIBRARIES}) diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 14d7a8098cf..dcd1353e152 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -1120,7 +1120,13 @@ GPU Parameters - ``gpu_use_dp`` :raw-html:`🔗︎`, default = ``false``, type = bool - - set this to ``true`` to use double precision math on GPU (by default single precision is used) + - set this to ``true`` to use double precision math on GPU (by default single precision is used in OpenCL implementation and double precision is used in CUDA implementation) + +- ``num_gpu`` :raw-html:`🔗︎`, default = ``1``, type = int, constraints: ``num_gpu > 0`` + + - number of GPUs + + - **Note**: can be used only in CUDA implementation .. end params list diff --git a/include/LightGBM/bin.h b/include/LightGBM/bin.h index 4f320698c83..987279e4771 100644 --- a/include/LightGBM/bin.h +++ b/include/LightGBM/bin.h @@ -288,6 +288,9 @@ class Bin { /*! \brief Number of all data */ virtual data_size_t num_data() const = 0; + /*! \brief Get data pointer */ + virtual void* get_data() = 0; + virtual void ReSize(data_size_t num_data) = 0; /*! diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index bfcb09a4004..5e190261390 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -965,9 +965,14 @@ struct Config { // desc = **Note**: refer to `GPU Targets <./GPU-Targets.rst#query-opencl-devices-in-your-system>`__ for more details int gpu_device_id = -1; - // desc = set this to ``true`` to use double precision math on GPU (by default single precision is used) + // desc = set this to ``true`` to use double precision math on GPU (by default single precision is used in OpenCL implementation and double precision is used in CUDA implementation) bool gpu_use_dp = false; + // check = >0 + // desc = number of GPUs + // desc = **Note**: can be used only in CUDA implementation + int num_gpu = 1; + #pragma endregion #pragma endregion diff --git a/include/LightGBM/cuda/cuda_utils.h b/include/LightGBM/cuda/cuda_utils.h new file mode 100644 index 00000000000..1054e09daf1 --- /dev/null +++ b/include/LightGBM/cuda/cuda_utils.h @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef LIGHTGBM_CUDA_CUDA_UTILS_H_ +#define LIGHTGBM_CUDA_CUDA_UTILS_H_ + +#ifdef USE_CUDA + +#include +#include +#include + +#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); } +inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) { + if (code != cudaSuccess) { + LightGBM::Log::Fatal("[CUDA] %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} + +#endif // USE_CUDA + +#endif // LIGHTGBM_CUDA_CUDA_UTILS_H_ diff --git a/include/LightGBM/cuda/vector_cudahost.h b/include/LightGBM/cuda/vector_cudahost.h new file mode 100644 index 00000000000..f81cc4dd905 --- /dev/null +++ b/include/LightGBM/cuda/vector_cudahost.h @@ -0,0 +1,86 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_ +#define LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_ + +#include + +#ifdef USE_CUDA +#include +#include +#endif +#include + +enum LGBM_Device { + lgbm_device_cpu, + lgbm_device_gpu, + lgbm_device_cuda +}; + +enum Use_Learner { + use_cpu_learner, + use_gpu_learner, + use_cuda_learner +}; + +namespace LightGBM { + +class LGBM_config_ { + public: + static int current_device; // Default: lgbm_device_cpu + static int current_learner; // Default: use_cpu_learner +}; + + +template +struct CHAllocator { + typedef T value_type; + CHAllocator() {} + template CHAllocator(const CHAllocator& other); + T* allocate(std::size_t n) { + T* ptr; + if (n == 0) return NULL; + #ifdef USE_CUDA + if (LGBM_config_::current_device == lgbm_device_cuda) { + cudaError_t ret = cudaHostAlloc(&ptr, n*sizeof(T), cudaHostAllocPortable); + if (ret != cudaSuccess) { + Log::Warning("Defaulting to malloc in CHAllocator!!!"); + ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16)); + } + } else { + ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16)); + } + #else + ptr = reinterpret_cast(_mm_malloc(n*sizeof(T), 16)); + #endif + return ptr; + } + + void deallocate(T* p, std::size_t n) { + (void)n; // UNUSED + if (p == NULL) return; + #ifdef USE_CUDA + if (LGBM_config_::current_device == lgbm_device_cuda) { + cudaPointerAttributes attributes; + cudaPointerGetAttributes(&attributes, p); + if ((attributes.type == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) { + cudaFreeHost(p); + } + } else { + _mm_free(p); + } + #else + _mm_free(p); + #endif + } +}; +template +bool operator==(const CHAllocator&, const CHAllocator&); +template +bool operator!=(const CHAllocator&, const CHAllocator&); + +} // namespace LightGBM + +#endif // LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_ diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 3cf82c2aa1d..2c6d74caa1d 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -589,6 +589,14 @@ class Dataset { return feature_groups_[i]->is_multi_val_; } + inline size_t FeatureGroupSizesInByte(int group) const { + return feature_groups_[group]->FeatureGroupSizesInByte(); + } + + inline void* FeatureGroupData(int group) const { + return feature_groups_[group]->FeatureGroupData(); + } + inline double RealThreshold(int i, uint32_t threshold) const { const int group = feature2group_[i]; const int sub_feature = feature2subfeature_[i]; diff --git a/include/LightGBM/feature_group.h b/include/LightGBM/feature_group.h index 2b17e98bb9c..3ba5c143f85 100644 --- a/include/LightGBM/feature_group.h +++ b/include/LightGBM/feature_group.h @@ -228,6 +228,17 @@ class FeatureGroup { return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin); } + inline size_t FeatureGroupSizesInByte() { + return bin_data_->SizesInByte(); + } + + inline void* FeatureGroupData() { + if (is_multi_val_) { + return nullptr; + } + return bin_data_->get_data(); + } + inline data_size_t Split(int sub_feature, const uint32_t* threshold, int num_threshold, bool default_left, const data_size_t* data_indices, data_size_t cnt, diff --git a/python-package/setup.py b/python-package/setup.py index 9104bc2694b..f8782fba47a 100644 --- a/python-package/setup.py +++ b/python-package/setup.py @@ -87,7 +87,7 @@ def silent_call(cmd, raise_error=False, error_msg=''): return 1 -def compile_cpp(use_mingw=False, use_gpu=False, use_mpi=False, +def compile_cpp(use_mingw=False, use_gpu=False, use_cuda=False, use_mpi=False, use_hdfs=False, boost_root=None, boost_dir=None, boost_include_dir=None, boost_librarydir=None, opencl_include_dir=None, opencl_library=None, @@ -115,6 +115,8 @@ def compile_cpp(use_mingw=False, use_gpu=False, use_mpi=False, cmake_cmd.append("-DOpenCL_INCLUDE_DIR={0}".format(opencl_include_dir)) if opencl_library: cmake_cmd.append("-DOpenCL_LIBRARY={0}".format(opencl_library)) + elif use_cuda: + cmake_cmd.append("-DUSE_CUDA=ON") if use_mpi: cmake_cmd.append("-DUSE_MPI=ON") if nomp: @@ -188,6 +190,7 @@ class CustomInstall(install): user_options = install.user_options + [ ('mingw', 'm', 'Compile with MinGW'), ('gpu', 'g', 'Compile GPU version'), + ('cuda', None, 'Compile CUDA version'), ('mpi', None, 'Compile MPI version'), ('nomp', None, 'Compile version without OpenMP support'), ('hdfs', 'h', 'Compile HDFS version'), @@ -205,6 +208,7 @@ def initialize_options(self): install.initialize_options(self) self.mingw = 0 self.gpu = 0 + self.cuda = 0 self.boost_root = None self.boost_dir = None self.boost_include_dir = None @@ -228,7 +232,7 @@ def run(self): open(LOG_PATH, 'wb').close() if not self.precompile: copy_files(use_gpu=self.gpu) - compile_cpp(use_mingw=self.mingw, use_gpu=self.gpu, use_mpi=self.mpi, + compile_cpp(use_mingw=self.mingw, use_gpu=self.gpu, use_cuda=self.cuda, use_mpi=self.mpi, use_hdfs=self.hdfs, boost_root=self.boost_root, boost_dir=self.boost_dir, boost_include_dir=self.boost_include_dir, boost_librarydir=self.boost_librarydir, opencl_include_dir=self.opencl_include_dir, opencl_library=self.opencl_library, diff --git a/src/application/application.cpp b/src/application/application.cpp index 21163a5a30e..d9be76d67c9 100644 --- a/src/application/application.cpp +++ b/src/application/application.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -38,6 +39,10 @@ Application::Application(int argc, char** argv) { if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) { Log::Fatal("No training/prediction data, application quit"); } + + if (config_.device_type == std::string("cuda")) { + LGBM_config_::current_device = lgbm_device_cuda; + } } Application::~Application() { diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 03f5fe25d55..fcb7185a151 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -17,6 +17,9 @@ namespace LightGBM { +int LGBM_config_::current_device = lgbm_device_cpu; +int LGBM_config_::current_learner = use_cpu_learner; + GBDT::GBDT() : iter_(0), train_data_(nullptr), @@ -58,6 +61,10 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective es_first_metric_only_ = config_->first_metric_only; shrinkage_rate_ = config_->learning_rate; + if (config_->device_type == std::string("cuda")) { + LGBM_config_::current_learner = use_cuda_learner; + } + // load forced_splits file if (!config->forcedsplits_filename.empty()) { std::ifstream forced_splits_file(config->forcedsplits_filename.c_str()); diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index a84b321531f..0d38385d5f0 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -479,10 +480,19 @@ class GBDT : public GBDTBase { std::vector> models_; /*! \brief Max feature index of training data*/ int max_feature_idx_; + +#ifdef USE_CUDA + /*! \brief First order derivative of training data */ + std::vector> gradients_; + /*! \brief Second order derivative of training data */ + std::vector> hessians_; +#else /*! \brief First order derivative of training data */ std::vector> gradients_; - /*! \brief Secend order derivative of training data */ + /*! \brief Second order derivative of training data */ std::vector> hessians_; +#endif + /*! \brief Store the indices of in-bag data */ std::vector> bag_data_indices_; /*! \brief Number of in-bag data */ diff --git a/src/c_api.cpp b/src/c_api.cpp index 61b3038e660..a389e8e47b1 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1611,10 +1611,14 @@ int LGBM_BoosterUpdateOneIterCustom(BoosterHandle handle, const float* hess, int* is_finished) { API_BEGIN(); - Booster* ref_booster = reinterpret_cast(handle); #ifdef SCORE_T_USE_DOUBLE + (void) handle; // UNUSED VARIABLE + (void) grad; // UNUSED VARIABLE + (void) hess; // UNUSED VARIABLE + (void) is_finished; // UNUSED VARIABLE Log::Fatal("Don't support custom loss function when SCORE_T_USE_DOUBLE is enabled"); #else + Booster* ref_booster = reinterpret_cast(handle); if (ref_booster->TrainOneIter(grad, hess)) { *is_finished = 1; } else { diff --git a/src/io/config.cpp b/src/io/config.cpp index d569a7401e1..6878896deb5 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include #include @@ -126,6 +127,8 @@ void GetDeviceType(const std::unordered_map& params, s *device_type = "cpu"; } else if (value == std::string("gpu")) { *device_type = "gpu"; + } else if (value == std::string("cuda")) { + *device_type = "cuda"; } else { Log::Fatal("Unknown device type %s", value.c_str()); } @@ -206,6 +209,9 @@ void Config::Set(const std::unordered_map& params) { GetMetricType(params, &metric); GetObjectiveType(params, &objective); GetDeviceType(params, &device_type); + if (device_type == std::string("cuda")) { + LGBM_config_::current_device = lgbm_device_cuda; + } GetTreeLearnerType(params, &tree_learner); GetMembersFromString(params); @@ -319,11 +325,18 @@ void Config::CheckParamConflict() { num_leaves = static_cast(full_num_leaves); } } - // force col-wise for gpu - if (device_type == std::string("gpu")) { + // force col-wise for gpu & CUDA + if (device_type == std::string("gpu") || device_type == std::string("cuda")) { force_col_wise = true; force_row_wise = false; } + + // force gpu_use_dp for CUDA + if (device_type == std::string("cuda") && !gpu_use_dp) { + Log::Warning("CUDA currently requires double precision calculations."); + gpu_use_dp = true; + } + // min_data_in_leaf must be at least 2 if path smoothing is active. This is because when the split is calculated // the count is calculated using the proportion of hessian in the leaf which is rounded up to nearest int, so it can // be 1 when there is actually no data in the leaf. In rare cases this can cause a bug because with path smoothing the diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index b14af67fd30..ad102020322 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -296,6 +296,7 @@ const std::unordered_set& Config::parameter_set() { "gpu_platform_id", "gpu_device_id", "gpu_use_dp", + "num_gpu", }); return params; } @@ -611,6 +612,9 @@ void Config::GetMembersFromString(const std::unordered_map #include +#include #include #include #include @@ -334,13 +335,24 @@ void Dataset::Construct(std::vector>* bin_mappers, "constant."); } auto features_in_group = NoGroup(used_features); + + auto is_sparse = io_config.is_enable_sparse; + if (io_config.device_type == std::string("cuda")) { + LGBM_config_::current_device = lgbm_device_cuda; + if (is_sparse) { + Log::Warning("Using sparse features with CUDA is currently not supported."); + } + is_sparse = false; + } + std::vector group_is_multi_val(used_features.size(), 0); if (io_config.enable_bundle && !used_features.empty()) { + bool lgbm_is_gpu_used = io_config.device_type == std::string("gpu") || io_config.device_type == std::string("cuda"); features_in_group = FastFeatureBundling( *bin_mappers, sample_non_zero_indices, sample_values, num_per_col, num_sample_col, static_cast(total_sample_cnt), - used_features, num_data_, io_config.device_type == std::string("gpu"), - io_config.is_enable_sparse, &group_is_multi_val); + used_features, num_data_, lgbm_is_gpu_used, + is_sparse, &group_is_multi_val); } num_features_ = 0; diff --git a/src/io/dense_bin.hpp b/src/io/dense_bin.hpp index e821fe32f08..4a1cc43fa79 100644 --- a/src/io/dense_bin.hpp +++ b/src/io/dense_bin.hpp @@ -7,6 +7,7 @@ #define LIGHTGBM_IO_DENSE_BIN_HPP_ #include +#include #include #include @@ -364,6 +365,8 @@ class DenseBin : public Bin { data_size_t num_data() const override { return num_data_; } + void* get_data() override { return data_.data(); } + void FinishLoad() override { if (IS_4BIT) { if (buf_.empty()) { @@ -458,7 +461,11 @@ class DenseBin : public Bin { private: data_size_t num_data_; +#ifdef USE_CUDA + std::vector> data_; +#else std::vector> data_; +#endif std::vector buf_; DenseBin(const DenseBin& other) diff --git a/src/io/sparse_bin.hpp b/src/io/sparse_bin.hpp index 07f57c4480a..1fc07657609 100644 --- a/src/io/sparse_bin.hpp +++ b/src/io/sparse_bin.hpp @@ -409,6 +409,8 @@ class SparseBin : public Bin { data_size_t num_data() const override { return num_data_; } + void* get_data() override { return nullptr; } + void FinishLoad() override { // get total non zero size size_t pair_cnt = 0; diff --git a/src/treelearner/cuda_kernel_launcher.cu b/src/treelearner/cuda_kernel_launcher.cu new file mode 100644 index 00000000000..8ceb5b813c9 --- /dev/null +++ b/src/treelearner/cuda_kernel_launcher.cu @@ -0,0 +1,171 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifdef USE_CUDA + +#include "cuda_kernel_launcher.h" + +#include + +#include + +#include + +namespace LightGBM { + +void cuda_histogram( + int histogram_size, + data_size_t leaf_num_data, + data_size_t num_data, + bool use_all_features, + bool is_constant_hessian, + int num_workgroups, + cudaStream_t stream, + uint8_t* arg0, + uint8_t* arg1, + data_size_t arg2, + data_size_t* arg3, + data_size_t arg4, + score_t* arg5, + score_t* arg6, + score_t arg6_const, + char* arg7, + volatile int* arg8, + void* arg9, + size_t exp_workgroups_per_feature) { + if (histogram_size == 16) { + if (leaf_num_data == num_data) { + if (use_all_features) { + if (!is_constant_hessian) + histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } else { + if (!is_constant_hessian) + histogram16_fulldata<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram16_fulldata<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } + } else { + if (use_all_features) { + // seems all features is always enabled, so this should be the same as fulldata + if (!is_constant_hessian) + histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } else { + if (!is_constant_hessian) + histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram16<<<16*num_workgroups, 16, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } + } + } else if (histogram_size == 64) { + if (leaf_num_data == num_data) { + if (use_all_features) { + if (!is_constant_hessian) + histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } else { + if (!is_constant_hessian) + histogram64_fulldata<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram64_fulldata<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } + } else { + if (use_all_features) { + // seems all features is always enabled, so this should be the same as fulldata + if (!is_constant_hessian) + histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } else { + if (!is_constant_hessian) + histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram64<<<4*num_workgroups, 64, 0, stream>>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } + } + } else { + if (leaf_num_data == num_data) { + if (use_all_features) { + if (!is_constant_hessian) + histogram256<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram256<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } else { + if (!is_constant_hessian) + histogram256_fulldata<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram256_fulldata<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } + } else { + if (use_all_features) { + // seems all features is always enabled, so this should be the same as fulldata + if (!is_constant_hessian) + histogram256<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram256<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } else { + if (!is_constant_hessian) + histogram256<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + else + histogram256<<>>(arg0, arg1, arg2, + arg3, arg4, arg5, + arg6_const, arg7, arg8, static_cast(arg9), exp_workgroups_per_feature); + } + } + } +} + +} // namespace LightGBM + +#endif // USE_CUDA diff --git a/src/treelearner/cuda_kernel_launcher.h b/src/treelearner/cuda_kernel_launcher.h new file mode 100644 index 00000000000..0714e05b2f2 --- /dev/null +++ b/src/treelearner/cuda_kernel_launcher.h @@ -0,0 +1,70 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_ +#define LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_ + +#ifdef USE_CUDA +#include +#include "kernels/histogram_16_64_256.hu" // kernel, acc_type, data_size_t, uchar, score_t + +namespace LightGBM { + +struct ThreadData { + // device id + int device_id; + // parameters for cuda_histogram + int histogram_size; + data_size_t leaf_num_data; + data_size_t num_data; + bool use_all_features; + bool is_constant_hessian; + int num_workgroups; + cudaStream_t stream; + uint8_t* device_features; + uint8_t* device_feature_masks; + data_size_t* device_data_indices; + score_t* device_gradients; + score_t* device_hessians; + score_t hessians_const; + char* device_subhistograms; + volatile int* sync_counters; + void* device_histogram_outputs; + size_t exp_workgroups_per_feature; + // cuda events + cudaEvent_t* kernel_start; + cudaEvent_t* kernel_wait_obj; + std::chrono::duration* kernel_input_wait_time; + // copy histogram + size_t output_size; + char* host_histogram_output; + cudaEvent_t* histograms_wait_obj; +}; + + +void cuda_histogram( + int histogram_size, + data_size_t leaf_num_data, + data_size_t num_data, + bool use_all_features, + bool is_constant_hessian, + int num_workgroups, + cudaStream_t stream, + uint8_t* arg0, + uint8_t* arg1, + data_size_t arg2, + data_size_t* arg3, + data_size_t arg4, + score_t* arg5, + score_t* arg6, + score_t arg6_const, + char* arg7, + volatile int* arg8, + void* arg9, + size_t exp_workgroups_per_feature); + +} // namespace LightGBM + +#endif // USE_CUDA +#endif // LIGHTGBM_TREELEARNER_CUDA_KERNEL_LAUNCHER_H_ diff --git a/src/treelearner/cuda_tree_learner.cpp b/src/treelearner/cuda_tree_learner.cpp new file mode 100644 index 00000000000..16569eef257 --- /dev/null +++ b/src/treelearner/cuda_tree_learner.cpp @@ -0,0 +1,974 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifdef USE_CUDA +#include "cuda_tree_learner.h" + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include "../io/dense_bin.hpp" + +namespace LightGBM { + +#define cudaMemcpy_DEBUG 0 // 1: DEBUG cudaMemcpy +#define ResetTrainingData_DEBUG 0 // 1: Debug ResetTrainingData + +#define CUDA_DEBUG 0 + +static void *launch_cuda_histogram(void *thread_data) { + ThreadData td = *(reinterpret_cast(thread_data)); + int device_id = td.device_id; + CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id)); + + // launch cuda kernel + cuda_histogram(td.histogram_size, + td.leaf_num_data, td.num_data, td.use_all_features, + td.is_constant_hessian, td.num_workgroups, td.stream, + td.device_features, + td.device_feature_masks, + td.num_data, + td.device_data_indices, + td.leaf_num_data, + td.device_gradients, + td.device_hessians, td.hessians_const, + td.device_subhistograms, td.sync_counters, + td.device_histogram_outputs, + td.exp_workgroups_per_feature); + + CUDASUCCESS_OR_FATAL(cudaGetLastError()); + + return NULL; +} + +CUDATreeLearner::CUDATreeLearner(const Config* config) + :SerialTreeLearner(config) { + use_bagging_ = false; + nthreads_ = 0; + if (config->gpu_use_dp && USE_DP_FLOAT) { + Log::Info("LightGBM using CUDA trainer with DP float!!"); + } else { + Log::Info("LightGBM using CUDA trainer with SP float!!"); + } +} + +CUDATreeLearner::~CUDATreeLearner() { +} + + +void CUDATreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) { + // initialize SerialTreeLearner + SerialTreeLearner::Init(train_data, is_constant_hessian); + + // some additional variables needed for GPU trainer + num_feature_groups_ = train_data_->num_feature_groups(); + + // Initialize GPU buffers and kernels: get device info + InitGPU(config_->num_gpu); +} + +// some functions used for debugging the GPU histogram construction +#if CUDA_DEBUG > 0 + +void PrintHistograms(hist_t* h, size_t size) { + double total_hess = 0; + for (size_t i = 0; i < size; ++i) { + printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i)); + if ((i & 3) == 3) + printf("\n"); + total_hess += GET_HESS(h, i); + } + printf("\nSum hessians: %9.3g\n", total_hess); +} + +union Float_t { + int64_t i; + double f; + static int64_t ulp_diff(Float_t a, Float_t b) { + return abs(a.i - b.i); + } +}; + +int CompareHistograms(hist_t* h1, hist_t* h2, size_t size, int feature_id, int dp_flag, int const_flag) { + int i; + int retval = 0; + printf("Comparing Histograms, feature_id = %d, size = %d\n", feature_id, static_cast(size)); + if (dp_flag) { // double precision + double af, bf; + int64_t ai, bi; + for (i = 0; i < static_cast(size); ++i) { + af = GET_GRAD(h1, i); + bf = GET_GRAD(h2, i); + if ((((std::fabs(af - bf))/af) >= 1e-6) && ((std::fabs(af - bf)) >= 1e-6)) { + printf("i = %5d, h1.grad %13.6lf, h2.grad %13.6lf\n", i, af, bf); + ++retval; + } + if (const_flag) { + ai = GET_HESS((reinterpret_cast(h1)), i); + bi = GET_HESS((reinterpret_cast(h2)), i); + if (ai != bi) { + printf("i = %5d, h1.hess %" PRId64 ", h2.hess %" PRId64 "\n", i, ai, bi); + ++retval; + } + } else { + af = GET_HESS(h1, i); + bf = GET_HESS(h2, i); + if (((std::fabs(af - bf))/af) >= 1e-6) { + printf("i = %5d, h1.hess %13.6lf, h2.hess %13.6lf\n", i, af, bf); + ++retval; + } + } + } + } else { // single precision + float af, bf; + int ai, bi; + for (i = 0; i < static_cast(size); ++i) { + af = GET_GRAD(h1, i); + bf = GET_GRAD(h2, i); + if ((((std::fabs(af - bf))/af) >= 1e-6) && ((std::fabs(af - bf)) >= 1e-6)) { + printf("i = %5d, h1.grad %13.6f, h2.grad %13.6f\n", i, af, bf); + ++retval; + } + if (const_flag) { + ai = GET_HESS(h1, i); + bi = GET_HESS(h2, i); + if (ai != bi) { + printf("i = %5d, h1.hess %d, h2.hess %d\n", i, ai, bi); + ++retval; + } + } else { + af = GET_HESS(h1, i); + bf = GET_HESS(h2, i); + if (((std::fabs(af - bf))/af) >= 1e-5) { + printf("i = %5d, h1.hess %13.6f, h2.hess %13.6f\n", i, af, bf); + ++retval; + } + } + } + } + printf("DONE Comparing Histograms...\n"); + return retval; +} +#endif + +int CUDATreeLearner::GetNumWorkgroupsPerFeature(data_size_t leaf_num_data) { + // we roughly want 256 workgroups per device, and we have num_dense_feature4_ feature tuples. + // also guarantee that there are at least 2K examples per workgroup + double x = 256.0 / num_dense_feature_groups_; + + int exp_workgroups_per_feature = static_cast(ceil(log2(x))); + double t = leaf_num_data / 1024.0; + + Log::Debug("We can have at most %d workgroups per feature4 for efficiency reasons\n" + "Best workgroup size per feature for full utilization is %d\n", static_cast(ceil(t)), (1 << exp_workgroups_per_feature)); + + exp_workgroups_per_feature = std::min(exp_workgroups_per_feature, static_cast(ceil(log(static_cast(t))/log(2.0)))); + if (exp_workgroups_per_feature < 0) + exp_workgroups_per_feature = 0; + if (exp_workgroups_per_feature > kMaxLogWorkgroupsPerFeature) + exp_workgroups_per_feature = kMaxLogWorkgroupsPerFeature; + + return exp_workgroups_per_feature; +} + +void CUDATreeLearner::GPUHistogram(data_size_t leaf_num_data, bool use_all_features) { + // we have already copied ordered gradients, ordered hessians and indices to GPU + // decide the best number of workgroups working on one feature4 tuple + // set work group size based on feature size + // each 2^exp_workgroups_per_feature workgroups work on a feature4 tuple + int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(leaf_num_data); + std::vector num_gpu_workgroups; + ThreadData *thread_data = reinterpret_cast(_mm_malloc(sizeof(ThreadData) * num_gpu_, 16)); + + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + int num_gpu_feature_groups = num_gpu_feature_groups_[device_id]; + int num_workgroups = (1 << exp_workgroups_per_feature) * num_gpu_feature_groups; + num_gpu_workgroups.push_back(num_workgroups); + if (num_workgroups > preallocd_max_num_wg_[device_id]) { + preallocd_max_num_wg_.at(device_id) = num_workgroups; + CUDASUCCESS_OR_FATAL(cudaFree(device_subhistograms_[device_id])); + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_subhistograms_[device_id]), static_cast(num_workgroups * dword_features_ * device_bin_size_ * (3 * hist_bin_entry_sz_ / 2)))); + } + // set thread_data + SetThreadData(thread_data, device_id, histogram_size_, leaf_num_data, use_all_features, + num_workgroups, exp_workgroups_per_feature); + } + + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + if (pthread_create(cpu_threads_[device_id], NULL, launch_cuda_histogram, reinterpret_cast(&thread_data[device_id]))) { + Log::Fatal("Error in creating threads."); + } + } + + /* Wait for the threads to finish */ + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + if (pthread_join(*(cpu_threads_[device_id]), NULL)) { + Log::Fatal("Error in joining threads."); + } + } + + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + // copy the results asynchronously. Size depends on if double precision is used + + size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_; + size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_; + + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(reinterpret_cast(host_histogram_outputs_) + host_output_offset, device_histogram_outputs_[device_id], output_size, cudaMemcpyDeviceToHost, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(histograms_wait_obj_[device_id], stream_[device_id])); + } +} + + +template +void CUDATreeLearner::WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array) { + HistType* hist_outputs = reinterpret_cast(host_histogram_outputs_); + + #pragma omp parallel for schedule(static, num_gpu_) + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + // when the output is ready, the computation is done + CUDASUCCESS_OR_FATAL(cudaEventSynchronize(histograms_wait_obj_[device_id])); + } + + HistType* histograms = reinterpret_cast(leaf_histogram_array[0].RawData() - kHistOffset); + #pragma omp parallel for schedule(static) + for (int i = 0; i < num_dense_feature_groups_; ++i) { + if (!feature_masks_[i]) { + continue; + } + int dense_group_index = dense_feature_group_map_[i]; + auto old_histogram_array = histograms + train_data_->GroupBinBoundary(dense_group_index) * 2; + int bin_size = train_data_->FeatureGroupNumBin(dense_group_index); + + for (int j = 0; j < bin_size; ++j) { + GET_GRAD(old_histogram_array, j) = GET_GRAD(hist_outputs, i * device_bin_size_+ j); + GET_HESS(old_histogram_array, j) = GET_HESS(hist_outputs, i * device_bin_size_+ j); + } + } +} + +void CUDATreeLearner::CountDenseFeatureGroups() { + num_dense_feature_groups_ = 0; + + for (int i = 0; i < num_feature_groups_; ++i) { + if (!train_data_->IsMultiGroup(i)) { + num_dense_feature_groups_++; + } + } + if (!num_dense_feature_groups_) { + Log::Warning("GPU acceleration is disabled because no non-trival dense features can be found"); + } +} + +void CUDATreeLearner::prevAllocateGPUMemory() { + // how many feature-group tuples we have + // leave some safe margin for prefetching + // 256 work-items per workgroup. Each work-item prefetches one tuple for that feature + + allocated_num_data_ = std::max(num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature), allocated_num_data_); + + // clear sparse/dense maps + dense_feature_group_map_.clear(); + sparse_feature_group_map_.clear(); + + // do nothing it there is no dense feature + if (!num_dense_feature_groups_) { + return; + } + + // calculate number of feature groups per gpu + num_gpu_feature_groups_.resize(num_gpu_); + offset_gpu_feature_groups_.resize(num_gpu_); + int num_features_per_gpu = num_dense_feature_groups_ / num_gpu_; + int remain_features = num_dense_feature_groups_ - num_features_per_gpu * num_gpu_; + + int offset = 0; + + for (int i = 0; i < num_gpu_; ++i) { + offset_gpu_feature_groups_.at(i) = offset; + num_gpu_feature_groups_.at(i) = (i < remain_features) ? num_features_per_gpu + 1 : num_features_per_gpu; + offset += num_gpu_feature_groups_.at(i); + } + + feature_masks_.resize(num_dense_feature_groups_); + Log::Debug("Resized feature masks"); + + ptr_pinned_feature_masks_ = feature_masks_.data(); + Log::Debug("Memset pinned_feature_masks_"); + memset(ptr_pinned_feature_masks_, 0, num_dense_feature_groups_); + + // histogram bin entry size depends on the precision (single/double) + hist_bin_entry_sz_ = 2 * (config_->gpu_use_dp ? sizeof(hist_t) : sizeof(gpu_hist_t)); // two elements in this "size" + + CUDASUCCESS_OR_FATAL(cudaHostAlloc(reinterpret_cast(&host_histogram_outputs_), static_cast(num_dense_feature_groups_ * device_bin_size_ * hist_bin_entry_sz_), cudaHostAllocPortable)); + + nthreads_ = std::min(omp_get_max_threads(), num_dense_feature_groups_ / dword_features_); + nthreads_ = std::max(nthreads_, 1); +} + +// allocate GPU memory for each GPU +void CUDATreeLearner::AllocateGPUMemory() { + #pragma omp parallel for schedule(static, num_gpu_) + + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + // do nothing it there is no gpu feature + int num_gpu_feature_groups = num_gpu_feature_groups_[device_id]; + if (num_gpu_feature_groups) { + CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id)); + + // allocate memory for all features + if (device_features_[device_id] != NULL) { + CUDASUCCESS_OR_FATAL(cudaFree(device_features_[device_id])); + } + + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_features_[device_id]), static_cast(num_gpu_feature_groups * num_data_ * sizeof(uint8_t)))); + Log::Debug("Allocated device_features_ addr=%p sz=%lu", device_features_[device_id], num_gpu_feature_groups * num_data_); + + // allocate space for gradients and hessians on device + // we will copy gradients and hessians in after ordered_gradients_ and ordered_hessians_ are constructed + if (device_gradients_[device_id] != NULL) { + CUDASUCCESS_OR_FATAL(cudaFree(device_gradients_[device_id])); + } + + if (device_hessians_[device_id] != NULL) { + CUDASUCCESS_OR_FATAL(cudaFree(device_hessians_[device_id])); + } + + if (device_feature_masks_[device_id] != NULL) { + CUDASUCCESS_OR_FATAL(cudaFree(device_feature_masks_[device_id])); + } + + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_gradients_[device_id]), static_cast(allocated_num_data_ * sizeof(score_t)))); + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_hessians_[device_id]), static_cast(allocated_num_data_ * sizeof(score_t)))); + + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_feature_masks_[device_id]), static_cast(num_gpu_feature_groups))); + + // copy indices to the device + if (device_data_indices_[device_id] != NULL) { + CUDASUCCESS_OR_FATAL(cudaFree(device_data_indices_[device_id])); + } + + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_data_indices_[device_id]), static_cast(allocated_num_data_ * sizeof(data_size_t)))); + CUDASUCCESS_OR_FATAL(cudaMemsetAsync(device_data_indices_[device_id], 0, allocated_num_data_ * sizeof(data_size_t), stream_[device_id])); + + Log::Debug("Memset device_data_indices_"); + + // create output buffer, each feature has a histogram with device_bin_size_ bins, + // each work group generates a sub-histogram of dword_features_ features. + if (!device_subhistograms_[device_id]) { + // only initialize once here, as this will not need to change when ResetTrainingData() is called + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_subhistograms_[device_id]), static_cast(preallocd_max_num_wg_[device_id] * dword_features_ * device_bin_size_ * (3 * hist_bin_entry_sz_ / 2)))); + + Log::Debug("created device_subhistograms_: %p", device_subhistograms_[device_id]); + } + + // create atomic counters for inter-group coordination + CUDASUCCESS_OR_FATAL(cudaFree(sync_counters_[device_id])); + CUDASUCCESS_OR_FATAL(cudaMalloc(&(sync_counters_[device_id]), static_cast(num_gpu_feature_groups * sizeof(int)))); + CUDASUCCESS_OR_FATAL(cudaMemsetAsync(sync_counters_[device_id], 0, num_gpu_feature_groups * sizeof(int), stream_[device_id])); + + // The output buffer is allocated to host directly, to overlap compute and data transfer + CUDASUCCESS_OR_FATAL(cudaFree(device_histogram_outputs_[device_id])); + CUDASUCCESS_OR_FATAL(cudaMalloc(&(device_histogram_outputs_[device_id]), static_cast(num_gpu_feature_groups * device_bin_size_ * hist_bin_entry_sz_))); + } + } +} + +void CUDATreeLearner::ResetGPUMemory() { + // clear sparse/dense maps + dense_feature_group_map_.clear(); + sparse_feature_group_map_.clear(); +} + +void CUDATreeLearner::copyDenseFeature() { + if (num_feature_groups_ == 0) { + LGBM_config_::current_learner = use_cpu_learner; + return; + } + + Log::Debug("Started copying dense features from CPU to GPU"); + // find the dense feature-groups and group then into Feature4 data structure (several feature-groups packed into 4 bytes) + size_t copied_feature = 0; + // set device info + int device_id = 0; + uint8_t* device_features = device_features_[device_id]; + CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id)); + Log::Debug("Started copying dense features from CPU to GPU - 1"); + + for (int i = 0; i < num_feature_groups_; ++i) { + // looking for dword_features_ non-sparse feature-groups + if (!train_data_->IsMultiGroup(i)) { + dense_feature_group_map_.push_back(i); + auto sizes_in_byte = train_data_->FeatureGroupSizesInByte(i); + void* tmp_data = train_data_->FeatureGroupData(i); + Log::Debug("Started copying dense features from CPU to GPU - 2"); + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(&device_features[copied_feature * num_data_], tmp_data, sizes_in_byte, cudaMemcpyHostToDevice, stream_[device_id])); + Log::Debug("Started copying dense features from CPU to GPU - 3"); + copied_feature++; + // reset device info + if (copied_feature == static_cast(num_gpu_feature_groups_[device_id])) { + CUDASUCCESS_OR_FATAL(cudaEventRecord(features_future_[device_id], stream_[device_id])); + device_id += 1; + copied_feature = 0; + if (device_id < num_gpu_) { + device_features = device_features_[device_id]; + CUDASUCCESS_OR_FATAL(cudaSetDevice(device_id)); + } + } + } else { + sparse_feature_group_map_.push_back(i); + } + } +} + + + +// InitGPU w/ num_gpu +void CUDATreeLearner::InitGPU(int num_gpu) { + // Get the max bin size, used for selecting best GPU kernel + max_num_bin_ = 0; + + #if CUDA_DEBUG >= 1 + printf("bin_size: "); + #endif + for (int i = 0; i < num_feature_groups_; ++i) { + if (train_data_->IsMultiGroup(i)) { + continue; + } + #if CUDA_DEBUG >= 1 + printf("%d, ", train_data_->FeatureGroupNumBin(i)); + #endif + max_num_bin_ = std::max(max_num_bin_, train_data_->FeatureGroupNumBin(i)); + } + #if CUDA_DEBUG >= 1 + printf("\n"); + #endif + + if (max_num_bin_ <= 16) { + device_bin_size_ = 16; + histogram_size_ = 16; + dword_features_ = 1; + } else if (max_num_bin_ <= 64) { + device_bin_size_ = 64; + histogram_size_ = 64; + dword_features_ = 1; + } else if (max_num_bin_ <= 256) { + Log::Debug("device_bin_size_ = 256"); + device_bin_size_ = 256; + histogram_size_ = 256; + dword_features_ = 1; + } else { + Log::Fatal("bin size %d cannot run on GPU", max_num_bin_); + } + if (max_num_bin_ == 65) { + Log::Warning("Setting max_bin to 63 is sugguested for best performance"); + } + if (max_num_bin_ == 17) { + Log::Warning("Setting max_bin to 15 is sugguested for best performance"); + } + + // get num_dense_feature_groups_ + CountDenseFeatureGroups(); + + if (num_gpu > num_dense_feature_groups_) num_gpu = num_dense_feature_groups_; + + // initialize GPU + int gpu_count; + + CUDASUCCESS_OR_FATAL(cudaGetDeviceCount(&gpu_count)); + num_gpu_ = (gpu_count < num_gpu) ? gpu_count : num_gpu; + + // set cpu threads + cpu_threads_ = reinterpret_cast(_mm_malloc(sizeof(pthread_t *)*num_gpu_, 16)); + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + cpu_threads_[device_id] = reinterpret_cast(_mm_malloc(sizeof(pthread_t), 16)); + } + + // resize device memory pointers + device_features_.resize(num_gpu_); + device_gradients_.resize(num_gpu_); + device_hessians_.resize(num_gpu_); + device_feature_masks_.resize(num_gpu_); + device_data_indices_.resize(num_gpu_); + sync_counters_.resize(num_gpu_); + device_subhistograms_.resize(num_gpu_); + device_histogram_outputs_.resize(num_gpu_); + + // create stream & events to handle multiple GPUs + preallocd_max_num_wg_.resize(num_gpu_, 1024); + stream_.resize(num_gpu_); + hessians_future_.resize(num_gpu_); + gradients_future_.resize(num_gpu_); + indices_future_.resize(num_gpu_); + features_future_.resize(num_gpu_); + kernel_start_.resize(num_gpu_); + kernel_wait_obj_.resize(num_gpu_); + histograms_wait_obj_.resize(num_gpu_); + + for (int i = 0; i < num_gpu_; ++i) { + CUDASUCCESS_OR_FATAL(cudaSetDevice(i)); + CUDASUCCESS_OR_FATAL(cudaStreamCreate(&(stream_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(hessians_future_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(gradients_future_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(indices_future_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(features_future_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(kernel_start_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(kernel_wait_obj_[i]))); + CUDASUCCESS_OR_FATAL(cudaEventCreate(&(histograms_wait_obj_[i]))); + } + + allocated_num_data_ = 0; + prevAllocateGPUMemory(); + + AllocateGPUMemory(); + + copyDenseFeature(); +} + +Tree* CUDATreeLearner::Train(const score_t* gradients, const score_t *hessians) { + Tree *ret = SerialTreeLearner::Train(gradients, hessians); + return ret; +} + +void CUDATreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) { + // check data size + data_size_t old_allocated_num_data = allocated_num_data_; + + SerialTreeLearner::ResetTrainingDataInner(train_data, is_constant_hessian, reset_multi_val_bin); + + #if ResetTrainingData_DEBUG == 1 + serial_time = std::chrono::steady_clock::now() - start_serial_time; + #endif + + num_feature_groups_ = train_data_->num_feature_groups(); + + // GPU memory has to been reallocated because data may have been changed + #if ResetTrainingData_DEBUG == 1 + auto start_alloc_gpu_time = std::chrono::steady_clock::now(); + #endif + + // AllocateGPUMemory only when the number of data increased + int old_num_feature_groups = num_dense_feature_groups_; + CountDenseFeatureGroups(); + if ((old_allocated_num_data < (num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature))) || (old_num_feature_groups < num_dense_feature_groups_)) { + prevAllocateGPUMemory(); + AllocateGPUMemory(); + } else { + ResetGPUMemory(); + } + + copyDenseFeature(); + + #if ResetTrainingData_DEBUG == 1 + alloc_gpu_time = std::chrono::steady_clock::now() - start_alloc_gpu_time; + #endif + + // setup GPU kernel arguments after we allocating all the buffers + #if ResetTrainingData_DEBUG == 1 + auto start_set_arg_time = std::chrono::steady_clock::now(); + #endif + + #if ResetTrainingData_DEBUG == 1 + set_arg_time = std::chrono::steady_clock::now() - start_set_arg_time; + reset_training_data_time = std::chrono::steady_clock::now() - start_reset_training_data_time; + Log::Info("reset_training_data_time: %f secs.", reset_training_data_time.count() * 1e-3); + Log::Info("serial_time: %f secs.", serial_time.count() * 1e-3); + Log::Info("alloc_gpu_time: %f secs.", alloc_gpu_time.count() * 1e-3); + Log::Info("set_arg_time: %f secs.", set_arg_time.count() * 1e-3); + #endif +} + +void CUDATreeLearner::BeforeTrain() { + #if cudaMemcpy_DEBUG == 1 + std::chrono::duration device_hessians_time = std::chrono::milliseconds(0); + std::chrono::duration device_gradients_time = std::chrono::milliseconds(0); + #endif + + SerialTreeLearner::BeforeTrain(); + + #if CUDA_DEBUG >= 2 + printf("CUDATreeLearner::BeforeTrain() Copying initial full gradients and hessians to device\n"); + #endif + + // Copy initial full hessians and gradients to GPU. + // We start copying as early as possible, instead of at ConstructHistogram(). + if ((hessians_ != NULL) && (gradients_ != NULL)) { + if (!use_bagging_ && num_dense_feature_groups_) { + Log::Debug("CudaTreeLearner::BeforeTrain() No baggings, dense_feature_groups_=%d", num_dense_feature_groups_); + + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + if (!(share_state_->is_constant_hessian)) { + Log::Debug("CUDATreeLearner::BeforeTrain(): Starting hessians_ -> device_hessians_"); + + #if cudaMemcpy_DEBUG == 1 + auto start_device_hessians_time = std::chrono::steady_clock::now(); + #endif + + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_hessians_[device_id], hessians_, num_data_*sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id])); + + CUDASUCCESS_OR_FATAL(cudaEventRecord(hessians_future_[device_id], stream_[device_id])); + + #if cudaMemcpy_DEBUG == 1 + device_hessians_time = std::chrono::steady_clock::now() - start_device_hessians_time; + #endif + + Log::Debug("queued copy of device_hessians_"); + } + + #if cudaMemcpy_DEBUG == 1 + auto start_device_gradients_time = std::chrono::steady_clock::now(); + #endif + + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_gradients_[device_id], gradients_, num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(gradients_future_[device_id], stream_[device_id])); + + #if cudaMemcpy_DEBUG == 1 + device_gradients_time = std::chrono::steady_clock::now() - start_device_gradients_time; + #endif + + Log::Debug("CUDATreeLearner::BeforeTrain: issued gradients_ -> device_gradients_"); + } + } + } + + // use bagging + if ((hessians_ != NULL) && (gradients_ != NULL)) { + if (data_partition_->leaf_count(0) != num_data_ && num_dense_feature_groups_) { + // On GPU, we start copying indices, gradients and hessians now, instead at ConstructHistogram() + // copy used gradients and hessians to ordered buffer + const data_size_t* indices = data_partition_->indices(); + data_size_t cnt = data_partition_->leaf_count(0); + + // transfer the indices to GPU + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], indices, cnt * sizeof(*indices), cudaMemcpyHostToDevice, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id])); + + if (!(share_state_->is_constant_hessian)) { + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_hessians_[device_id], const_cast(reinterpret_cast(&(hessians_[0]))), num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(hessians_future_[device_id], stream_[device_id])); + } + + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_gradients_[device_id], const_cast(reinterpret_cast(&(gradients_[0]))), num_data_ * sizeof(score_t), cudaMemcpyHostToDevice, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(gradients_future_[device_id], stream_[device_id])); + } + } + } +} + +bool CUDATreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { + int smaller_leaf; + + data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf); + data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf); + + // only have root + if (right_leaf < 0) { + smaller_leaf = -1; + } else if (num_data_in_left_child < num_data_in_right_child) { + smaller_leaf = left_leaf; + } else { + smaller_leaf = right_leaf; + } + + // Copy indices, gradients and hessians as early as possible + if (smaller_leaf >= 0 && num_dense_feature_groups_) { + // only need to initialize for smaller leaf + // Get leaf boundary + const data_size_t* indices = data_partition_->indices(); + data_size_t begin = data_partition_->leaf_begin(smaller_leaf); + data_size_t end = begin + data_partition_->leaf_count(smaller_leaf); + + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], &indices[begin], (end-begin) * sizeof(data_size_t), cudaMemcpyHostToDevice, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id])); + } + } + + const bool ret = SerialTreeLearner::BeforeFindBestSplit(tree, left_leaf, right_leaf); + + return ret; +} + +bool CUDATreeLearner::ConstructGPUHistogramsAsync( + const std::vector& is_feature_used, + const data_size_t* data_indices, data_size_t num_data) { + if (num_data <= 0) { + return false; + } + + // do nothing if no features can be processed on GPU + if (!num_dense_feature_groups_) { + Log::Debug("no dense feature groups, returning"); + return false; + } + + // copy data indices if it is not null + if (data_indices != nullptr && num_data != num_data_) { + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_data_indices_[device_id], data_indices, num_data * sizeof(data_size_t), cudaMemcpyHostToDevice, stream_[device_id])); + CUDASUCCESS_OR_FATAL(cudaEventRecord(indices_future_[device_id], stream_[device_id])); + } + } + + // converted indices in is_feature_used to feature-group indices + std::vector is_feature_group_used(num_feature_groups_, 0); + + #pragma omp parallel for schedule(static, 1024) if (num_features_ >= 2048) + for (int i = 0; i < num_features_; ++i) { + if (is_feature_used[i]) { + int feature_group = train_data_->Feature2Group(i); + is_feature_group_used[feature_group] = (train_data_->FeatureGroupNumBin(feature_group) <= 16) ? 2 : 1; + } + } + + // construct the feature masks for dense feature-groups + int used_dense_feature_groups = 0; + #pragma omp parallel for schedule(static, 1024) reduction(+:used_dense_feature_groups) if (num_dense_feature_groups_ >= 2048) + for (int i = 0; i < num_dense_feature_groups_; ++i) { + if (is_feature_group_used[dense_feature_group_map_[i]]) { + feature_masks_[i] = is_feature_group_used[dense_feature_group_map_[i]]; + ++used_dense_feature_groups; + } else { + feature_masks_[i] = 0; + } + } + bool use_all_features = ((used_dense_feature_groups == num_dense_feature_groups_) && (data_indices != nullptr)); + // if no feature group is used, just return and do not use GPU + if (used_dense_feature_groups == 0) { + return false; + } + + // if not all feature groups are used, we need to transfer the feature mask to GPU + // otherwise, we will use a specialized GPU kernel with all feature groups enabled + + // We now copy even if all features are used. + #pragma omp parallel for schedule(static, num_gpu_) + for (int device_id = 0; device_id < num_gpu_; ++device_id) { + int offset = offset_gpu_feature_groups_[device_id]; + CUDASUCCESS_OR_FATAL(cudaMemcpyAsync(device_feature_masks_[device_id], ptr_pinned_feature_masks_ + offset, num_gpu_feature_groups_[device_id] , cudaMemcpyHostToDevice, stream_[device_id])); + } + + // All data have been prepared, now run the GPU kernel + GPUHistogram(num_data, use_all_features); + + return true; +} + +void CUDATreeLearner::ConstructHistograms(const std::vector& is_feature_used, bool use_subtract) { + std::vector is_sparse_feature_used(num_features_, 0); + std::vector is_dense_feature_used(num_features_, 0); + int num_dense_features = 0, num_sparse_features = 0; + + #pragma omp parallel for schedule(static) + for (int feature_index = 0; feature_index < num_features_; ++feature_index) { + if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue; + if (!is_feature_used[feature_index]) continue; + if (train_data_->IsMultiGroup(train_data_->Feature2Group(feature_index))) { + is_sparse_feature_used[feature_index] = 1; + num_sparse_features++; + } else { + is_dense_feature_used[feature_index] = 1; + num_dense_features++; + } + } + + // construct smaller leaf + hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset; + + // Check workgroups per feature4 tuple.. + int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(smaller_leaf_splits_->num_data_in_leaf()); + + // if the workgroup per feature is 1 (2^0), return as the work is too small for a GPU + if (exp_workgroups_per_feature == 0) { + return SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract); + } + + // ConstructGPUHistogramsAsync will return true if there are availabe feature groups dispatched to GPU + bool is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used, + nullptr, smaller_leaf_splits_->num_data_in_leaf()); + + // then construct sparse features on CPU + // We set data_indices to null to avoid rebuilding ordered gradients/hessians + if (num_sparse_features > 0) { + train_data_->ConstructHistograms(is_sparse_feature_used, + smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(), + gradients_, hessians_, + ordered_gradients_.data(), ordered_hessians_.data(), + share_state_.get(), + ptr_smaller_leaf_hist_data); + } + + // wait for GPU to finish, only if GPU is actually used + if (is_gpu_used) { + if (config_->gpu_use_dp) { + // use double precision + WaitAndGetHistograms(smaller_leaf_histogram_array_); + } else { + // use single precision + WaitAndGetHistograms(smaller_leaf_histogram_array_); + } + } + + // Compare GPU histogram with CPU histogram, useful for debuggin GPU code problem + // #define CUDA_DEBUG_COMPARE +#ifdef CUDA_DEBUG_COMPARE + printf("Start Comparing_Histogram between GPU and CPU, num_dense_feature_groups_ = %d\n", num_dense_feature_groups_); + bool compare = true; + for (int i = 0; i < num_dense_feature_groups_; ++i) { + if (!feature_masks_[i]) + continue; + int dense_feature_group_index = dense_feature_group_map_[i]; + size_t size = train_data_->FeatureGroupNumBin(dense_feature_group_index); + hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset; + hist_t* current_histogram = ptr_smaller_leaf_hist_data + train_data_->GroupBinBoundary(dense_feature_group_index) * 2; + hist_t* gpu_histogram = new hist_t[size * 2]; + data_size_t num_data = smaller_leaf_splits_->num_data_in_leaf(); + printf("Comparing histogram for feature %d, num_data %d, num_data_ = %d, %lu bins\n", dense_feature_group_index, num_data, num_data_, size); + std::copy(current_histogram, current_histogram + size * 2, gpu_histogram); + std::memset(current_histogram, 0, size * sizeof(hist_t) * 2); + if (train_data_->FeatureGroupBin(dense_feature_group_index) == nullptr) { + continue; + } + if (num_data == num_data_) { + if (share_state_->is_constant_hessian) { + printf("ConstructHistogram(): num_data == num_data_ is_constant_hessian\n"); + train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram( + 0, + num_data, + gradients_, + current_histogram); + } else { + printf("ConstructHistogram(): num_data == num_data_\n"); + train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram( + 0, + num_data, + gradients_, hessians_, + current_histogram); + } + } else { + if (share_state_->is_constant_hessian) { + printf("ConstructHistogram(): is_constant_hessian\n"); + train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram( + smaller_leaf_splits_->data_indices(), + 0, + num_data, + gradients_, + current_histogram); + } else { + printf("ConstructHistogram(): 4, num_data = %d, num_data_ = %d\n", num_data, num_data_); + train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram( + smaller_leaf_splits_->data_indices(), + 0, + num_data, + gradients_, hessians_, + current_histogram); + } + } + int retval; + if ((num_data != num_data_) && compare) { + retval = CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index, config_->gpu_use_dp, share_state_->is_constant_hessian); + printf("CompareHistograms reports %d errors\n", retval); + compare = false; + } + retval = CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index, config_->gpu_use_dp, share_state_->is_constant_hessian); + if (num_data == num_data_) { + printf("CompareHistograms reports %d errors\n", retval); + } else { + printf("CompareHistograms reports %d errors\n", retval); + } + std::copy(gpu_histogram, gpu_histogram + size * 2, current_histogram); + delete [] gpu_histogram; + } + printf("End Comparing Histogram between GPU and CPU\n"); + fflush(stderr); + fflush(stdout); +#endif + + if (larger_leaf_histogram_array_ != nullptr && !use_subtract) { + // construct larger leaf + hist_t* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - kHistOffset; + + is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used, + larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf()); + + // then construct sparse features on CPU + // We set data_indices to null to avoid rebuilding ordered gradients/hessians + if (num_sparse_features > 0) { + train_data_->ConstructHistograms(is_sparse_feature_used, + larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(), + gradients_, hessians_, + ordered_gradients_.data(), ordered_hessians_.data(), + share_state_.get(), + ptr_larger_leaf_hist_data); + } + + // wait for GPU to finish, only if GPU is actually used + if (is_gpu_used) { + if (config_->gpu_use_dp) { + // use double precision + WaitAndGetHistograms(larger_leaf_histogram_array_); + } else { + // use single precision + WaitAndGetHistograms(larger_leaf_histogram_array_); + } + } + } +} + +void CUDATreeLearner::FindBestSplits(const Tree* tree) { + SerialTreeLearner::FindBestSplits(tree); + +#if CUDA_DEBUG >= 3 + for (int feature_index = 0; feature_index < num_features_; ++feature_index) { + if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue; + if (parent_leaf_histogram_array_ != nullptr + && !parent_leaf_histogram_array_[feature_index].is_splittable()) { + smaller_leaf_histogram_array_[feature_index].set_is_splittable(false); + continue; + } + size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1; + printf("CUDATreeLearner::FindBestSplits() Feature %d bin_size=%zd smaller leaf:\n", feature_index, bin_size); + PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size); + if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->leaf_index() < 0) { continue; } + printf("CUDATreeLearner::FindBestSplits() Feature %d bin_size=%zd larger leaf:\n", feature_index, bin_size); + + PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size); + } +#endif +} + +void CUDATreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) { + const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf]; +#if CUDA_DEBUG >= 2 + printf("Splitting leaf %d with feature %d thresh %d gain %f stat %f %f %f %f\n", best_Leaf, best_split_info.feature, best_split_info.threshold, best_split_info.gain, best_split_info.left_sum_gradient, best_split_info.right_sum_gradient, best_split_info.left_sum_hessian, best_split_info.right_sum_hessian); +#endif + SerialTreeLearner::Split(tree, best_Leaf, left_leaf, right_leaf); + if (Network::num_machines() == 1) { + // do some sanity check for the GPU algorithm + if (best_split_info.left_count < best_split_info.right_count) { + if ((best_split_info.left_count != smaller_leaf_splits_->num_data_in_leaf()) || + (best_split_info.right_count!= larger_leaf_splits_->num_data_in_leaf())) { + Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf()); + } + } else { + if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) || + (best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) { + Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf()); + } + } + } +} + +} // namespace LightGBM +#undef cudaMemcpy_DEBUG +#endif // USE_CUDA diff --git a/src/treelearner/cuda_tree_learner.h b/src/treelearner/cuda_tree_learner.h new file mode 100644 index 00000000000..442c2f53ea0 --- /dev/null +++ b/src/treelearner/cuda_tree_learner.h @@ -0,0 +1,265 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ +#ifndef LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_ +#define LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#ifdef USE_CUDA +#include +#endif + +#include "feature_histogram.hpp" +#include "serial_tree_learner.h" +#include "data_partition.hpp" +#include "split_info.hpp" +#include "leaf_splits.hpp" + +#ifdef USE_CUDA +#include +#include "cuda_kernel_launcher.h" + + +using json11::Json; + +namespace LightGBM { + +/*! +* \brief CUDA-based parallel learning algorithm. +*/ +class CUDATreeLearner: public SerialTreeLearner { + public: + explicit CUDATreeLearner(const Config* tree_config); + ~CUDATreeLearner(); + void Init(const Dataset* train_data, bool is_constant_hessian) override; + void ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) override; + Tree* Train(const score_t* gradients, const score_t *hessians); + void SetBaggingData(const Dataset* subset, const data_size_t* used_indices, data_size_t num_data) override { + SerialTreeLearner::SetBaggingData(subset, used_indices, num_data); + if (subset == nullptr && used_indices != nullptr) { + if (num_data != num_data_) { + use_bagging_ = true; + return; + } + } + use_bagging_ = false; + } + + protected: + void BeforeTrain() override; + bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; + void FindBestSplits(const Tree* tree) override; + void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; + void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract) override; + + private: + typedef float gpu_hist_t; + + /*! + * \brief Find the best number of workgroups processing one feature for maximizing efficiency + * \param leaf_num_data The number of data examples on the current leaf being processed + * \return Log2 of the best number for workgroups per feature, in range 0...kMaxLogWorkgroupsPerFeature + */ + int GetNumWorkgroupsPerFeature(data_size_t leaf_num_data); + + /*! + * \brief Initialize GPU device + * \param num_gpu: number of maximum gpus + */ + void InitGPU(int num_gpu); + + /*! + * \brief Allocate memory for GPU computation // alloc only + */ + void CountDenseFeatureGroups(); // compute num_dense_feature_group + void prevAllocateGPUMemory(); // compute CPU-side param calculation & Pin HostMemory + void AllocateGPUMemory(); + + /*! + * \ ResetGPUMemory + */ + void ResetGPUMemory(); + + /*! + * \ copy dense feature from CPU to GPU + */ + void copyDenseFeature(); + + /*! + * \brief Compute GPU feature histogram for the current leaf. + * Indices, gradients and hessians have been copied to the device. + * \param leaf_num_data Number of data on current leaf + * \param use_all_features Set to true to not use feature masks, with a faster kernel + */ + void GPUHistogram(data_size_t leaf_num_data, bool use_all_features); + + void SetThreadData(ThreadData* thread_data, int device_id, int histogram_size, + int leaf_num_data, bool use_all_features, + int num_workgroups, int exp_workgroups_per_feature) { + ThreadData* td = &thread_data[device_id]; + td->device_id = device_id; + td->histogram_size = histogram_size; + td->leaf_num_data = leaf_num_data; + td->num_data = num_data_; + td->use_all_features = use_all_features; + td->is_constant_hessian = share_state_->is_constant_hessian; + td->num_workgroups = num_workgroups; + td->stream = stream_[device_id]; + td->device_features = device_features_[device_id]; + td->device_feature_masks = reinterpret_cast(device_feature_masks_[device_id]); + td->device_data_indices = device_data_indices_[device_id]; + td->device_gradients = device_gradients_[device_id]; + td->device_hessians = device_hessians_[device_id]; + td->hessians_const = hessians_[0]; + td->device_subhistograms = device_subhistograms_[device_id]; + td->sync_counters = sync_counters_[device_id]; + td->device_histogram_outputs = device_histogram_outputs_[device_id]; + td->exp_workgroups_per_feature = exp_workgroups_per_feature; + + td->kernel_start = &(kernel_start_[device_id]); + td->kernel_wait_obj = &(kernel_wait_obj_[device_id]); + td->kernel_input_wait_time = &(kernel_input_wait_time_[device_id]); + + size_t output_size = num_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_; + size_t host_output_offset = offset_gpu_feature_groups_[device_id] * dword_features_ * device_bin_size_ * hist_bin_entry_sz_; + td->output_size = output_size; + td->host_histogram_output = reinterpret_cast(host_histogram_outputs_) + host_output_offset; + td->histograms_wait_obj = &(histograms_wait_obj_[device_id]); + } + + /*! + * \brief Wait for GPU kernel execution and read histogram + * \param histograms Destination of histogram results from GPU. + */ + template + void WaitAndGetHistograms(FeatureHistogram* leaf_histogram_array); + + /*! + * \brief Construct GPU histogram asynchronously. + * Interface is similar to Dataset::ConstructHistograms(). + * \param is_feature_used A predicate vector for enabling each feature + * \param data_indices Array of data example IDs to be included in histogram, will be copied to GPU. + * Set to nullptr to skip copy to GPU. + * \param num_data Number of data examples to be included in histogram + * \return true if GPU kernel is launched, false if GPU is not used + */ + bool ConstructGPUHistogramsAsync( + const std::vector& is_feature_used, + const data_size_t* data_indices, data_size_t num_data); + + /*! brief Log2 of max number of workgroups per feature*/ + const int kMaxLogWorkgroupsPerFeature = 10; // 2^10 + /*! brief Max total number of workgroups with preallocated workspace. + * If we use more than this number of workgroups, we have to reallocate subhistograms */ + std::vector preallocd_max_num_wg_; + + /*! \brief True if bagging is used */ + bool use_bagging_; + + /*! \brief GPU command queue object */ + std::vector stream_; + + /*! \brief total number of feature-groups */ + int num_feature_groups_; + /*! \brief total number of dense feature-groups, which will be processed on GPU */ + int num_dense_feature_groups_; + std::vector num_gpu_feature_groups_; + std::vector offset_gpu_feature_groups_; + /*! \brief On GPU we read one DWORD (4-byte) of features of one example once. + * With bin size > 16, there are 4 features per DWORD. + * With bin size <=16, there are 8 features per DWORD. + */ + int dword_features_; + /*! \brief Max number of bins of training data, used to determine + * which GPU kernel to use */ + int max_num_bin_; + /*! \brief Used GPU kernel bin size (64, 256) */ + int histogram_size_; + int device_bin_size_; + /*! \brief Size of histogram bin entry, depending if single or double precision is used */ + size_t hist_bin_entry_sz_; + /*! \brief Indices of all dense feature-groups */ + std::vector dense_feature_group_map_; + /*! \brief Indices of all sparse feature-groups */ + std::vector sparse_feature_group_map_; + /*! \brief GPU memory object holding the training data */ + std::vector device_features_; + /*! \brief GPU memory object holding the ordered gradient */ + std::vector device_gradients_; + /*! \brief Pointer to pinned memory of ordered gradient */ + void * ptr_pinned_gradients_ = nullptr; + /*! \brief GPU memory object holding the ordered hessian */ + std::vector device_hessians_; + /*! \brief Pointer to pinned memory of ordered hessian */ + void * ptr_pinned_hessians_ = nullptr; + /*! \brief A vector of feature mask. 1 = feature used, 0 = feature not used */ + std::vector feature_masks_; + /*! \brief GPU memory object holding the feature masks */ + std::vector device_feature_masks_; + /*! \brief Pointer to pinned memory of feature masks */ + char* ptr_pinned_feature_masks_ = nullptr; + /*! \brief GPU memory object holding indices of the leaf being processed */ + std::vector device_data_indices_; + /*! \brief GPU memory object holding counters for workgroup coordination */ + std::vector sync_counters_; + /*! \brief GPU memory object holding temporary sub-histograms per workgroup */ + std::vector device_subhistograms_; + /*! \brief Host memory object for histogram output (GPU will write to Host memory directly) */ + std::vector device_histogram_outputs_; + /*! \brief Host memory pointer for histogram outputs */ + void *host_histogram_outputs_; + /*! CUDA waitlist object for waiting for data transfer before kernel execution */ + std::vector kernel_wait_obj_; + /*! CUDA waitlist object for reading output histograms after kernel execution */ + std::vector histograms_wait_obj_; + /*! CUDA Asynchronous waiting object for copying indices */ + std::vector indices_future_; + /*! Asynchronous waiting object for copying gradients */ + std::vector gradients_future_; + /*! Asynchronous waiting object for copying hessians */ + std::vector hessians_future_; + /*! Asynchronous waiting object for copying dense features */ + std::vector features_future_; + + // host-side buffer for converting feature data into featre4 data + int nthreads_; // number of Feature4* vector on host4_vecs_ + std::vector kernel_start_; + std::vector kernel_time_; // measure histogram kernel time + std::vector> kernel_input_wait_time_; + int num_gpu_; + int allocated_num_data_; // allocated data instances + pthread_t **cpu_threads_; // pthread, 1 cpu thread / gpu +}; + +} // namespace LightGBM +#else // USE_CUDA + +// When GPU support is not compiled in, quit with an error message + +namespace LightGBM { + +class CUDATreeLearner: public SerialTreeLearner { + public: + #pragma warning(disable : 4702) + explicit CUDATreeLearner(const Config* tree_config) : SerialTreeLearner(tree_config) { + Log::Fatal("CUDA Tree Learner was not enabled in this build.\n" + "Please recompile with CMake option -DUSE_CUDA=1"); + } +}; + +} // namespace LightGBM + +#endif // USE_CUDA +#endif // LIGHTGBM_TREELEARNER_CUDA_TREE_LEARNER_H_ diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp index 0d6f9df251b..30d8df84acf 100644 --- a/src/treelearner/data_parallel_tree_learner.cpp +++ b/src/treelearner/data_parallel_tree_learner.cpp @@ -256,6 +256,7 @@ void DataParallelTreeLearner::Split(Tree* tree, int best_Leaf, in } // instantiate template classes, otherwise linker cannot find the code +template class DataParallelTreeLearner; template class DataParallelTreeLearner; template class DataParallelTreeLearner; diff --git a/src/treelearner/feature_parallel_tree_learner.cpp b/src/treelearner/feature_parallel_tree_learner.cpp index c5202f3d706..f4edfe03dc1 100644 --- a/src/treelearner/feature_parallel_tree_learner.cpp +++ b/src/treelearner/feature_parallel_tree_learner.cpp @@ -77,6 +77,7 @@ void FeatureParallelTreeLearner::FindBestSplitsFromHistograms( } // instantiate template classes, otherwise linker cannot find the code +template class FeatureParallelTreeLearner; template class FeatureParallelTreeLearner; template class FeatureParallelTreeLearner; } // namespace LightGBM diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp index 43ccadfd176..df90aafb945 100644 --- a/src/treelearner/gpu_tree_learner.cpp +++ b/src/treelearner/gpu_tree_learner.cpp @@ -52,7 +52,7 @@ void PrintHistograms(hist_t* h, size_t size) { double total_hess = 0; for (size_t i = 0; i < size; ++i) { printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i)); - if ((i & 2) == 2) + if ((i & 3) == 3) printf("\n"); total_hess += GET_HESS(h, i); } @@ -1068,10 +1068,10 @@ void GPUTreeLearner::FindBestSplits(const Tree* tree) { } size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1; printf("Feature %d smaller leaf:\n", feature_index); - PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - 1, bin_size); + PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size); if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; } printf("Feature %d larger leaf:\n", feature_index); - PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - 1, bin_size); + PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - kHistOffset, bin_size); } #endif } diff --git a/src/treelearner/kernels/histogram_16_64_256.cu b/src/treelearner/kernels/histogram_16_64_256.cu new file mode 100644 index 00000000000..105ccbb6203 --- /dev/null +++ b/src/treelearner/kernels/histogram_16_64_256.cu @@ -0,0 +1,949 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ + +#include + +#include +#include + +#include "histogram_16_64_256.hu" + +namespace LightGBM { + +// atomic add for float number in local memory +inline __device__ void atomic_local_add_f(acc_type *addr, const acc_type val) { + atomicAdd(addr, static_cast(val)); +} + +// histogram16 stuff +#ifdef ENABLE_ALL_FEATURES +#ifdef IGNORE_INDICES +#define KERNEL_NAME histogram16_fulldata +#else // IGNORE_INDICES +#define KERNEL_NAME histogram16 +#endif // IGNORE_INDICES +#else // ENABLE_ALL_FEATURES +#error "ENABLE_ALL_FEATURES should always be 1" +#define KERNEL_NAME histogram16 +#endif // ENABLE_ALL_FEATURES +#define NUM_BINS 16 +#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS) + +// this function will be called by histogram16 +// we have one sub-histogram of one feature in local memory, and need to read others +inline void __device__ within_kernel_reduction16x4(const acc_type* __restrict__ feature_sub_hist, + const unsigned int skip_id, + const unsigned int old_val_cont_bin0, + const uint16_t num_sub_hist, + acc_type* __restrict__ output_buf, + acc_type* __restrict__ local_hist, + const size_t power_feature_workgroups) { + const uint16_t ltid = threadIdx.x; + acc_type grad_bin = local_hist[ltid * 2]; + acc_type hess_bin = local_hist[ltid * 2 + 1]; + unsigned int* __restrict__ local_cnt = reinterpret_cast(local_hist + 2 * NUM_BINS); + + unsigned int cont_bin; + if (power_feature_workgroups != 0) { + cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0; + } else { + cont_bin = local_cnt[ltid]; + } + uint16_t i; + + if (power_feature_workgroups != 0) { + // add all sub-histograms for feature + const acc_type* __restrict__ p = feature_sub_hist + ltid; + for (i = 0; i < skip_id; ++i) { + grad_bin += *p; p += NUM_BINS; + hess_bin += *p; p += NUM_BINS; + cont_bin += as_acc_int_type(*p); p += NUM_BINS; + } + + // skip the counters we already have + p += 3 * NUM_BINS; + + for (i = i + 1; i < num_sub_hist; ++i) { + grad_bin += *p; p += NUM_BINS; + hess_bin += *p; p += NUM_BINS; + cont_bin += as_acc_int_type(*p); p += NUM_BINS; + } + } + __syncthreads(); + + output_buf[ltid * 2 + 0] = grad_bin; + output_buf[ltid * 2 + 1] = hess_bin; +} + +#if USE_CONSTANT_BUF == 1 +__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base, + __constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))), + const data_size_t feature_size, + __constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))), + const data_size_t num_data, + __constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))), +#if CONST_HESSIAN == 0 + __constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))), +#else + const score_t const_hessian, +#endif + char* __restrict__ output_buf, + volatile int * sync_counters, + acc_type* __restrict__ hist_buf_base, + const size_t power_feature_workgroups) { +#else +__global__ void KERNEL_NAME(const uchar* feature_data_base, + const uchar* __restrict__ feature_masks, + const data_size_t feature_size, + const data_size_t* data_indices, + const data_size_t num_data, + const score_t* ordered_gradients, +#if CONST_HESSIAN == 0 + const score_t* ordered_hessians, +#else + const score_t const_hessian, +#endif + char* __restrict__ output_buf, + volatile int * sync_counters, + acc_type* __restrict__ hist_buf_base, + const size_t power_feature_workgroups) { +#endif + // allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms + // otherwise a "Misaligned Address" exception may occur + __shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)]; + const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x; + const uint16_t ltid = threadIdx.x; + const uint16_t lsize = NUM_BINS; // get_local_size(0); + const uint16_t group_id = blockIdx.x; + + // local memory per workgroup is 3 KB + // clear local memory + unsigned int *ptr = reinterpret_cast(shared_array); + for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) { + ptr[i] = 0; + } + __syncthreads(); + // gradient/hessian histograms + // assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary?? + // total size: 2 * 256 * size_of(float) = 2 KB + // organization: each feature/grad/hessian is at a different bank, + // as indepedent of the feature value as possible + acc_type *gh_hist = reinterpret_cast(shared_array); + + // counter histogram + // total size: 256 * size_of(unsigned int) = 1 KB + unsigned int *cnt_hist = reinterpret_cast(gh_hist + 2 * NUM_BINS); + + // odd threads (1, 3, ...) compute histograms for hessians first + // even thread (0, 2, ...) compute histograms for gradients first + // etc. + uchar is_hessian_first = ltid & 1; + + uint16_t feature_id = group_id >> power_feature_workgroups; + + // each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant) + // feature_size is the number of examples per feature + const uchar *feature_data = feature_data_base + feature_id * feature_size; + + // size of threads that process this feature4 + const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups); + + // equavalent thread ID in this subgroup for this feature4 + const unsigned int subglobal_tid = gtid - feature_id * subglobal_size; + + + data_size_t ind; + data_size_t ind_next; + #ifdef IGNORE_INDICES + ind = subglobal_tid; + #else + ind = data_indices[subglobal_tid]; + #endif + + // extract feature mask, when a byte is set to 0, that feature is disabled + uchar feature_mask = feature_masks[feature_id]; + // exit if the feature is masked + if (!feature_mask) { + return; + } else { + feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature) + } + + // STAGE 1: read feature data, and gradient and hessian + // first half of the threads read feature data from global memory + // We will prefetch data into the "next" variable at the beginning of each iteration + uchar feature; + uchar feature_next; + uint16_t bin; + + feature = feature_data[ind >> feature_mask]; + if (feature_mask) { + feature = (feature >> ((ind & 1) << 2)) & 0xf; + } + bin = feature; + acc_type grad_bin = 0.0f, hess_bin = 0.0f; + acc_type *addr_bin; + + // store gradient and hessian + score_t grad, hess; + score_t grad_next, hess_next; + grad = ordered_gradients[ind]; + #if CONST_HESSIAN == 0 + hess = ordered_hessians[ind]; + #endif + + // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4 + for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) { + // prefetch the next iteration variables + // we don't need bondary check because we have made the buffer large + int i_next = i + subglobal_size; + #ifdef IGNORE_INDICES + // we need to check to bounds here + ind_next = i_next < num_data ? i_next : i; + #else + ind_next = data_indices[i_next]; + #endif + + grad_next = ordered_gradients[ind_next]; + #if CONST_HESSIAN == 0 + hess_next = ordered_hessians[ind_next]; + #endif + + // STAGE 2: accumulate gradient and hessian + if (bin != feature) { + addr_bin = gh_hist + bin * 2 + is_hessian_first; + #if CONST_HESSIAN == 0 + acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin; + atomic_local_add_f(addr_bin, acc_bin); + + addr_bin = addr_bin + 1 - 2 * is_hessian_first; + acc_bin = is_hessian_first ? grad_bin : hess_bin; + atomic_local_add_f(addr_bin, acc_bin); + + #elif CONST_HESSIAN == 1 + atomic_local_add_f(addr_bin, grad_bin); + #endif + + bin = feature; + grad_bin = grad; + hess_bin = hess; + } else { + grad_bin += grad; + hess_bin += hess; + } + + // prefetch the next iteration variables + feature_next = feature_data[ind_next >> feature_mask]; + + // STAGE 3: accumulate counter + atomicAdd(cnt_hist + feature, 1); + + // STAGE 4: update next stat + grad = grad_next; + hess = hess_next; + if (!feature_mask) { + feature = feature_next; + } else { + feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf; + } + } + + + addr_bin = gh_hist + bin * 2 + is_hessian_first; + #if CONST_HESSIAN == 0 + acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin; + atomic_local_add_f(addr_bin, acc_bin); + + addr_bin = addr_bin + 1 - 2 * is_hessian_first; + acc_bin = is_hessian_first ? grad_bin : hess_bin; + atomic_local_add_f(addr_bin, acc_bin); + + #elif CONST_HESSIAN == 1 + atomic_local_add_f(addr_bin, grad_bin); + #endif + __syncthreads(); + + #if CONST_HESSIAN == 1 + // make a final reduction + gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1]; + gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position + __syncthreads(); + #endif + +#if POWER_FEATURE_WORKGROUPS != 0 + acc_type *__restrict__ output = reinterpret_cast(output_buf) + group_id * 3 * NUM_BINS; + // write gradients and hessians + acc_type *__restrict__ ptr_f = output; + for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) { + // even threads read gradients, odd threads read hessians + acc_type value = gh_hist[i]; + ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value; + } + // write counts + acc_int_type *__restrict__ ptr_i = reinterpret_cast(output + 2 * NUM_BINS); + for (uint16_t i = ltid; i < NUM_BINS; i += lsize) { + unsigned int value = cnt_hist[i]; + ptr_i[i] = value; + } + __syncthreads(); + __threadfence(); + unsigned int * counter_val = cnt_hist; + // backup the old value + unsigned int old_val = *counter_val; + if (ltid == 0) { + // all workgroups processing the same feature add this counter + *counter_val = atomicAdd(const_cast(sync_counters + feature_id), 1); + } + // make sure everyone in this workgroup is here + __syncthreads(); + // everyone in this workgroup: if we are the last workgroup, then do reduction! + if (*counter_val == (1 << power_feature_workgroups) - 1) { + if (ltid == 0) { + sync_counters[feature_id] = 0; + } +#else + } + // only 1 work group, no need to increase counter + // the reduction will become a simple copy + { + unsigned int old_val; // dummy +#endif + // locate our feature's block in output memory + unsigned int output_offset = (feature_id << power_feature_workgroups); + acc_type const * __restrict__ feature_subhists = + reinterpret_cast(output_buf) + output_offset * 3 * NUM_BINS; + // skip reading the data already in local memory + unsigned int skip_id = group_id - output_offset; + // locate output histogram location for this feature4 + acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS; + + within_kernel_reduction16x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast(shared_array), power_feature_workgroups); + } +} + +// end of histogram16 stuff + +// histogram64 stuff +#undef KERNEL_NAME +#undef NUM_BINS +#undef LOCAL_MEM_SIZE +#ifdef ENABLE_ALL_FEATURES +#ifdef IGNORE_INDICES +#define KERNEL_NAME histogram64_fulldata +#else // IGNORE_INDICES +#define KERNEL_NAME histogram64 // seems like ENABLE_ALL_FEATURES is set to 1 in the header if its disabled +// #define KERNEL_NAME histogram64_allfeats +#endif // IGNORE_INDICES +#else // ENABLE_ALL_FEATURES +#error "ENABLE_ALL_FEATURES should always be 1" +#define KERNEL_NAME histogram64 +#endif // ENABLE_ALL_FEATURES +#define NUM_BINS 64 +#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS) + +// this function will be called by histogram64 +// we have one sub-histogram of one feature in local memory, and need to read others +inline void __device__ within_kernel_reduction64x4(const acc_type* __restrict__ feature_sub_hist, + const unsigned int skip_id, + const unsigned int old_val_cont_bin0, + const uint16_t num_sub_hist, + acc_type* __restrict__ output_buf, + acc_type* __restrict__ local_hist, + const size_t power_feature_workgroups) { + const uint16_t ltid = threadIdx.x; + acc_type grad_bin = local_hist[ltid * 2]; + acc_type hess_bin = local_hist[ltid * 2 + 1]; + unsigned int* __restrict__ local_cnt = reinterpret_cast(local_hist + 2 * NUM_BINS); + + unsigned int cont_bin; + if (power_feature_workgroups != 0) { + cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0; + } else { + cont_bin = local_cnt[ltid]; + } + uint16_t i; + + if (power_feature_workgroups != 0) { + // add all sub-histograms for feature + const acc_type* __restrict__ p = feature_sub_hist + ltid; + for (i = 0; i < skip_id; ++i) { + grad_bin += *p; p += NUM_BINS; + hess_bin += *p; p += NUM_BINS; + cont_bin += as_acc_int_type(*p); p += NUM_BINS; + } + + // skip the counters we already have + p += 3 * NUM_BINS; + + for (i = i + 1; i < num_sub_hist; ++i) { + grad_bin += *p; p += NUM_BINS; + hess_bin += *p; p += NUM_BINS; + cont_bin += as_acc_int_type(*p); p += NUM_BINS; + } + } + __syncthreads(); + + output_buf[ltid * 2 + 0] = grad_bin; + output_buf[ltid * 2 + 1] = hess_bin; +} + +#if USE_CONSTANT_BUF == 1 +__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base, + __constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))), + const data_size_t feature_size, + __constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))), + const data_size_t num_data, + __constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))), +#if CONST_HESSIAN == 0 + __constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))), +#else + const score_t const_hessian, +#endif + char* __restrict__ output_buf, + volatile int * sync_counters, + acc_type* __restrict__ hist_buf_base, + const size_t power_feature_workgroups) { +#else +__global__ void KERNEL_NAME(const uchar* feature_data_base, + const uchar* __restrict__ feature_masks, + const data_size_t feature_size, + const data_size_t* data_indices, + const data_size_t num_data, + const score_t* ordered_gradients, +#if CONST_HESSIAN == 0 + const score_t* ordered_hessians, +#else + const score_t const_hessian, +#endif + char* __restrict__ output_buf, + volatile int * sync_counters, + acc_type* __restrict__ hist_buf_base, + const size_t power_feature_workgroups) { +#endif + // allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms + // otherwise a "Misaligned Address" exception may occur + __shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)]; + const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x; + const uint16_t ltid = threadIdx.x; + const uint16_t lsize = NUM_BINS; // get_local_size(0); + const uint16_t group_id = blockIdx.x; + + // local memory per workgroup is 3 KB + // clear local memory + unsigned int *ptr = reinterpret_cast(shared_array); + for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) { + ptr[i] = 0; + } + __syncthreads(); + // gradient/hessian histograms + // assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary?? + // total size: 2 * 256 * size_of(float) = 2 KB + // organization: each feature/grad/hessian is at a different bank, + // as indepedent of the feature value as possible + acc_type *gh_hist = reinterpret_cast(shared_array); + + // counter histogram + // total size: 256 * size_of(unsigned int) = 1 KB + unsigned int *cnt_hist = reinterpret_cast(gh_hist + 2 * NUM_BINS); + + // odd threads (1, 3, ...) compute histograms for hessians first + // even thread (0, 2, ...) compute histograms for gradients first + // etc. + uchar is_hessian_first = ltid & 1; + + uint16_t feature_id = group_id >> power_feature_workgroups; + + // each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant) + // feature_size is the number of examples per feature + const uchar *feature_data = feature_data_base + feature_id * feature_size; + + // size of threads that process this feature4 + const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups); + + // equavalent thread ID in this subgroup for this feature4 + const unsigned int subglobal_tid = gtid - feature_id * subglobal_size; + + data_size_t ind; + data_size_t ind_next; + #ifdef IGNORE_INDICES + ind = subglobal_tid; + #else + ind = data_indices[subglobal_tid]; + #endif + + // extract feature mask, when a byte is set to 0, that feature is disabled + uchar feature_mask = feature_masks[feature_id]; + // exit if the feature is masked + if (!feature_mask) { + return; + } else { + feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature) + } + + // STAGE 1: read feature data, and gradient and hessian + // first half of the threads read feature data from global memory + // We will prefetch data into the "next" variable at the beginning of each iteration + uchar feature; + uchar feature_next; + uint16_t bin; + + feature = feature_data[ind >> feature_mask]; + if (feature_mask) { + feature = (feature >> ((ind & 1) << 2)) & 0xf; + } + bin = feature; + acc_type grad_bin = 0.0f, hess_bin = 0.0f; + acc_type *addr_bin; + + // store gradient and hessian + score_t grad, hess; + score_t grad_next, hess_next; + grad = ordered_gradients[ind]; + #if CONST_HESSIAN == 0 + hess = ordered_hessians[ind]; + #endif + + // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4 + for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) { + // prefetch the next iteration variables + // we don't need bondary check because we have made the buffer large + int i_next = i + subglobal_size; + #ifdef IGNORE_INDICES + // we need to check to bounds here + ind_next = i_next < num_data ? i_next : i; + #else + ind_next = data_indices[i_next]; + #endif + + grad_next = ordered_gradients[ind_next]; + #if CONST_HESSIAN == 0 + hess_next = ordered_hessians[ind_next]; + #endif + + // STAGE 2: accumulate gradient and hessian + if (bin != feature) { + addr_bin = gh_hist + bin * 2 + is_hessian_first; + #if CONST_HESSIAN == 0 + acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin; + atomic_local_add_f(addr_bin, acc_bin); + + addr_bin = addr_bin + 1 - 2 * is_hessian_first; + acc_bin = is_hessian_first ? grad_bin : hess_bin; + atomic_local_add_f(addr_bin, acc_bin); + + #elif CONST_HESSIAN == 1 + atomic_local_add_f(addr_bin, grad_bin); + #endif + + bin = feature; + grad_bin = grad; + hess_bin = hess; + } else { + grad_bin += grad; + hess_bin += hess; + } + + // prefetch the next iteration variables + feature_next = feature_data[ind_next >> feature_mask]; + + // STAGE 3: accumulate counter + atomicAdd(cnt_hist + feature, 1); + + // STAGE 4: update next stat + grad = grad_next; + hess = hess_next; + if (!feature_mask) { + feature = feature_next; + } else { + feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf; + } + } + + addr_bin = gh_hist + bin * 2 + is_hessian_first; + #if CONST_HESSIAN == 0 + acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin; + atomic_local_add_f(addr_bin, acc_bin); + + addr_bin = addr_bin + 1 - 2 * is_hessian_first; + acc_bin = is_hessian_first ? grad_bin : hess_bin; + atomic_local_add_f(addr_bin, acc_bin); + + #elif CONST_HESSIAN == 1 + atomic_local_add_f(addr_bin, grad_bin); + #endif + __syncthreads(); + + #if CONST_HESSIAN == 1 + // make a final reduction + gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1]; + gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position + __syncthreads(); + #endif + +#if POWER_FEATURE_WORKGROUPS != 0 + acc_type *__restrict__ output = reinterpret_cast(output_buf) + group_id * 3 * NUM_BINS; + // write gradients and hessians + acc_type *__restrict__ ptr_f = output; + for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) { + // even threads read gradients, odd threads read hessians + acc_type value = gh_hist[i]; + ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value; + } + // write counts + acc_int_type *__restrict__ ptr_i = reinterpret_cast(output + 2 * NUM_BINS); + for (uint16_t i = ltid; i < NUM_BINS; i += lsize) { + unsigned int value = cnt_hist[i]; + ptr_i[i] = value; + } + __syncthreads(); + __threadfence(); + unsigned int * counter_val = cnt_hist; + // backup the old value + unsigned int old_val = *counter_val; + if (ltid == 0) { + // all workgroups processing the same feature add this counter + *counter_val = atomicAdd(const_cast(sync_counters + feature_id), 1); + } + // make sure everyone in this workgroup is here + __syncthreads(); + // everyone in this workgroup: if we are the last workgroup, then do reduction! + if (*counter_val == (1 << power_feature_workgroups) - 1) { + if (ltid == 0) { + sync_counters[feature_id] = 0; + } +#else + } + // only 1 work group, no need to increase counter + // the reduction will become a simple copy + { + unsigned int old_val; // dummy +#endif + // locate our feature's block in output memory + unsigned int output_offset = (feature_id << power_feature_workgroups); + acc_type const * __restrict__ feature_subhists = + reinterpret_cast(output_buf) + output_offset * 3 * NUM_BINS; + // skip reading the data already in local memory + unsigned int skip_id = group_id - output_offset; + // locate output histogram location for this feature4 + acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS; + + within_kernel_reduction64x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast(shared_array), power_feature_workgroups); + } +} + +// end of histogram64 stuff + +// histogram256 stuff +#undef KERNEL_NAME +#undef NUM_BINS +#undef LOCAL_MEM_SIZE +#ifdef ENABLE_ALL_FEATURES +#ifdef IGNORE_INDICES +#define KERNEL_NAME histogram256_fulldata +#else // IGNORE_INDICES +#define KERNEL_NAME histogram256 // seems like ENABLE_ALL_FEATURES is set to 1 in the header if its disabled +// #define KERNEL_NAME histogram256_allfeats +#endif // IGNORE_INDICES +#else // ENABLE_ALL_FEATURES +#error "ENABLE_ALL_FEATURES should always be 1" +#define KERNEL_NAME histogram256 +#endif // ENABLE_ALL_FEATURES +#define NUM_BINS 256 +#define LOCAL_MEM_SIZE ((sizeof(unsigned int) + 2 * sizeof(acc_type)) * NUM_BINS) + +// this function will be called by histogram256 +// we have one sub-histogram of one feature in local memory, and need to read others +inline void __device__ within_kernel_reduction256x4(const acc_type* __restrict__ feature_sub_hist, + const unsigned int skip_id, + const unsigned int old_val_cont_bin0, + const uint16_t num_sub_hist, + acc_type* __restrict__ output_buf, + acc_type* __restrict__ local_hist, + const size_t power_feature_workgroups) { + const uint16_t ltid = threadIdx.x; + acc_type grad_bin = local_hist[ltid * 2]; + acc_type hess_bin = local_hist[ltid * 2 + 1]; + unsigned int* __restrict__ local_cnt = reinterpret_cast(local_hist + 2 * NUM_BINS); + + unsigned int cont_bin; + if (power_feature_workgroups != 0) { + cont_bin = ltid ? local_cnt[ltid] : old_val_cont_bin0; + } else { + cont_bin = local_cnt[ltid]; + } + uint16_t i; + + if (power_feature_workgroups != 0) { + // add all sub-histograms for feature + const acc_type* __restrict__ p = feature_sub_hist + ltid; + for (i = 0; i < skip_id; ++i) { + grad_bin += *p; p += NUM_BINS; + hess_bin += *p; p += NUM_BINS; + cont_bin += as_acc_int_type(*p); p += NUM_BINS; + } + + // skip the counters we already have + p += 3 * NUM_BINS; + + for (i = i + 1; i < num_sub_hist; ++i) { + grad_bin += *p; p += NUM_BINS; + hess_bin += *p; p += NUM_BINS; + cont_bin += as_acc_int_type(*p); p += NUM_BINS; + } + } + + __syncthreads(); + + output_buf[ltid * 2 + 0] = grad_bin; + output_buf[ltid * 2 + 1] = hess_bin; +} + +#if USE_CONSTANT_BUF == 1 +__kernel void KERNEL_NAME(__global const uchar* restrict feature_data_base, + __constant const uchar* restrict feature_masks __attribute__((max_constant_size(65536))), + const data_size_t feature_size, + __constant const data_size_t* restrict data_indices __attribute__((max_constant_size(65536))), + const data_size_t num_data, + __constant const score_t* restrict ordered_gradients __attribute__((max_constant_size(65536))), +#if CONST_HESSIAN == 0 + __constant const score_t* restrict ordered_hessians __attribute__((max_constant_size(65536))), +#else + const score_t const_hessian, +#endif + char* __restrict__ output_buf, + volatile int * sync_counters, + acc_type* __restrict__ hist_buf_base, + const size_t power_feature_workgroups) { +#else +__global__ void KERNEL_NAME(const uchar* feature_data_base, + const uchar* __restrict__ feature_masks, + const data_size_t feature_size, + const data_size_t* data_indices, + const data_size_t num_data, + const score_t* ordered_gradients, +#if CONST_HESSIAN == 0 + const score_t* ordered_hessians, +#else + const score_t const_hessian, +#endif + char* __restrict__ output_buf, + volatile int * sync_counters, + acc_type* __restrict__ hist_buf_base, + const size_t power_feature_workgroups) { +#endif + // allocate the local memory array aligned with float2, to guarantee correct alignment on NVIDIA platforms + // otherwise a "Misaligned Address" exception may occur + __shared__ float2 shared_array[LOCAL_MEM_SIZE/sizeof(float2)]; + const unsigned int gtid = blockIdx.x * blockDim.x + threadIdx.x; + const uint16_t ltid = threadIdx.x; + const uint16_t lsize = NUM_BINS; // get_local_size(0); + const uint16_t group_id = blockIdx.x; + + // local memory per workgroup is 3 KB + // clear local memory + unsigned int *ptr = reinterpret_cast(shared_array); + for (int i = ltid; i < LOCAL_MEM_SIZE/sizeof(unsigned int); i += lsize) { + ptr[i] = 0; + } + __syncthreads(); + // gradient/hessian histograms + // assume this starts at 32 * 4 = 128-byte boundary // What does it mean? boundary?? + // total size: 2 * 256 * size_of(float) = 2 KB + // organization: each feature/grad/hessian is at a different bank, + // as indepedent of the feature value as possible + acc_type *gh_hist = reinterpret_cast(shared_array); + + // counter histogram + // total size: 256 * size_of(unsigned int) = 1 KB + unsigned int *cnt_hist = reinterpret_cast(gh_hist + 2 * NUM_BINS); + + // odd threads (1, 3, ...) compute histograms for hessians first + // even thread (0, 2, ...) compute histograms for gradients first + // etc. + uchar is_hessian_first = ltid & 1; + + uint16_t feature_id = group_id >> power_feature_workgroups; + + // each 2^POWER_FEATURE_WORKGROUPS workgroups process on one feature (compile-time constant) + // feature_size is the number of examples per feature + const uchar *feature_data = feature_data_base + feature_id * feature_size; + + // size of threads that process this feature4 + const unsigned int subglobal_size = lsize * (1 << power_feature_workgroups); + + // equavalent thread ID in this subgroup for this feature4 + const unsigned int subglobal_tid = gtid - feature_id * subglobal_size; + + data_size_t ind; + data_size_t ind_next; + #ifdef IGNORE_INDICES + ind = subglobal_tid; + #else + ind = data_indices[subglobal_tid]; + #endif + + // extract feature mask, when a byte is set to 0, that feature is disabled + uchar feature_mask = feature_masks[feature_id]; + // exit if the feature is masked + if (!feature_mask) { + return; + } else { + feature_mask = feature_mask - 1; // feature_mask is used for get feature (1: 4bit feature, 0: 8bit feature) + } + + // STAGE 1: read feature data, and gradient and hessian + // first half of the threads read feature data from global memory + // We will prefetch data into the "next" variable at the beginning of each iteration + uchar feature; + uchar feature_next; + uint16_t bin; + + feature = feature_data[ind >> feature_mask]; + if (feature_mask) { + feature = (feature >> ((ind & 1) << 2)) & 0xf; + } + bin = feature; + acc_type grad_bin = 0.0f, hess_bin = 0.0f; + acc_type *addr_bin; + + // store gradient and hessian + score_t grad, hess; + score_t grad_next, hess_next; + grad = ordered_gradients[ind]; + #if CONST_HESSIAN == 0 + hess = ordered_hessians[ind]; + #endif + + // there are 2^POWER_FEATURE_WORKGROUPS workgroups processing each feature4 + for (unsigned int i = subglobal_tid; i < num_data; i += subglobal_size) { + // prefetch the next iteration variables + // we don't need bondary check because we have made the buffer large + int i_next = i + subglobal_size; + #ifdef IGNORE_INDICES + // we need to check to bounds here + ind_next = i_next < num_data ? i_next : i; + #else + ind_next = data_indices[i_next]; + #endif + + grad_next = ordered_gradients[ind_next]; + #if CONST_HESSIAN == 0 + hess_next = ordered_hessians[ind_next]; + #endif + // STAGE 2: accumulate gradient and hessian + if (bin != feature) { + addr_bin = gh_hist + bin * 2 + is_hessian_first; + #if CONST_HESSIAN == 0 + acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin; + atomic_local_add_f(addr_bin, acc_bin); + + addr_bin = addr_bin + 1 - 2 * is_hessian_first; + acc_bin = is_hessian_first ? grad_bin : hess_bin; + atomic_local_add_f(addr_bin, acc_bin); + + #elif CONST_HESSIAN == 1 + atomic_local_add_f(addr_bin, grad_bin); + #endif + + bin = feature; + grad_bin = grad; + hess_bin = hess; + } else { + grad_bin += grad; + hess_bin += hess; + } + + // prefetch the next iteration variables + feature_next = feature_data[ind_next >> feature_mask]; + + // STAGE 3: accumulate counter + atomicAdd(cnt_hist + feature, 1); + + // STAGE 4: update next stat + grad = grad_next; + hess = hess_next; + if (!feature_mask) { + feature = feature_next; + } else { + feature = (feature_next >> ((ind_next & 1) << 2)) & 0xf; + } + } + + addr_bin = gh_hist + bin * 2 + is_hessian_first; + #if CONST_HESSIAN == 0 + acc_type acc_bin = is_hessian_first ? hess_bin : grad_bin; + atomic_local_add_f(addr_bin, acc_bin); + + addr_bin = addr_bin + 1 - 2 * is_hessian_first; + acc_bin = is_hessian_first ? grad_bin : hess_bin; + + atomic_local_add_f(addr_bin, acc_bin); + + #elif CONST_HESSIAN == 1 + atomic_local_add_f(addr_bin, grad_bin); + #endif + __syncthreads(); + + #if CONST_HESSIAN == 1 + // make a final reduction + gh_hist[ltid * 2] += gh_hist[ltid * 2 + 1]; + gh_hist[ltid * 2 + 1] = const_hessian * cnt_hist[ltid]; // counter move to this position + __syncthreads(); + #endif + +#if POWER_FEATURE_WORKGROUPS != 0 + acc_type *__restrict__ output = reinterpret_cast(output_buf) + group_id * 3 * NUM_BINS; + // write gradients and hessians + acc_type *__restrict__ ptr_f = output; + for (uint16_t i = ltid; i < 2 * NUM_BINS; i += lsize) { + // even threads read gradients, odd threads read hessians + acc_type value = gh_hist[i]; + ptr_f[(i & 1) * NUM_BINS + (i >> 1)] = value; + } + // write counts + acc_int_type *__restrict__ ptr_i = reinterpret_cast(output + 2 * NUM_BINS); + for (uint16_t i = ltid; i < NUM_BINS; i += lsize) { + unsigned int value = cnt_hist[i]; + ptr_i[i] = value; + } + __syncthreads(); + __threadfence(); + unsigned int * counter_val = cnt_hist; + // backup the old value + unsigned int old_val = *counter_val; + if (ltid == 0) { + // all workgroups processing the same feature add this counter + *counter_val = atomicAdd(const_cast(sync_counters + feature_id), 1); + } + // make sure everyone in this workgroup is here + __syncthreads(); + // everyone in this workgroup: if we are the last workgroup, then do reduction! + if (*counter_val == (1 << power_feature_workgroups) - 1) { + if (ltid == 0) { + sync_counters[feature_id] = 0; + } +#else + } + // only 1 work group, no need to increase counter + // the reduction will become a simple copy + { + unsigned int old_val; // dummy +#endif + // locate our feature's block in output memory + unsigned int output_offset = (feature_id << power_feature_workgroups); + acc_type const * __restrict__ feature_subhists = + reinterpret_cast(output_buf) + output_offset * 3 * NUM_BINS; + // skip reading the data already in local memory + unsigned int skip_id = group_id - output_offset; + // locate output histogram location for this feature4 + acc_type *__restrict__ hist_buf = hist_buf_base + feature_id * 2 * NUM_BINS; + + within_kernel_reduction256x4(feature_subhists, skip_id, old_val, 1 << power_feature_workgroups, hist_buf, reinterpret_cast(shared_array), power_feature_workgroups); + } +} + +// end of histogram256 stuff + +} // namespace LightGBM diff --git a/src/treelearner/kernels/histogram_16_64_256.hu b/src/treelearner/kernels/histogram_16_64_256.hu new file mode 100644 index 00000000000..8e3d3a5ec78 --- /dev/null +++ b/src/treelearner/kernels/histogram_16_64_256.hu @@ -0,0 +1,161 @@ +/*! + * Copyright (c) 2020 IBM Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for license information. + */ + +#ifndef LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_ +#define LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_ + +#include "LightGBM/meta.h" + +namespace LightGBM { + +// use double precision or not +#ifndef USE_DP_FLOAT +#define USE_DP_FLOAT 1 +#endif + +// ignore hessian, and use the local memory for hessian as an additional bank for gradient +#ifndef CONST_HESSIAN +#define CONST_HESSIAN 0 +#endif + +typedef unsigned char uchar; + +template +__device__ double as_double(const T t) { + static_assert(sizeof(T) == sizeof(double), "size mismatch"); + double d; + memcpy(&d, &t, sizeof(T)); + return d; +} +template +__device__ unsigned long long as_ulong_ulong(const T t) { + static_assert(sizeof(T) == sizeof(unsigned long long), "size mismatch"); + unsigned long long u; + memcpy(&u, &t, sizeof(T)); + return u; +} +template +__device__ float as_float(const T t) { + static_assert(sizeof(T) == sizeof(float), "size mismatch"); + float f; + memcpy(&f, &t, sizeof(T)); + return f; +} +template +__device__ unsigned int as_uint(const T t) { + static_assert(sizeof(T) == sizeof(unsigned int), "size_mismatch"); + unsigned int u; + memcpy(&u, &t, sizeof(T)); + return u; +} +template +__device__ uchar4 as_uchar4(const T t) { + static_assert(sizeof(T) == sizeof(uchar4), "size mismatch"); + uchar4 u; + memcpy(&u, &t, sizeof(T)); + return u; +} + +#if USE_DP_FLOAT == 1 +typedef double acc_type; +typedef unsigned long long acc_int_type; +#define as_acc_type as_double +#define as_acc_int_type as_ulong_ulong +#else +typedef float acc_type; +typedef unsigned int acc_int_type; +#define as_acc_type as_float +#define as_acc_int_type as_uint +#endif + +// use all features and do not use feature mask +#ifndef ENABLE_ALL_FEATURES +#define ENABLE_ALL_FEATURES 1 +#endif + +// define all of the different kernels + +#define DECLARE_CONST_BUF(name) \ +__global__ void name(__global const uchar* restrict feature_data_base, \ + const uchar* restrict feature_masks,\ + const data_size_t feature_size,\ + const data_size_t* restrict data_indices, \ + const data_size_t num_data, \ + const score_t* restrict ordered_gradients, \ + const score_t* restrict ordered_hessians,\ + char* __restrict__ output_buf,\ + volatile int * sync_counters,\ + acc_type* __restrict__ hist_buf_base, \ + const size_t power_feature_workgroups); + + +#define DECLARE_CONST_HES_CONST_BUF(name) \ +__global__ void name(const uchar* __restrict__ feature_data_base, \ + const uchar* __restrict__ feature_masks,\ + const data_size_t feature_size,\ + const data_size_t* __restrict__ data_indices, \ + const data_size_t num_data, \ + const score_t* __restrict__ ordered_gradients, \ + const score_t const_hessian,\ + char* __restrict__ output_buf,\ + volatile int * sync_counters,\ + acc_type* __restrict__ hist_buf_base, \ + const size_t power_feature_workgroups); + + + +#define DECLARE_CONST_HES(name) \ +__global__ void name(const uchar* feature_data_base, \ + const uchar* __restrict__ feature_masks,\ + const data_size_t feature_size,\ + const data_size_t* data_indices, \ + const data_size_t num_data, \ + const score_t* ordered_gradients, \ + const score_t const_hessian,\ + char* __restrict__ output_buf, \ + volatile int * sync_counters,\ + acc_type* __restrict__ hist_buf_base, \ + const size_t power_feature_workgroups); + + +#define DECLARE(name) \ +__global__ void name(const uchar* feature_data_base, \ + const uchar* __restrict__ feature_masks,\ + const data_size_t feature_size,\ + const data_size_t* data_indices, \ + const data_size_t num_data, \ + const score_t* ordered_gradients, \ + const score_t* ordered_hessians,\ + char* __restrict__ output_buf, \ + volatile int * sync_counters,\ + acc_type* __restrict__ hist_buf_base, \ + const size_t power_feature_workgroups); + + +DECLARE_CONST_HES(histogram16_allfeats); +DECLARE_CONST_HES(histogram16_fulldata); +DECLARE_CONST_HES(histogram16); +DECLARE(histogram16_allfeats); +DECLARE(histogram16_fulldata); +DECLARE(histogram16); + +DECLARE_CONST_HES(histogram64_allfeats); +DECLARE_CONST_HES(histogram64_fulldata); +DECLARE_CONST_HES(histogram64); +DECLARE(histogram64_allfeats); +DECLARE(histogram64_fulldata); +DECLARE(histogram64); + +DECLARE_CONST_HES(histogram256_allfeats); +DECLARE_CONST_HES(histogram256_fulldata); +DECLARE_CONST_HES(histogram256); +DECLARE(histogram256_allfeats); +DECLARE(histogram256_fulldata); +DECLARE(histogram256); + +} // namespace LightGBM + +#endif // LIGHTGBM_TREELEARNER_KERNELS_HISTOGRAM_16_64_256_HU_ + diff --git a/src/treelearner/parallel_tree_learner.h b/src/treelearner/parallel_tree_learner.h index 137697408e8..2001f2e0dfe 100644 --- a/src/treelearner/parallel_tree_learner.h +++ b/src/treelearner/parallel_tree_learner.h @@ -12,6 +12,7 @@ #include #include +#include "cuda_tree_learner.h" #include "gpu_tree_learner.h" #include "serial_tree_learner.h" diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 22b353952ee..92f26930441 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -326,7 +326,16 @@ void SerialTreeLearner::FindBestSplits(const Tree* tree) { is_feature_used[feature_index] = 1; } bool use_subtract = parent_leaf_histogram_array_ != nullptr; + +#ifdef USE_CUDA + if (LGBM_config_::current_learner == use_cpu_learner) { + SerialTreeLearner::ConstructHistograms(is_feature_used, use_subtract); + } else { + ConstructHistograms(is_feature_used, use_subtract); + } +#else ConstructHistograms(is_feature_used, use_subtract); +#endif FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree); } diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index e6ac8e3ad09..59ba770fb95 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -201,6 +202,11 @@ class SerialTreeLearner: public TreeLearner { std::vector> ordered_gradients_; /*! \brief hessians of current iteration, ordered for cache optimized, aligned to 4K page */ std::vector> ordered_hessians_; +#elif USE_CUDA + /*! \brief gradients of current iteration, ordered for cache optimized */ + std::vector> ordered_gradients_; + /*! \brief hessians of current iteration, ordered for cache optimized */ + std::vector> ordered_hessians_; #else /*! \brief gradients of current iteration, ordered for cache optimized */ std::vector> ordered_gradients_; diff --git a/src/treelearner/tree_learner.cpp b/src/treelearner/tree_learner.cpp index 7172f6b655c..ab009a0b100 100644 --- a/src/treelearner/tree_learner.cpp +++ b/src/treelearner/tree_learner.cpp @@ -4,6 +4,7 @@ */ #include +#include "cuda_tree_learner.h" #include "gpu_tree_learner.h" #include "parallel_tree_learner.h" #include "serial_tree_learner.h" @@ -31,6 +32,16 @@ TreeLearner* TreeLearner::CreateTreeLearner(const std::string& learner_type, con } else if (learner_type == std::string("voting")) { return new VotingParallelTreeLearner(config); } + } else if (device_type == std::string("cuda")) { + if (learner_type == std::string("serial")) { + return new CUDATreeLearner(config); + } else if (learner_type == std::string("feature")) { + return new FeatureParallelTreeLearner(config); + } else if (learner_type == std::string("data")) { + return new DataParallelTreeLearner(config); + } else if (learner_type == std::string("voting")) { + return new VotingParallelTreeLearner(config); + } } return nullptr; } diff --git a/src/treelearner/voting_parallel_tree_learner.cpp b/src/treelearner/voting_parallel_tree_learner.cpp index 1c9c36ba8bb..51ee2096380 100644 --- a/src/treelearner/voting_parallel_tree_learner.cpp +++ b/src/treelearner/voting_parallel_tree_learner.cpp @@ -454,6 +454,7 @@ void VotingParallelTreeLearner::Split(Tree* tree, int best_Leaf, } // instantiate template classes, otherwise linker cannot find the code +template class VotingParallelTreeLearner; template class VotingParallelTreeLearner; template class VotingParallelTreeLearner; } // namespace LightGBM