Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Handle categorical split in model histogram and dataframe. (#7065)
* Error on get_split_value_histogram when feature is categorical * Add a category column to output dataframe
- Loading branch information
1 parent
1cd20ef
commit a5d222f
Showing
3 changed files
with
96 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import sys | ||
import pytest | ||
import xgboost as xgb | ||
|
||
sys.path.append("tests/python") | ||
import testing as tm | ||
|
||
|
||
def test_tree_to_df_categorical(): | ||
X, y = tm.make_categorical(100, 10, 31, False) | ||
Xy = xgb.DMatrix(X, y, enable_categorical=True) | ||
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10) | ||
df = booster.trees_to_dataframe() | ||
for _, x in df.iterrows(): | ||
if x["Feature"] != "Leaf": | ||
assert len(x["Category"]) == 1 | ||
|
||
|
||
def test_split_value_histograms(): | ||
X, y = tm.make_categorical(1000, 10, 13, False) | ||
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) | ||
reg.fit(X, y) | ||
|
||
with pytest.raises(ValueError, match="doesn't"): | ||
reg.get_booster().get_split_value_histogram("3", bins=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters