Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support categorical split in tree model dump. #7036

Merged
merged 2 commits into from Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/xgboost/feature_map.h
@@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2021 by Contributors
* \file feature_map.h
* \brief Feature map data structure to help visualization and model dump.
* \author Tianqi Chen
Expand All @@ -26,7 +26,8 @@ class FeatureMap {
kIndicator = 0,
kQuantitive = 1,
kInteger = 2,
kFloat = 3
kFloat = 3,
kCategorical = 4
};
/*!
* \brief load feature map from input stream
Expand Down Expand Up @@ -82,6 +83,7 @@ class FeatureMap {
if (!strcmp("q", tname)) return kQuantitive;
if (!strcmp("int", tname)) return kInteger;
if (!strcmp("float", tname)) return kFloat;
if (!strcmp("categorical", tname)) return kCategorical;
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
return kIndicator;
}
Expand Down
3 changes: 2 additions & 1 deletion python-package/xgboost/plotting.py
Expand Up @@ -3,6 +3,7 @@
# coding: utf-8
"""Plotting Library."""
from io import BytesIO
import json
import numpy as np
from .core import Booster
from .sklearn import XGBModel
Expand Down Expand Up @@ -203,7 +204,7 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,

if kwargs:
parameters += ':'
parameters += str(kwargs)
parameters += json.dumps(kwargs)
tree = booster.get_dump(
fmap=fmap,
dump_format=parameters)[num_trees]
Expand Down
5 changes: 0 additions & 5 deletions src/predictor/cpu_predictor.cc
Expand Up @@ -52,11 +52,6 @@ bst_float PredValue(const SparsePage::Inst &inst,
if (tree_info[i] == bst_group) {
auto const &tree = *trees[i];
bool has_categorical = tree.HasCategoricalSplit();

auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
auto split_types = tree.GetSplitTypes();
auto categories_ptr =
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
auto cats = tree.GetCategoriesMatrix();
bst_node_t nidx = -1;
if (has_categorical) {
Expand Down
182 changes: 146 additions & 36 deletions src/tree/tree_model.cc
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2020 by Contributors
* Copyright 2015-2021 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
Expand Down Expand Up @@ -74,6 +74,7 @@ class TreeGenerator {
int32_t /*nid*/, uint32_t /*depth*/) const {
return "";
}
virtual std::string Categorical(RegTree const&, int32_t, uint32_t) const = 0;
virtual std::string Integer(RegTree const& /*tree*/,
int32_t /*nid*/, uint32_t /*depth*/) const {
return "";
Expand All @@ -92,26 +93,51 @@ class TreeGenerator {
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
auto const split_index = tree[nid].SplitIndex();
std::string result;
auto is_categorical = tree.GetSplitTypes()[nid] == FeatureType::kCategorical;
if (split_index < fmap_.Size()) {
auto check_categorical = [&]() {
CHECK(is_categorical)
<< fmap_.Name(split_index)
<< " in feature map is numerical but tree node is categorical.";
};
auto check_numerical = [&]() {
auto is_numerical = !is_categorical;
CHECK(is_numerical)
<< fmap_.Name(split_index)
<< " in feature map is categorical but tree node is numerical.";
};

switch (fmap_.TypeOf(split_index)) {
case FeatureMap::kIndicator: {
result = this->Indicator(tree, nid, depth);
break;
}
case FeatureMap::kInteger: {
result = this->Integer(tree, nid, depth);
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
result = this->Quantitive(tree, nid, depth);
break;
}
default:
LOG(FATAL) << "Unknown feature map type.";
case FeatureMap::kCategorical: {
check_categorical();
result = this->Categorical(tree, nid, depth);
break;
}
case FeatureMap::kIndicator: {
check_numerical();
result = this->Indicator(tree, nid, depth);
break;
}
case FeatureMap::kInteger: {
check_numerical();
result = this->Integer(tree, nid, depth);
break;
}
case FeatureMap::kFloat:
case FeatureMap::kQuantitive: {
check_numerical();
result = this->Quantitive(tree, nid, depth);
break;
}
default:
LOG(FATAL) << "Unknown feature map type.";
}
} else {
result = this->PlainNode(tree, nid, depth);
if (is_categorical) {
result = this->Categorical(tree, nid, depth);
} else {
result = this->PlainNode(tree, nid, depth);
}
}
return result;
}
Expand Down Expand Up @@ -179,6 +205,32 @@ TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const&
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)

std::vector<bst_cat_t> GetSplitCategories(RegTree const &tree, int32_t nidx) {
auto const &csr = tree.GetCategoriesMatrix();
auto seg = csr.node_ptr[nidx];
auto split = common::KCatBitField{csr.categories.subspan(seg.beg, seg.size)};

std::vector<bst_cat_t> cats;
for (size_t i = 0; i < split.Size(); ++i) {
if (split.Check(i)) {
cats.push_back(static_cast<bst_cat_t>(i));
}
}
return cats;
}

std::string PrintCatsAsSet(std::vector<bst_cat_t> const &cats) {
std::stringstream ss;
ss << "{";
for (size_t i = 0; i < cats.size(); ++i) {
ss << cats[i];
if (i != cats.size() - 1) {
ss << ",";
}
}
ss << "}";
return ss.str();
}

class TextGenerator : public TreeGenerator {
using SuperT = TreeGenerator;
Expand Down Expand Up @@ -258,6 +310,17 @@ class TextGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
}

std::string Categorical(RegTree const &tree, int32_t nid,
uint32_t depth) const override {
auto cats = GetSplitCategories(tree, nid);
std::string cats_str = PrintCatsAsSet(cats);
static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}";
std::string const result =
SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
return result;
}

std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
Expand Down Expand Up @@ -343,6 +406,24 @@ class JsonGenerator : public TreeGenerator {
return result;
}

std::string Categorical(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cats = GetSplitCategories(tree, nid);
static std::string const kCategoryTemplate =
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {right}, "no": {left}, )I"
R"I("missing": {missing})I";
std::string cats_ptr = "[";
for (size_t i = 0; i < cats.size(); ++i) {
cats_ptr += std::to_string(cats[i]);
if (i != cats.size() - 1) {
cats_ptr += ", ";
}
}
cats_ptr += "]";
auto results = SplitNodeImpl(tree, nid, kCategoryTemplate, cats_ptr, depth);
return results;
}

std::string SplitNodeImpl(RegTree const &tree, int32_t nid,
std::string const &template_str, std::string cond,
uint32_t depth) const {
Expand Down Expand Up @@ -534,6 +615,27 @@ class GraphvizGenerator : public TreeGenerator {
}

protected:
template <bool is_categorical>
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
// Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child;
std::string branch;
if (is_categorical) {
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
} else {
branch = std::string{left ? "yes" : "no"} + std::string{is_missing ? ", missing" : ""};
}
std::string buffer =
SuperT::Match(kEdgeTemplate,
{{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}});
return buffer;
}

// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
Expand All @@ -552,27 +654,32 @@ class GraphvizGenerator : public TreeGenerator {
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{params}", param_.condition_node_params}});

static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
auto BuildEdge =
[&tree, nid, MatchFn, this](int32_t child, bool left) {
// Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child;
std::string branch = std::string {left ? "yes" : "no"} +
std::string {is_missing ? ", missing" : ""};
std::string buffer = MatchFn(kEdgeTemplate, {
{"{nid}", std::to_string(nid)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}});
return buffer;
};
result += BuildEdge(tree[nid].LeftChild(), true);
result += BuildEdge(tree[nid].RightChild(), false);
result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);

return result;
};

std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
auto cats = GetSplitCategories(tree, nid);
auto cats_str = PrintCatsAsSet(cats);
auto split = tree[nid].SplitIndex();
std::string result = SuperT::Match(
kLabelTemplate,
{{"{nid}", std::to_string(nid)},
{"{fname}", split < fmap_.Size() ? fmap_.Name(split)
: 'f' + std::to_string(split)},
{"{cond}", cats_str},
{"{params}", param_.condition_node_params}});

result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);

return result;
}

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
Expand All @@ -588,9 +695,12 @@ class GraphvizGenerator : public TreeGenerator {
return this->LeafNode(tree, nid, depth);
}
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
? this->Categorical(tree, nid, depth)
: this->PlainNode(tree, nid, depth);
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", this->PlainNode(tree, nid, depth)},
{{"{parent}", node},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
return result;
Expand Down