diff --git a/doc/python/callbacks.rst b/doc/python/callbacks.rst index 009b4d742fe5..943df4d511b8 100644 --- a/doc/python/callbacks.rst +++ b/doc/python/callbacks.rst @@ -7,9 +7,9 @@ package. In XGBoost 1.3, a new callback interface is designed for Python packag provides the flexiblity of designing various extension for training. Also, XGBoost has a number of pre-defined callbacks for supporting early stopping, checkpoints etc. -####################### + Using builtin callbacks -####################### +----------------------- By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and ``verbose``/``verbose_eval``, when specified the training procedure will define the @@ -50,9 +50,9 @@ this callback function directly into XGBoost: dump = booster.get_dump(dump_format='json') assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump) -########################## + Defining your own callback -########################## +-------------------------- XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user defined callbacks should inherit this class and override corresponding methods. There's a diff --git a/doc/python/index.rst b/doc/python/index.rst index 7596be247f9b..d46b6cb45df3 100644 --- a/doc/python/index.rst +++ b/doc/python/index.rst @@ -12,4 +12,5 @@ Contents python_intro python_api callbacks + model Python examples diff --git a/doc/python/model.rst b/doc/python/model.rst new file mode 100644 index 000000000000..ea5c46024fcb --- /dev/null +++ b/doc/python/model.rst @@ -0,0 +1,38 @@ +##### +Model +##### + +Slice tree model +---------------- + +When ``booster`` is set to ``gbtree`` or ``dart``, XGBoost builds a tree model, which is a +list of trees and can be sliced into multiple sub-models. + +.. code-block:: python + + from sklearn.datasets import make_classification + num_classes = 3 + X, y = make_classification(n_samples=1000, n_informative=5, + n_classes=num_classes) + dtrain = xgb.DMatrix(data=X, label=y) + num_parallel_tree = 4 + num_boost_round = 16 + # total number of built trees is num_parallel_tree * num_classes * num_boost_round + + # We build a boosted random forest for classification here. + booster = xgb.train({ + 'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3}, + num_boost_round=num_boost_round, dtrain=dtrain) + + # This is the sliced model, containing [3, 7) forests + # step is also supported with some limitations like negative step is invalid. + sliced: xgb.Booster = booster[3:7] + + # Access individual tree layer + trees = [_ for _ in booster] + assert len(trees) == num_boost_round + + +The sliced model is a copy of selected trees, that means the model itself is immutable +during slicing. This feature is the basis of `save_best` option in early stopping +callback. diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 4db461d11b1c..12f395c66e45 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], */ XGB_DLL int XGBoosterFree(BoosterHandle handle); +/*! + * \brief Slice a model using boosting index. The slice m:n indicates taking all trees + * that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1). + * + * \param handle Booster to be sliced. + * \param begin_layer start of the slice + * \param end_layer end of the slice; end_layer=0 is equivalent to + * end_layer=num_boost_round + * \param step step size of the slice + * \param out Sliced booster. + * + * \return 0 when success, -1 when failure happens, -2 when index is out of bound. + */ +XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, + int end_layer, int step, + BoosterHandle *out); + /*! * \brief set parameters * \param handle handle diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 8081e15d0922..20b2fbf11218 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable { * \param fo output stream */ virtual void Save(dmlc::Stream* fo) const = 0; + /*! + * \brief Slice a model using boosting index. The slice m:n indicates taking all trees + * that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1). + * \param layer_begin Begining of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. + * \param out Output gradient booster + */ + virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, + GradientBooster *out, bool* out_of_bound) const { + LOG(FATAL) << "Slice is not supported by current booster."; + } /*! * \brief whether the model allow lazy checkpoint * return true if model is only updated in DoBoost diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 50b20de677f1..a3c46085bf35 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \return whether the model allow lazy checkpoint in rabit. */ bool AllowLazyCheckPoint() const; + /*! + * \brief Slice the model. + * + * See InplacePredict for layer parameters. + * + * \param step step size between slice. + * \param out_of_bound Return true if end layer is out of bound. + * + * \return a sliced model. + */ + virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step, + bool *out_of_bound) = 0; /*! * \brief dump the model in the requested format * \param fmap feature map that may help give interpretations of feature diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 32cc957cb5c9..7c8bdaff3c60 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -10,7 +10,7 @@ import numpy from . import rabit -from .core import EarlyStopException, CallbackEnv +from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError from .compat import STRING_TYPES @@ -279,9 +279,11 @@ def __init__(self): def before_training(self, model): '''Run before training starts.''' + return model def after_training(self, model): '''Run after training is finished.''' + return model def before_iteration(self, model, epoch, evals_log): '''Run before each iteration. Return True when training should stop.''' @@ -362,12 +364,24 @@ def __init__(self, callbacks: List[TrainingCallback], def before_training(self, model): '''Function called before training.''' for c in self.callbacks: - c.before_training(model=model) + model = c.before_training(model=model) + msg = 'before_training should return the model' + if self.is_cv: + assert isinstance(model.cvfolds, list), msg + else: + assert isinstance(model, Booster), msg + return model def after_training(self, model): '''Function called after training.''' for c in self.callbacks: - c.after_training(model) + model = c.after_training(model=model) + msg = 'after_training should return the model' + if self.is_cv: + assert isinstance(model.cvfolds, list), msg + else: + assert isinstance(model, Booster), msg + return model def before_iteration(self, model, epoch, dtrain, evals): '''Function called before training iteration.''' @@ -461,7 +475,7 @@ class EarlyStopping(TrainingCallback): maximize : bool Whether to maximize evaluation metric. None means auto (discouraged). save_best : bool - Placeholder, the feature is not yet supported. + Whether training should return the best model or the last model. ''' def __init__(self, rounds, @@ -473,9 +487,6 @@ def __init__(self, self.metric_name = metric_name self.rounds = rounds self.save_best = save_best - # https://github.com/dmlc/xgboost/issues/5531 - assert self.save_best is False, 'save best is not yet supported.' - self.maximize = maximize self.stopping_history = {} @@ -525,7 +536,7 @@ def _update_rounds(self, score, name, metric, model, epoch): return True return False - def after_iteration(self, model, epoch, evals_log): + def after_iteration(self, model: Booster, epoch, evals_log): msg = 'Must have at least 1 validation dataset for early stopping.' assert len(evals_log.keys()) >= 1, msg data_name = '' @@ -551,6 +562,14 @@ def after_iteration(self, model, epoch, evals_log): score = data_log[metric_name][-1] return self._update_rounds(score, data_name, metric_name, model, epoch) + def after_training(self, model: Booster): + try: + if self.save_best: + model = model[: int(model.attr('best_iteration'))] + except XGBoostError as e: + raise XGBoostError('`save_best` is not applicable to current booster') from e + return model + class EvaluationMonitor(TrainingCallback): '''Print the evaluation result at each iteration. @@ -684,9 +703,11 @@ def __init__(self, callbacks, start_iteration, end_iteration, def before_training(self, model): '''Nothing to do for legacy callbacks''' + return model def after_training(self, model): '''Nothing to do for legacy callbacks''' + return model def before_iteration(self, model, epoch, dtrain, evals): '''Called before each iteration.''' diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index e834f409b6f2..9dd708be3f53 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -947,8 +947,8 @@ def __init__(self, params=None, cache=(), model_file=None): Parameters for boosters. cache : list List of cache items. - model_file : string or os.PathLike - Path to the model file. + model_file : string/os.PathLike/Booster/bytearray + Path to the model file if it's string or PathLike. """ for d in cache: if not isinstance(d, DMatrix): @@ -1024,6 +1024,43 @@ def __setstate__(self, state): state['handle'] = handle self.__dict__.update(state) + def __getitem__(self, val): + if isinstance(val, int): + val = slice(val, val+1) + if isinstance(val, tuple): + raise ValueError('Only supports slicing through 1 dimension.') + if not isinstance(val, slice): + msg = _expect((int, slice), type(val)) + raise TypeError(msg) + if isinstance(val.start, type(Ellipsis)) or val.start is None: + start = 0 + else: + start = val.start + if isinstance(val.stop, type(Ellipsis)) or val.stop is None: + stop = 0 + else: + stop = val.stop + if stop < start: + raise ValueError('Invalid slice', val) + + step = val.step if val.step is not None else 1 + + start = ctypes.c_int(start) + stop = ctypes.c_int(stop) + step = ctypes.c_int(step) + + sliced_handle = ctypes.c_void_p() + status = _LIB.XGBoosterSlice(self.handle, start, stop, step, + ctypes.byref(sliced_handle)) + if status == -2: + raise IndexError('Layer index out of range') + _check_call(status) + + sliced = Booster() + _check_call(_LIB.XGBoosterFree(sliced.handle)) + sliced.handle = sliced_handle + return sliced + def save_config(self): '''Output internal parameter configuration of Booster as a JSON string. diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 6b333e246d4a..7ca5922905dd 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -101,7 +101,7 @@ def _train_internal(params, dtrain, num_boost_round, feval, evals_result, callbacks, show_stdv=False, cvfolds=None) - callbacks.before_training(bst) + bst = callbacks.before_training(bst) for i in range(start_iteration, num_boost_round): if callbacks.before_iteration(bst, i, dtrain, evals): break @@ -123,7 +123,7 @@ def _train_internal(params, dtrain, bst.save_rabit_checkpoint() version += 1 - callbacks.after_training(bst) + bst = callbacks.after_training(bst) if evals_result is not None and is_new_callback: evals_result.update(callbacks.history) @@ -493,9 +493,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None verbose_eval, early_stopping_rounds, maximize, 0, num_boost_round, feval, None, callbacks, show_stdv=show_stdv, cvfolds=cvfolds) - callbacks.before_training(cvfolds) - booster = _PackedBooster(cvfolds) + callbacks.before_training(booster) for i in range(num_boost_round): if callbacks.before_iteration(booster, i, dtrain, None): @@ -522,4 +521,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None results = pd.DataFrame.from_dict(results) except ImportError: pass + + callbacks.after_training(booster) + return results diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index d91b179f0a5c..cf0bbebdeaa7 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -730,6 +730,22 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) { API_END(); } +XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, + int end_layer, int step, + BoosterHandle *out) { + API_BEGIN(); + CHECK_HANDLE(); + auto* learner = static_cast(handle); + bool out_of_bound = false; + auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound); + if (out_of_bound) { + return -2; + } + CHECK(p_out); + *out = p_out; + API_END(); +} + inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap, int with_stats, const char *format, xgboost::bst_ulong *len, diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 5d88b4d34ac2..6142cb010dde 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -398,6 +398,38 @@ void GBTree::SaveModel(Json* p_out) const { model_.SaveModel(&model); } +void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step, + GradientBooster *out, bool* out_of_bound) const { + CHECK(configured_); + CHECK(out); + + auto p_gbtree = dynamic_cast(out); + CHECK(p_gbtree); + GBTreeModel &out_model = p_gbtree->model_; + auto layer_trees = this->LayerTrees(); + + layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end; + CHECK_GE(layer_end, layer_begin); + CHECK_GE(step, 1); + int32_t n_layers = (layer_end - layer_begin) / step; + std::vector> &out_trees = out_model.trees; + out_trees.resize(layer_trees * n_layers); + std::vector &out_trees_info = out_model.tree_info; + out_trees_info.resize(layer_trees * n_layers); + out_model.param.num_trees = out_model.trees.size(); + CHECK(this->model_.trees_to_update.empty()); + + *out_of_bound = detail::SliceTrees( + layer_begin, layer_end, step, this->model_, tparam_, layer_trees, + [&](auto const &in_it, auto const &out_it) { + auto new_tree = + std::make_unique(*this->model_.trees.at(in_it)); + bst_group_t group = this->model_.tree_info[in_it]; + out_trees.at(out_it) = std::move(new_tree); + out_trees_info.at(out_it) = group; + }); +} + void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool, @@ -494,6 +526,22 @@ class Dart : public GBTree { dparam_.UpdateAllowUnknown(cfg); } + void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, + GradientBooster *out, bool* out_of_bound) const final { + GBTree::Slice(layer_begin, layer_end, step, out, out_of_bound); + if (*out_of_bound) { + return; + } + auto p_dart = dynamic_cast(out); + CHECK(p_dart); + CHECK(p_dart->weight_drop_.empty()); + detail::SliceTrees( + layer_begin, layer_end, step, model_, tparam_, this->LayerTrees(), + [&](auto const& in_it, auto const&) { + p_dart->weight_drop_.push_back(this->weight_drop_.at(in_it)); + }); + } + void SaveModel(Json *p_out) const override { auto &out = *p_out; out["name"] = String("dart"); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index f96a895aef9e..b2a990dbe304 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -152,6 +152,50 @@ struct DartTrainParam : public XGBoostParameter { } }; +namespace detail { +// From here on, layer becomes concrete trees. +inline std::pair LayerToTree(gbm::GBTreeModel const &model, + GBTreeTrainParam const &tparam, + size_t layer_begin, + size_t layer_end) { + bst_group_t groups = model.learner_model_param->num_output_group; + uint32_t tree_begin = layer_begin * groups * tparam.num_parallel_tree; + uint32_t tree_end = layer_end * groups * tparam.num_parallel_tree; + if (tree_end == 0) { + tree_end = static_cast(model.trees.size()); + } + CHECK_LT(tree_begin, tree_end); + return {tree_begin, tree_end}; +} + +// Call fn for each pair of input output tree. Return true if index is out of bound. +template +inline bool SliceTrees(int32_t layer_begin, int32_t layer_end, int32_t step, + GBTreeModel const &model, GBTreeTrainParam const &tparam, + uint32_t layer_trees, Func fn) { + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model, tparam, layer_begin, layer_end); + if (tree_end > model.trees.size()) { + return true; + } + + layer_end = layer_end == 0 ? model.trees.size() / layer_trees : layer_end; + uint32_t n_layers = (layer_end - layer_begin) / step; + int32_t in_it = tree_begin; + int32_t out_it = 0; + for (uint32_t l = 0; l < n_layers; ++l) { + for (uint32_t i = 0; i < layer_trees; ++i) { + CHECK_LT(in_it, tree_end); + fn(in_it, out_it); + out_it++; + in_it++; + } + in_it += (step - 1) * layer_trees; + } + return false; +} +} // namespace detail + // gradient boosted trees class GBTree : public GradientBooster { public: @@ -200,6 +244,15 @@ class GBTree : public GradientBooster { return model_.learner_model_param->num_output_group == 1; } + // Number of trees per layer. + auto LayerTrees() const { + auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree; + return n_trees; + } + // slice the trees, out must be already allocated + void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, + GradientBooster *out, bool* out_of_bound) const override; + void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training, @@ -210,13 +263,8 @@ class GBTree : public GradientBooster { uint32_t layer_begin, unsigned layer_end) const override { CHECK(configured_); - // From here on, layer becomes concrete trees. - bst_group_t groups = model_.learner_model_param->num_output_group; - uint32_t tree_begin = layer_begin * groups * tparam_.num_parallel_tree; - uint32_t tree_end = layer_end * groups * tparam_.num_parallel_tree; - if (tree_end == 0 || tree_end > model_.trees.size()) { - tree_end = static_cast(model_.trees.size()); - } + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); this->GetPredictor()->InplacePredict(x, model_, missing, out_preds, tree_begin, tree_end); } diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 4a20b48f7d1d..e56dc0ad3a59 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -6,10 +6,10 @@ #include "xgboost/json.h" #include "xgboost/logging.h" #include "gbtree_model.h" +#include "gbtree.h" namespace xgboost { namespace gbm { - void GBTreeModel::Save(dmlc::Stream* fo) const { CHECK_EQ(param.num_trees, static_cast(trees.size())); diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 5a89878d3816..2d07ec198a79 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2019 by Contributors + * Copyright 2017-2020 by Contributors * \file gbtree_model.h */ #ifndef XGBOOST_GBM_GBTREE_MODEL_H_ @@ -22,6 +22,7 @@ namespace xgboost { class Json; namespace gbm { + /*! \brief model parameters */ struct GBTreeModelParam : public dmlc::Parameter { public: diff --git a/src/learner.cc b/src/learner.cc index 85ca3a503260..ce652a420e78 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -971,6 +971,26 @@ class LearnerImpl : public LearnerIO { return gbm_->DumpModel(fmap, with_stats, format); } + Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step, + bool *out_of_bound) override { + this->Configure(); + CHECK_GE(begin_layer, 0); + auto *out_impl = new LearnerImpl({}); + auto gbm = std::unique_ptr(GradientBooster::Create( + this->tparam_.booster, &this->generic_parameters_, + &this->learner_model_param_)); + this->gbm_->Slice(begin_layer, end_layer, step, gbm.get(), out_of_bound); + out_impl->gbm_ = std::move(gbm); + Json config { Object() }; + this->SaveConfig(&config); + out_impl->mparam_ = this->mparam_; + out_impl->attributes_ = this->attributes_; + out_impl->learner_model_param_ = this->learner_model_param_; + out_impl->LoadConfig(config); + out_impl->Configure(); + return out_impl; + } + void UpdateOneIter(int iter, std::shared_ptr train) override { monitor_.Start("UpdateOneIter"); TrainingObserver::Instance().Update(iter); diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 463253aea019..64a94e736800 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -154,9 +154,9 @@ TEST(GBTree, JsonIO) { ASSERT_EQ(get(model["model"]["name"]), "gbtree"); auto const& gbtree_model = model["model"]["model"]; - ASSERT_EQ(get(gbtree_model["trees"]).size(), 1); + ASSERT_EQ(get(gbtree_model["trees"]).size(), 1ul); ASSERT_EQ(get(get(get(gbtree_model["trees"]).front()).at("id")), 0); - ASSERT_EQ(get(gbtree_model["tree_info"]).size(), 1); + ASSERT_EQ(get(gbtree_model["tree_info"]).size(), 1ul); auto j_train_param = model["config"]["gbtree_train_param"]; ASSERT_EQ(get(j_train_param["num_parallel_tree"]), "1"); @@ -194,7 +194,7 @@ TEST(Dart, JsonIO) { ASSERT_EQ(get(model["model"]["name"]), "dart") << model; ASSERT_EQ(get(model["config"]["name"]), "dart"); ASSERT_TRUE(IsA(model["model"]["gbtree"])); - ASSERT_NE(get(model["model"]["weight_drop"]).size(), 0); + ASSERT_NE(get(model["model"]["weight_drop"]).size(), 0ul); } TEST(Dart, Prediction) { @@ -230,4 +230,122 @@ TEST(Dart, Prediction) { ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps); } } + +std::pair TestModelSlice(std::string booster) { + size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3; + auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses); + + int32_t kIters = 10; + std::unique_ptr learner { + Learner::Create({m}) + }; + learner->SetParams(Args{{"booster", booster}, + {"tree_method", "hist"}, + {"num_parallel_tree", std::to_string(kForest)}, + {"num_class", std::to_string(kClasses)}, + {"subsample", "0.5"}, + {"max_depth", "2"}}); + + for (auto i = 0; i < kIters; ++i) { + learner->UpdateOneIter(i, m); + } + + Json model{Object()}; + Json config{Object()}; + learner->SaveModel(&model); + learner->SaveConfig(&config); + bool out_of_bound = false; + + size_t constexpr kSliceStart = 2, kSliceEnd = 8, kStep = 3; + std::unique_ptr sliced {learner->Slice(kSliceStart, kSliceEnd, kStep, &out_of_bound)}; + Json sliced_model{Object()}; + sliced->SaveModel(&sliced_model); + + auto get_shape = [&](Json const& model) { + if (booster == "gbtree") { + return get(model["learner"]["gradient_booster"]["model"]["gbtree_model_param"]); + } else { + return get(model["learner"]["gradient_booster"]["gbtree"]["model"]["gbtree_model_param"]); + } + }; + + auto const& model_shape = get_shape(sliced_model); + CHECK_EQ(get(model_shape.at("num_trees")), std::to_string(2 * kClasses * kForest)); + + Json sliced_config {Object()}; + sliced->SaveConfig(&sliced_config); + CHECK_EQ(sliced_config, config); + + auto get_trees = [&](Json const& model) { + if (booster == "gbtree") { + return get(model["learner"]["gradient_booster"]["model"]["trees"]); + } else { + return get(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]); + } + }; + + auto get_info = [&](Json const& model) { + if (booster == "gbtree") { + return get(model["learner"]["gradient_booster"]["model"]["tree_info"]); + } else { + return get(model["learner"]["gradient_booster"]["gbtree"]["model"]["tree_info"]); + } + }; + + auto const &sliced_trees = get_trees(sliced_model); + CHECK_EQ(sliced_trees.size(), 2 * kClasses * kForest); + + auto constexpr kLayerSize = kClasses * kForest; + auto const &sliced_info = get_info(sliced_model); + + for (size_t layer = 0; layer < 2; ++layer) { + for (size_t j = 0; j < kClasses; ++j) { + for (size_t k = 0; k < kForest; ++k) { + auto idx = layer * kLayerSize + j * kForest + k; + auto const &group = get(sliced_info.at(idx)); + CHECK_EQ(static_cast(group), j); + } + } + } + + auto const& trees = get_trees(model); + + // Sliced layers are [2, 5] + auto begin = kLayerSize * kSliceStart; + auto end = begin + kLayerSize; + auto j = 0; + for (size_t i = begin; i < end; ++i) { + Json tree = trees[i]; + tree["id"] = Integer(0); // id is different, we set it to 0 to allow comparison. + auto sliced_tree = sliced_trees[j]; + sliced_tree["id"] = Integer(0); + CHECK_EQ(tree, sliced_tree); + j++; + } + + begin = kLayerSize * (kSliceStart + kStep); + end = begin + kLayerSize; + for (size_t i = begin; i < end; ++i) { + Json tree = trees[i]; + tree["id"] = Integer(0); + auto sliced_tree = sliced_trees[j]; + sliced_tree["id"] = Integer(0); + CHECK_EQ(tree, sliced_tree); + j++; + } + + return std::make_pair(model, sliced_model); +} + +TEST(GBTree, Slice) { + TestModelSlice("gbtree"); +} + +TEST(Dart, Slice) { + Json model, sliced_model; + std::tie(model, sliced_model) = TestModelSlice("dart"); + auto const& weights = get(model["learner"]["gradient_booster"]["weight_drop"]); + auto const& trees = get(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]); + ASSERT_EQ(weights.size(), trees.size()); +} } // namespace xgboost diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 56e4a95ece42..ff1a7c7cd79d 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -118,7 +118,7 @@ TEST(Learner, Configuration) { // eval_metric is not part of configuration auto attr_names = learner->GetConfigurationArguments(); - ASSERT_EQ(attr_names.size(), 1); + ASSERT_EQ(attr_names.size(), 1ul); ASSERT_EQ(attr_names.find(emetric), attr_names.cend()); ASSERT_EQ(attr_names.at("foo"), "bar"); } @@ -127,7 +127,7 @@ TEST(Learner, Configuration) { std::unique_ptr learner { Learner::Create({nullptr}) }; learner->SetParams({{"foo", "bar"}, {emetric, "auc"}, {emetric, "entropy"}, {emetric, "KL"}}); auto attr_names = learner->GetConfigurationArguments(); - ASSERT_EQ(attr_names.size(), 1); + ASSERT_EQ(attr_names.size(), 1ul); ASSERT_EQ(attr_names.at("foo"), "bar"); } } @@ -181,7 +181,7 @@ TEST(Learner, JsonModelIO) { learner->SaveModel(&new_in); ASSERT_TRUE(IsA(out["learner"]["attributes"])); - ASSERT_EQ(get(out["learner"]["attributes"]).size(), 1); + ASSERT_EQ(get(out["learner"]["attributes"]).size(), 1ul); ASSERT_EQ(out, new_in); } } @@ -333,5 +333,4 @@ TEST(Learner, Seed) { ASSERT_EQ(std::to_string(seed), get(config["learner"]["generic_param"]["seed"])); } - } // namespace xgboost diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index dc3a2778af02..9744eec34d40 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -29,7 +29,7 @@ def json_model(model_path, parameters): return model -class TestModels(unittest.TestCase): +class TestModels: def test_glm(self): param = {'verbosity': 0, 'objective': 'binary:logistic', 'booster': 'gblinear', 'alpha': 0.0001, 'lambda': 1, @@ -209,12 +209,14 @@ def test_feature_names_validation(self): bst = xgb.train([], dm1) bst.predict(dm1) # success - self.assertRaises(ValueError, bst.predict, dm2) + with pytest.raises(ValueError): + bst.predict(dm2) bst.predict(dm1) # success bst = xgb.train([], dm2) bst.predict(dm2) # success - self.assertRaises(ValueError, bst.predict, dm1) + with pytest.raises(ValueError): + bst.predict(dm1) bst.predict(dm2) # success def test_model_binary_io(self): @@ -325,3 +327,96 @@ def validate_model(parameters): parameters = {'tree_method': 'hist', 'booster': 'dart', 'objective': 'multi:softmax'} validate_model(parameters) + + @pytest.mark.parametrize('booster', ['gbtree', 'dart']) + def test_slice(self, booster): + from sklearn.datasets import make_classification + num_classes = 3 + X, y = make_classification(n_samples=1000, n_informative=5, + n_classes=num_classes) + dtrain = xgb.DMatrix(data=X, label=y) + num_parallel_tree = 4 + num_boost_round = 16 + total_trees = num_parallel_tree * num_classes * num_boost_round + booster = xgb.train({ + 'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3, 'booster': booster, + 'objective': 'multi:softprob'}, + num_boost_round=num_boost_round, dtrain=dtrain) + assert len(booster.get_dump()) == total_trees + beg = 3 + end = 7 + sliced: xgb.Booster = booster[beg: end] + + sliced_trees = (end - beg) * num_parallel_tree * num_classes + assert sliced_trees == len(sliced.get_dump()) + + sliced_trees = sliced_trees // 2 + sliced: xgb.Booster = booster[beg: end: 2] + assert sliced_trees == len(sliced.get_dump()) + + sliced: xgb.Booster = booster[beg: ...] + sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes + assert sliced_trees == len(sliced.get_dump()) + + sliced: xgb.Booster = booster[beg:] + sliced_trees = (num_boost_round - beg) * num_parallel_tree * num_classes + assert sliced_trees == len(sliced.get_dump()) + + sliced: xgb.Booster = booster[:end] + sliced_trees = end * num_parallel_tree * num_classes + assert sliced_trees == len(sliced.get_dump()) + + sliced: xgb.Booster = booster[...:end] + sliced_trees = end * num_parallel_tree * num_classes + assert sliced_trees == len(sliced.get_dump()) + + with pytest.raises(ValueError, match=r'>= 0'): + booster[-1: 0] + + # we do not accept empty slice. + with pytest.raises(ValueError): + booster[1:1] + # stop can not be smaller than begin + with pytest.raises(ValueError, match=r'Invalid.*'): + booster[3:0] + with pytest.raises(ValueError, match=r'Invalid.*'): + booster[3:-1] + # negative step is not supported. + with pytest.raises(ValueError, match=r'.*>= 1.*'): + booster[0:2:-1] + # step can not be 0. + with pytest.raises(ValueError, match=r'.*>= 1.*'): + booster[0:2:0] + + trees = [_ for _ in booster] + assert len(trees) == num_boost_round + + with pytest.raises(TypeError): + booster["wrong type"] + with pytest.raises(IndexError): + booster[:num_boost_round+1] + with pytest.raises(ValueError): + booster[1, 2] # too many dims + # setitem is not implemented as model is immutable during slicing. + with pytest.raises(TypeError): + booster[...:end] = booster + + sliced_0 = booster[1:3] + sliced_1 = booster[3:7] + + predt_0 = sliced_0.predict(dtrain, output_margin=True) + predt_1 = sliced_1.predict(dtrain, output_margin=True) + + merged = predt_0 + predt_1 - 0.5 # base score. + single = booster[1:7].predict(dtrain, output_margin=True) + np.testing.assert_allclose(merged, single, atol=1e-6) + + sliced_0 = booster[1:7:2] # 1,3,5 + sliced_1 = booster[2:8:2] # 2,4,6 + + predt_0 = sliced_0.predict(dtrain, output_margin=True) + predt_1 = sliced_1.predict(dtrain, output_margin=True) + + merged = predt_0 + predt_1 - 0.5 + single = booster[1:7].predict(dtrain, output_margin=True) + np.testing.assert_allclose(merged, single, atol=1e-6) diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index 2b482bb99c5e..9f444e60c668 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -113,6 +113,35 @@ def test_early_stopping_custom_eval_skl(self): dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 + def test_early_stopping_save_best_model(self): + from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) + n_estimators = 100 + cls = xgb.XGBClassifier(n_estimators=n_estimators) + early_stopping_rounds = 5 + early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, + save_best=True) + cls.fit(X, y, eval_set=[(X, y)], + eval_metric=tm.eval_error_metric, callbacks=[early_stop]) + booster = cls.get_booster() + dump = booster.get_dump(dump_format='json') + assert len(dump) == booster.best_iteration + + early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, + save_best=True) + cls = xgb.XGBClassifier(booster='gblinear', n_estimators=10) + self.assertRaises(ValueError, lambda: cls.fit(X, y, eval_set=[(X, y)], + eval_metric=tm.eval_error_metric, + callbacks=[early_stop])) + + # No error + early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, + save_best=False) + xgb.XGBClassifier(booster='gblinear', n_estimators=10).fit( + X, y, eval_set=[(X, y)], + eval_metric=tm.eval_error_metric, + callbacks=[early_stop]) + def run_eta_decay(self, tree_method, deprecated_callback): if deprecated_callback: scheduler = xgb.callback.reset_learning_rate