From b4340abf5694c790ea5730f963680a1048fa172c Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Sat, 29 Jan 2022 15:54:49 -0800 Subject: [PATCH] Add special handling for multi:softmax in sklearn predict (#7607) * Add special handling for multi:softmax in sklearn predict * Add test coverage --- python-package/xgboost/sklearn.py | 2 ++ tests/python/test_with_sklearn.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index d6ab64ef2917..0a9b9c923135 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1419,6 +1419,8 @@ def predict( # multi-label column_indexes = np.zeros(class_probs.shape) column_indexes[class_probs > 0.5] = 1 + elif self.objective == "multi:softmax": + return class_probs.astype(np.int32) else: # turns soft logit into class label column_indexes = np.repeat(0, class_probs.shape[0]) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 149e77ed72df..98ae09ed7bdb 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -36,7 +36,8 @@ def test_binary_classification(): assert err < 0.1 -def test_multiclass_classification(): +@pytest.mark.parametrize('objective', ['multi:softmax', 'multi:softprob']) +def test_multiclass_classification(objective): from sklearn.datasets import load_iris from sklearn.model_selection import KFold @@ -54,7 +55,7 @@ def check_pred(preds, labels, output_margin): X = iris['data'] kf = KFold(n_splits=2, shuffle=True, random_state=rng) for train_index, test_index in kf.split(X, y): - xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) + xgb_model = xgb.XGBClassifier(objective=objective).fit(X[train_index], y[train_index]) assert (xgb_model.get_booster().num_boosted_rounds() == xgb_model.n_estimators) preds = xgb_model.predict(X[test_index])