From 36a38c66d1d0cceb17bcc859ec9eb029ab3253ae Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 19 Dec 2021 11:59:55 +0800 Subject: [PATCH 1/4] Initial support for multi-label classification. --- doc/tutorials/index.rst | 1 + doc/tutorials/multioutput.rst | 32 +++++++++++++++++++++++++++++++ python-package/xgboost/sklearn.py | 14 +++++++++++++- tests/python/test_with_sklearn.py | 18 +++++++++++++++++ 4 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 doc/tutorials/multioutput.rst diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 4aa636d64de5..d2cf979e39f3 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -27,3 +27,4 @@ See `Awesome XGBoost `_ for mo external_memory custom_metric_obj categorical + multioutput diff --git a/doc/tutorials/multioutput.rst b/doc/tutorials/multioutput.rst new file mode 100644 index 000000000000..f42559a78167 --- /dev/null +++ b/doc/tutorials/multioutput.rst @@ -0,0 +1,32 @@ +################ +Multiple Outputs +################ + +.. versionadded:: 1.6 + +Starting from version 1.6, XGBoost has experimental support for multi-output regression +and multi-label classification. For the terminologies please refer to the `scikit-learn +user guide `_. + +Internally, XGBoost builds one model for each target similar to sklearn meta estimators, +with the added benefit of reusing data and custom objective support. For a worked example +of regression, see :ref:`sphx_glr_python_examples_multioutput_regression.py`. For +multi-label classification, the binary relevance strategy is used. Since classes are not +mutually exclusive so XGBoost will train one binary classifier for each target. Input +``y`` should be of shape ``(n_samples, n_classes)`` with each column has value 0 or 1 to +specify whether the sample is labeled as positive. + +.. code-block:: python + + from sklearn.datasets import make_multilabel_classification + import numpy as np + + X, y = make_multilabel_classification( + n_samples=32, n_classes=5, n_labels=3, random_state=0 + ) + clf = xgb.XGBClassifier(tree_method="hist") + clf.fit(X, y) + np.testing.assert_allclose(clf.predict(X), y) + + +The feature is still under development and might contain unknown issues. diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 949dae7b46c7..dbcd5f8755fb 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1215,6 +1215,14 @@ def intercept_(self) -> np.ndarray: def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> PredtT: assert len(prediction.shape) <= 2 if len(prediction.shape) == 2 and prediction.shape[1] == n_classes: + # multi-class + return prediction + if ( + len(prediction.shape) == 2 + and n_classes == 2 + and prediction.shape[1] >= n_classes + ): + # multi-label return prediction # binary logistic function classone_probs = prediction @@ -1374,9 +1382,13 @@ def predict( # If output_margin is active, simply return the scores return class_probs - if len(class_probs.shape) > 1: + if len(class_probs.shape) > 1 and self.n_classes_ != 2: # turns softprob into softmax column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore + elif len(class_probs.shape) > 1: + # multi-label + column_indexes = np.zeros(class_probs.shape) + column_indexes[class_probs > 0.5] = 1 else: # turns soft logit into class label column_indexes = np.repeat(0, class_probs.shape[0]) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index a5c0d8fe2cd9..83c73932b4a2 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1194,6 +1194,24 @@ def test_estimator_type(): cls.load_model(path) # no error +def test_multilabel_classification() -> None: + from sklearn.datasets import make_multilabel_classification + + X, y = make_multilabel_classification( + n_samples=32, n_classes=5, n_labels=3, random_state=0 + ) + clf = xgb.XGBClassifier(tree_method="hist") + clf.fit(X, y) + booster = clf.get_booster() + learner = json.loads(booster.save_config())["learner"] + assert int(learner["learner_model_param"]["num_target"]) == 5 + + np.testing.assert_allclose(clf.predict(X), y) + predt = (clf.predict_proba(X) > 0.5).astype(np.int64) + np.testing.assert_allclose(clf.predict(X), predt) + assert predt.dtype == np.int64 + + def run_data_initialization(DMatrix, model, X, y): """Assert that we don't create duplicated DMatrix.""" From 8754928fb2ad119cda0c1bc3ec3248116471d27c Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 19 Dec 2021 12:33:27 +0800 Subject: [PATCH 2/4] Ensure correct shape. --- python-package/xgboost/sklearn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index dbcd5f8755fb..f1592e438b4d 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1385,7 +1385,7 @@ def predict( if len(class_probs.shape) > 1 and self.n_classes_ != 2: # turns softprob into softmax column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore - elif len(class_probs.shape) > 1: + elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1: # multi-label column_indexes = np.zeros(class_probs.shape) column_indexes[class_probs > 0.5] = 1 From faaa47ce4f0e38b49649d31cda98395c492ff3f8 Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 19 Dec 2021 12:47:27 +0800 Subject: [PATCH 3/4] doc. --- doc/tutorials/multioutput.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/tutorials/multioutput.rst b/doc/tutorials/multioutput.rst index f42559a78167..b8a953c2e8f9 100644 --- a/doc/tutorials/multioutput.rst +++ b/doc/tutorials/multioutput.rst @@ -11,10 +11,10 @@ user guide `_. Internally, XGBoost builds one model for each target similar to sklearn meta estimators, with the added benefit of reusing data and custom objective support. For a worked example of regression, see :ref:`sphx_glr_python_examples_multioutput_regression.py`. For -multi-label classification, the binary relevance strategy is used. Since classes are not -mutually exclusive so XGBoost will train one binary classifier for each target. Input -``y`` should be of shape ``(n_samples, n_classes)`` with each column has value 0 or 1 to -specify whether the sample is labeled as positive. +multi-label classification, the binary relevance strategy is used. Input ``y`` should be +of shape ``(n_samples, n_classes)`` with each column has value 0 or 1 to specify whether +the sample is labeled as positive. At the moment XGBoost supports only dense matrix for +labels. .. code-block:: python From e4b4123025bc9bfc5ad59e0f2b1a4f651996a6a4 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 21 Dec 2021 00:33:49 +0800 Subject: [PATCH 4/4] Doc. --- doc/tutorials/multioutput.rst | 17 +++++++++++------ python-package/xgboost/sklearn.py | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/doc/tutorials/multioutput.rst b/doc/tutorials/multioutput.rst index b8a953c2e8f9..d9af9313e475 100644 --- a/doc/tutorials/multioutput.rst +++ b/doc/tutorials/multioutput.rst @@ -5,16 +5,21 @@ Multiple Outputs .. versionadded:: 1.6 Starting from version 1.6, XGBoost has experimental support for multi-output regression -and multi-label classification. For the terminologies please refer to the `scikit-learn -user guide `_. +and multi-label classification with Python package. Multi-label classification usually +refers to targets that have multiple non-exclusive class labels. For instance, a movie +can be simultaneously classified as both sci-fi and comedy. For detailed explanation of +terminologies related to different multi-output models please refer to the `scikit-learn +user guide `_. Internally, XGBoost builds one model for each target similar to sklearn meta estimators, with the added benefit of reusing data and custom objective support. For a worked example of regression, see :ref:`sphx_glr_python_examples_multioutput_regression.py`. For multi-label classification, the binary relevance strategy is used. Input ``y`` should be -of shape ``(n_samples, n_classes)`` with each column has value 0 or 1 to specify whether -the sample is labeled as positive. At the moment XGBoost supports only dense matrix for -labels. +of shape ``(n_samples, n_classes)`` with each column having a value of 0 or 1 to specify +whether the sample is labeled as positive for respective class. Given a sample with 3 +output classes and 2 labels, the corresponding `y` should be encoded as ``[1, 0, 1]`` with +the second class labeled as negative and the rest labeled as positive. At the moment +XGBoost supports only dense matrix for labels. .. code-block:: python @@ -29,4 +34,4 @@ labels. np.testing.assert_allclose(clf.predict(X), y) -The feature is still under development and might contain unknown issues. +The feature is still under development with limited support from objectives and metrics. diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index f1592e438b4d..a2589060608c 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1383,7 +1383,7 @@ def predict( return class_probs if len(class_probs.shape) > 1 and self.n_classes_ != 2: - # turns softprob into softmax + # multi-class, turns softprob into softmax column_indexes: np.ndarray = np.argmax(class_probs, axis=1) # type: ignore elif len(class_probs.shape) > 1 and class_probs.shape[1] != 1: # multi-label