Skip to content

Commit

Permalink
Add special handling for multi:softmax in sklearn predict (#7607)
Browse files Browse the repository at this point in the history
* Add special handling for multi:softmax in sklearn predict

* Add test coverage
  • Loading branch information
hcho3 committed Jan 29, 2022
1 parent 7f738e7 commit b4340ab
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions python-package/xgboost/sklearn.py
Expand Up @@ -1419,6 +1419,8 @@ def predict(
# multi-label
column_indexes = np.zeros(class_probs.shape)
column_indexes[class_probs > 0.5] = 1
elif self.objective == "multi:softmax":
return class_probs.astype(np.int32)
else:
# turns soft logit into class label
column_indexes = np.repeat(0, class_probs.shape[0])
Expand Down
5 changes: 3 additions & 2 deletions tests/python/test_with_sklearn.py
Expand Up @@ -36,7 +36,8 @@ def test_binary_classification():
assert err < 0.1


def test_multiclass_classification():
@pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob'])
def test_multiclass_classification(objective):
from sklearn.datasets import load_iris
from sklearn.model_selection import KFold

Expand All @@ -54,7 +55,7 @@ def check_pred(preds, labels, output_margin):
X = iris['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index])
assert (xgb_model.get_booster().num_boosted_rounds() ==
xgb_model.n_estimators)
preds = xgb_model.predict(X[test_index])
Expand Down

0 comments on commit b4340ab

Please sign in to comment.