Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 27, 2021
1 parent a746404 commit fbd852b
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -223,8 +223,8 @@ __device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
while (!n.IsLeaf()) {
float fvalue = loader->GetElement(ridx, n.SplitIndex());
bool is_missing = common::CheckNAN(fvalue);
nidx = GetNextNode<has_categorical>(tree.d_tree, nidx, fvalue, is_missing,
tree.cats);
nidx = GetNextNode<has_missing, has_categorical>(tree.d_tree, nidx, fvalue,
is_missing, tree.cats);
n = tree.d_tree[nidx];
}
return nidx;
Expand Down
2 changes: 1 addition & 1 deletion src/predictor/predict_fn.h
Expand Up @@ -8,7 +8,7 @@

namespace xgboost {
namespace predictor {
template <bool has_missing, bool has_categorical = true>
template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t GetNextNode(
common::Span<RegTree::Node const> tree, bst_node_t nid, float fvalue,
bool is_missing, RegTree::CategoricalSplitMatrix const& cats) {
Expand Down
5 changes: 3 additions & 2 deletions src/tree/tree_model.cc
Expand Up @@ -1060,8 +1060,9 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,

while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = predictor::GetNextNode<true>(nodes, nid, feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
nid = predictor::GetNextNode<true, true>(nodes, nid,
feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
Expand Down
6 changes: 3 additions & 3 deletions src/tree/updater_refresh.cc
Expand Up @@ -129,9 +129,9 @@ class TreeRefresher: public TreeUpdater {
// traverse tree
while (!tree[pid].IsLeaf()) {
unsigned split_index = tree[pid].SplitIndex();
pid =
predictor::GetNextNode<true>(nodes, pid, feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
pid = predictor::GetNextNode<true, true>(
nodes, pid, feat.GetFvalue(split_index), feat.IsMissing(split_index),
cats);
gstats[pid].Add(gpair[ridx]);
}
}
Expand Down

0 comments on commit fbd852b

Please sign in to comment.