Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add decision path dump feature #9151

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

Anarion-zuo
Copy link

@Anarion-zuo Anarion-zuo commented May 10, 2023

This patch follows this issue #9128. I have already implemented a prototype for dumping decision paths in tree models. XGBoosterPredictFromDMatrix calls this new version of predict and works as expected when called from Python.

A dumped decision path looks like the following. For now, dumping is directed onto stderr.

[{"row_index": 0, "decision_path": [
{ "nodeid": 0, "depth": 0, "split": "feat_slot_8", "split_condition": 0.18880336, "yes": 1, "no": 2, "missing": 2 , "gain": 3.37755585, "cover": 46.0216866},
{ "nodeid": 2, "depth": 1, "split": "feat_slot_1", "split_condition": 0.0665793717, "yes": 5, "no": 6, "missing": 6 , "gain": 1.01384354, "cover": 34.8793144},
{ "nodeid": 5, "depth": 2, "split": "feat_slot_5", "split_condition": 0.137061194, "yes": 11, "no": 12, "missing": 12 , "gain": 0.0831073523, "cover": 2.51199985}
]},
{"row_index": 1, "decision_path": [
{ "nodeid": 0, "depth": 0, "split": "feat_slot_8", "split_condition": 0.18880336, "yes": 1, "no": 2, "missing": 2 , "gain": 3.37755585, "cover": 46.0216866},
{ "nodeid": 2, "depth": 1, "split": "feat_slot_1", "split_condition": 0.0665793717, "yes": 5, "no": 6, "missing": 6 , "gain": 1.01384354, "cover": 34.8793144},
{ "nodeid": 6, "depth": 2, "split": "feat_slot_7", "split_condition": 0.579003334, "yes": 13, "no": 14, "missing": 14 , "gain": 1.50590134, "cover": 32.3673134},
{ "nodeid": 13, "depth": 3, "split": "feat_slot_8", "split_condition": 0.632688642, "yes": 15, "no": 16, "missing": 16 , "gain": 0.154432297, "cover": 25.213913},
{ "nodeid": 15, "depth": 4, "split": "feat_slot_7", "split_condition": 0.119174264, "yes": 19, "no": 20, "missing": 20 , "gain": 0.0240440369, "cover": 16.2166805}
]},
{"row_index": 2, "decision_path": [
{ "nodeid": 0, "depth": 0, "split": "feat_slot_8", "split_condition": 0.18880336, "yes": 1, "no": 2, "missing": 2 , "gain": 3.37755585, "cover": 46.0216866},
{ "nodeid": 2, "depth": 1, "split": "feat_slot_1", "split_condition": 0.0665793717, "yes": 5, "no": 6, "missing": 6 , "gain": 1.01384354, "cover": 34.8793144},
{ "nodeid": 6, "depth": 2, "split": "feat_slot_7", "split_condition": 0.579003334, "yes": 13, "no": 14, "missing": 14 , "gain": 1.50590134, "cover": 32.3673134},
{ "nodeid": 13, "depth": 3, "split": "feat_slot_8", "split_condition": 0.632688642, "yes": 15, "no": 16, "missing": 16 , "gain": 0.154432297, "cover": 25.213913},
{ "nodeid": 15, "depth": 4, "split": "feat_slot_7", "split_condition": 0.119174264, "yes": 19, "no": 20, "missing": 20 , "gain": 0.0240440369, "cover": 16.2166805}
]},
{"row_index": 3, "decision_path": [
{ "nodeid": 0, "depth": 0, "split": "feat_slot_8", "split_condition": 0.18880336, "yes": 1, "no": 2, "missing": 2 , "gain": 3.37755585, "cover": 46.0216866},
{ "nodeid": 2, "depth": 1, "split": "feat_slot_1", "split_condition": 0.0665793717, "yes": 5, "no": 6, "missing": 6 , "gain": 1.01384354, "cover": 34.8793144},
{ "nodeid": 6, "depth": 2, "split": "feat_slot_7", "split_condition": 0.579003334, "yes": 13, "no": 14, "missing": 14 , "gain": 1.50590134, "cover": 32.3673134},
{ "nodeid": 13, "depth": 3, "split": "feat_slot_8", "split_condition": 0.632688642, "yes": 15, "no": 16, "missing": 16 , "gain": 0.154432297, "cover": 25.213913},
{ "nodeid": 16, "depth": 4, "split": "feat_slot_6", "split_condition": 0.774850428, "yes": 21, "no": 22, "missing": 22 , "gain": 0.0456981659, "cover": 8.99723244}
]},...

Some future work direction:

  • Fix format issues & add doc comments.
  • Add path dumping in other formats (text, graphviz)
  • Add api support in other programming languages
  • Add support when using GPU
  • Add tests on the new feature

@trivialfis
Copy link
Member

Hi, thank you for working on the new feature, I hope you enjoy hacking the code!

One question comes to mind, should we align the API to what's in sklearn?

@Anarion-zuo
Copy link
Author

Personally, I don't think it's necessary. Nonetheless, it can be implemented when all else is ready, if you care to insist upon it.

@Anarion-zuo
Copy link
Author

Anarion-zuo commented May 11, 2023

Calling predict cpu version from Python API is presently fully functional as I personally expected. For each row in the input dmatrix, the nodes on the decision path collected when predicting the row is dumped into a given file in either json or text format.

There are several issues that I am not so sure about. Please advise!

  • I have trouble adding support on other programming languages as I am not familiar with them. Perhaps someone can help out on this.
  • I do not believe predict's gpu version should have this feature, as the gpu version is meant to be efficient.
  • I could not find an existing way of testing APIs like DumpModel. Is a rigorous test required?

@trivialfis
Copy link
Member

Personally, I don't think it's necessary. Nonetheless, it can be implemented when all else is ready, if you care to insist upon it.

I think the scikit-learn representation makes sense. It represents the decision path with a numeric data structure instead of strings. The former is more efficient and reusable, one can easily build a human-readable text representation based on the matrix output when needed. In addition, the result can be returned back to Python instead of being kept as C++-only structure.

I have trouble adding support on other programming languages as I am not familiar with them. Perhaps someone can help out on this.

It's not necessary to have support for all language bindings at the moment. We can iterate on that.

I do not believe predict's gpu version should have this feature, as the gpu version is meant to be efficient.

We can work on it later.

I could not find an existing way of testing APIs like DumpModel. Is a rigorous test required?

There are some tests in test_tree_model.cc.

@Anarion-zuo
Copy link
Author

Then I must refactor things to align it with sklearn, and add tests.

@trivialfis
Copy link
Member

Let me know if there's anything I can help.

@Anarion-zuo
Copy link
Author

Anarion-zuo commented May 12, 2023

Please correct me if I'm mistaken. I presently share my thoughts on the format of the returned representation.

Sklearn's random forest allocates node IDs globally, whereas xgboost's tree node IDs are local to each tree. Thus, different trees could have the same ID for the nodes. Therefore, our decision_path cannot be in exactly the same way as sklearn's.

My proposal is then in two forms of the same idea.

  1. For each tree, predict in xgboost returns an indicator matrix for each row of data. Hence, each matrix is in shape (n_samples, n_node), where n_node is the number of nodes in the tree. We would have n_trees of such matrices.
  2. For each row of data, predict in xgboost returns an indicator matrix for each tree. Hence, each matrix is in shape (n_trees, n_node). We would have n_rows of such matrices.

On second thought, the second one would not be feasible because the n_node value might not be the same for different trees.

@Anarion-zuo
Copy link
Author

I have done as the first one indicated.

@trivialfis
Copy link
Member

Apologies for the slow reply. I will mention the issue with @hcho3 and @RAMitchell today. In the meanwhile, I will look into your solution. Thank you for the nice work!

@Anarion-zuo
Copy link
Author

Is there any update here?

@hcho3
Copy link
Collaborator

hcho3 commented May 18, 2023

Sklearn's random forest allocates node IDs globally, whereas xgboost's tree node IDs are local to each tree. Thus, different trees could have the same ID for the nodes. Therefore, our decision_path cannot be in exactly the same way as sklearn's.

For each tree, predict in xgboost returns an indicator matrix for each row of data. Hence, each matrix is in shape (n_samples, n_node), where n_node is the number of nodes in the tree. We would have n_trees of such matrices.

I'm not sure why XGBoost should return multiple indicator matrices. In my opinion, it should be possible to return a single (sparse) matrix, along with an array n_nodes_ptr.

@Anarion-zuo
Copy link
Author

For sklearn

indicatorsparse matrix of shape (n_samples, n_nodes)

Let me explain my understanding of this API. This second dimension accepts node id as index. [n_nodes_ptr[i]:n_nodes_ptr[i+1]] points to the range of node indices of estimator i. This would require node indices to be global, i.e. different trees do not have the same node id for their nodes.

Does my understanding align with what sklearn has? If so, xgboost's trees do not shared node id globally, to my knowledge, and therefore cannot return the nodes ids as the trees have them. Nonetheless, we could have a mapping from (tree_id, node_id) to some global node id for each node.

@hcho3
Copy link
Collaborator

hcho3 commented May 18, 2023

From the scikit-learn doc:

n_nodes_ptr: ndarray of shape (n_estimators + 1,)
The columns from indicator[n_nodes_ptr[i]:n_nodes_ptr[i+1]] gives the indicator value for the i-th estimator.

So n_nodes_ptr only tells you which columns in indicator matrix corresponds to the i-th trees. It does not require the global IDs

@Anarion-zuo
Copy link
Author

Anarion-zuo commented May 18, 2023

The wording here is a bit strange to me. I thought indicator[n_nodes_ptr[i]:n_nodes_ptr[i+1]] refers to a range rows, and The columns refers to the columns of these rows. Now that you mentioned it, it is clear to me that it refers to a range of cols.

I should alter my implementation accordingly.

@hcho3
Copy link
Collaborator

hcho3 commented May 18, 2023

The current implementation is too complex and cannot be admitted as it is. Here's my guidance for implementing decision_path.

  1. Implement two C API functions:
  • int XGBoosterQueryNumNodes(BoosterHandle handle, const int** num_nodes): returns the number of nodes per tree
  • int XGBoosterGetDecisionPath(BoosterHandle handle, int tree_id, int node_id, const int** decision_path): returns the list of all ancestors of a given node.
  1. Other than the two C API functions, implement decision_path in the Python layer only. We don't need to support this feature in other language bindings. The current PR tries to modify too many places in the C++ codebase, raising the risk of breaking existing functionalities.
  2. Allocate a single 2D dense matrix filled with zero, with shape (n_rows, n_total_nodes), where n_total_nodes is the total number of all nodes in all trees.
n_total_nodes = #  ... call C API to compute this
indicator = np.zeros((n_rows, n_total_nodes), dtype=np.int8)
  1. Allocate a single 1D array of length n_trees + 1. We call this array n_nodes_ptr.
  2. Fill n_nodes_ptr as follows:
n_trees = len(bst.get_dump(dump_format="json"))
n_nodes_ptr[0] = 0
for i in range(1, n_trees + 1):
    n_nodes_ptr[i] = n_nodes_ptr[i - 1] + get_n_nodes_for_tree(i)

Use XGBoosterQueryNumNodes to implement get_n_nodes_for_tree().
6. Obtain (local) leaf IDs by running Booster.predict with pred_leaf=True:

leaf_ids = bst.predict(X, pred_leaf=True)
  1. For each leaf_id[row_id] (which is the leaf output for row row_id), get the list of ancestor nodes by calling XGBoosterGetDecisionPath. Then set the corresponding entries in the indicator matrix:
for row_id in range(X.shape[0]):
    for tree_id in range(n_trees):
        ancestor_list = # ... 
        for ancestor in ancestor_list:
            indicator[row_id, n_nodes_ptr[tree_id] + ancestor] = 1

Note the use of n_nodes_ptr[tree_id] in the second index.
8. Now return indicator and n_nodes_ptr.

Let me know if you need any help. In particular, I will help you with creating "C API functions" and how to call them from the Python side.

@Anarion-zuo
Copy link
Author

Thanks for the suggestion. I should fork another branch to do this.

@hcho3
Copy link
Collaborator

hcho3 commented May 18, 2023

Feel free to reach out to me for help

@Anarion-zuo
Copy link
Author

Sorry about the delay. I was busy with other work. Now that I a bit more free, I would submit another PR in the recent days.

@trivialfis
Copy link
Member

trivialfis commented Jul 9, 2023 via email

@trivialfis trivialfis added this to To do in 2.1 Roadmap Aug 17, 2023
@Anarion-zuo
Copy link
Author

Is this feature written by someone else yet? I am truly free now. Here to fulfill my long-awaited destiny...

@trivialfis
Copy link
Member

trivialfis commented Aug 22, 2023

I don't think any of the core contributor is actively working on it. But do let me refresh, I want to have a deeper look before giving any advice.

@Anarion-zuo
Copy link
Author

Is there any update?

@trivialfis
Copy link
Member

trivialfis commented Aug 24, 2023

Here's another proposal for the API. Looking deeper into the decision_path function from sklearn, I think we might be able to improve upon it so that the returned matrix is easier for others to parse. My proposal is slightly different from the one found in sklearn.

Format

Let's assume we have two trees $t_0$, $t_1$, and two samples $x_0$, $x_1$ for demonstration. $x_0$ goes through 0 -> 1 -> 4 nodes in $t_0$, and 0 -> 2 -> 3 -> 5 nodes in $t_1$. $x_1$ goes through 0 -> 1 -> 3 -> 5 of $t_0$ and 0 -> 2 of $t_1$. With this, we have the following table:

$t_0$ $t_1$
$x_0$ 0, 1, 4 0, 2, 3, 5
$x_1$ 0, 1, 3, 5 0, 2

The question is now how do we represent this table in Python efficiently while maintaining some level of interpretability. My suggestion is we return a list of CSR matrics with a size equal to the number of trees: $|trees|$. Each CSR inside the list represents one column in the table.

decision_path = [
    CSR(values=[0, 1, 4, 0, 1, 3, 5], indptr=[0, 3, 7]),
    CSR(values=[0, 2, 3, 5, 0, 2], indptr=[0, 5, 6]),
]

However, complications arise when we have multi-class models or boosted random forests as it's quite difficult for users to find the correct tree. Without further grouping strategies, it's almost impossible for a normal user to find "the trees that represent the second class".

Implementation

I agree with @hcho3 that we need a 2-phrase API for this feature. Using thread-local memory would be too expensive. However, I haven't decided what exactly the C function should return. For instance, we can return a single tree and let the Python function loop over all trees, or we can return all the paths in one go. Repeating the concern in the previous section, we might use the iteration_range like the predict method.

Would love to hear others' opinions on these issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

None yet

3 participants