diff --git a/demo/guide-python/README.md b/demo/guide-python/README.md index 092931fd400c..6fd6a090c8d5 100644 --- a/demo/guide-python/README.md +++ b/demo/guide-python/README.md @@ -16,3 +16,5 @@ XGBoost Python Feature Walkthrough * [External Memory](external_memory.py) * [Training continuation](continuation.py) * [Feature weights for column sampling](feature_weights.py) +* [Basic Categorical data support](categorical.py) +* [Compare builtin categorical data support with one-hot encoding](cat_in_the_dat.py) \ No newline at end of file diff --git a/demo/guide-python/cat_in_the_dat.py b/demo/guide-python/cat_in_the_dat.py new file mode 100644 index 000000000000..551f4f535714 --- /dev/null +++ b/demo/guide-python/cat_in_the_dat.py @@ -0,0 +1,118 @@ +"""A simple demo for categorical data support using dataset from Kaggle categorical data +tutorial. + +The excellent tutorial is at: +https://www.kaggle.com/shahules/an-overview-of-encoding-techniques + +And the data can be found at: +https://www.kaggle.com/shahules/an-overview-of-encoding-techniques/data + + .. versionadded 1.6.0 + +""" + +from __future__ import annotations +from time import time +import os +from tempfile import TemporaryDirectory + +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.metrics import roc_auc_score + +import xgboost as xgb + + +def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]: + """Assuming you have already downloaded the data into `input` directory.""" + + df_train = pd.read_csv("./input/cat-in-the-dat/train.csv") + + print( + "train data set has got {} rows and {} columns".format( + df_train.shape[0], df_train.shape[1] + ) + ) + X = df_train.drop(["target"], axis=1) + y = df_train["target"] + + for i in range(0, 5): + X["bin_" + str(i)] = X["bin_" + str(i)].astype("category") + + for i in range(0, 5): + X["nom_" + str(i)] = X["nom_" + str(i)].astype("category") + + for i in range(5, 10): + X["nom_" + str(i)] = X["nom_" + str(i)].apply(int, base=16) + + for i in range(0, 6): + X["ord_" + str(i)] = X["ord_" + str(i)].astype("category") + + print(X.shape) + + print( + "train data set has got {} rows and {} columns".format(X.shape[0], X.shape[1]) + ) + return X, y + + +params = {"tree_method": "gpu_hist", "use_label_encoder": False, "n_estimators": 32} + + +def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None: + """Train using builtin categorical data support from XGBoost""" + X_train, X_test, y_train, y_test = train_test_split( + X, y, random_state=1994, test_size=0.2 + ) + + clf = xgb.XGBClassifier(**params, enable_categorical=True) + clf.fit( + X_train, + y_train, + eval_set=[(X_test, y_test), (X_train, y_train)], + eval_metric="auc", + ) + print(clf.n_classes_) + clf.save_model(os.path.join(output_dir, "categorical.json")) + + y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples + auc = roc_auc_score(y_test, y_score) + print("AUC of using builtin categorical data support:", auc) + + +def onehot_encoding_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None: + """Train using one-hot encoded data.""" + + X_train, X_test, y_train, y_test = train_test_split( + X, y, random_state=42, test_size=0.2 + ) + print(X_train.shape, y_train.shape) + + clf = xgb.XGBClassifier(**params, enable_categorical=False) + clf.fit( + X_train, + y_train, + eval_set=[(X_test, y_test), (X_train, y_train)], + eval_metric="auc", + ) + clf.save_model(os.path.join(output_dir, "one-hot.json")) + + y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples + auc = roc_auc_score(y_test, y_score) + print("AUC of using onehot encoding:", auc) + + +if __name__ == "__main__": + X, y = load_cat_in_the_dat() + + with TemporaryDirectory() as tmpdir: + start = time() + categorical_model(X, y, tmpdir) + end = time() + print("Duration:categorical", end - start) + + X = pd.get_dummies(X) + start = time() + onehot_encoding_model(X, y, tmpdir) + end = time() + print("Duration:onehot", end - start) diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst new file mode 100644 index 000000000000..6ee724d45b6a --- /dev/null +++ b/doc/tutorials/categorical.rst @@ -0,0 +1,118 @@ +################ +Categorical Data +################ + +Starting from version 1.5, XGBoost has experimental support for categorical data available +for public testing. At the moment, the support is implemented as one-hot encoding based +categorical tree splits. For numerical data, the split condition is defined as +:math:`value < threshold`, while for categorical data the split is defined as :math:`value +== category` and ``category`` is a discrete value. More advanced categorical split +strategy is planned for future releases and this tutorial details how to inform XGBoost +about the data type. Also, the current support for training is limited to ``gpu_hist`` +tree method. + +************************************ +Training with scikit-learn Interface +************************************ + +The easiest way to pass categorical data into XGBoost is using dataframe and the +``scikit-learn`` interface like :class:`XGBClassifier `. For +preparing the data, users need to specify the data type of input predictor as +``category``. For ``pandas/cudf Dataframe``, this can be achieved by + +.. code:: python + + X["cat_feature"].astype("category") + +for all columns that represent categorical features. After which, users can tell XGBoost +to enable training with categorical data. Assuming that you are using the +:class:`XGBClassifier ` for classification problem, specify the +parameter ``enable_categorical``: + +.. code:: python + + # Only gpu_hist is supported for categorical data as mentioned previously + clf = xgb.XGBClassifier( + tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False + ) + # X is the dataframe we created in previous snippet + clf.fit(X, y) + # Must use JSON for serialization, otherwise the information is lost + clf.save_model("categorical-model.json") + + +Once training is finished, most of other features can utilize the model. For instance one +can plot the model and calculate the global feature importance: + + +.. code:: python + + # Get a graph + graph = xgb.to_graphviz(clf, num_trees=1) + # Or get a matplotlib axis + ax = xgb.plot_tree(reg, num_trees=1) + # Get feature importances + clf.feature_importances_ + + +The ``scikit-learn`` interface from dask is similar to single node version. The basic +idea is create dataframe with category feature type, and tell XGBoost to use ``gpu_hist`` +with parameter ``enable_categorical``. See `this demo +`_ for a +worked example using categorical data with ``scikit-learn`` interface. For using it with +the Kaggle tutorial dataset, see ``_ + + +********************** +Using native interface +********************** + +The ``scikit-learn`` interface is user friendly, but lacks some features that are only +available in native interface. For instance users cannot compute SHAP value directly or +use quantized ``DMatrix``. Also native interface supports data types other than +dataframe, like ``numpy/cupy array``. To use the native interface with categorical data, +we need to pass the similar parameter to ``DMatrix`` and the ``train`` function. For +dataframe input: + +.. code:: python + + # X is a dataframe we created in previous snippet + Xy = xgb.DMatrix(X, y, enable_categorical=True) + booster = xgb.train({"tree_method": "gpu_hist"}, Xy) + # Must use JSON for serialization, otherwise the information is lost + booster.save_model("categorical-model.json") + +SHAP value computation: + +.. code:: python + + SHAP = booster.predict(Xy, pred_interactions=True) + + # categorical features are listed as "c" + print(booster.feature_types) + + +For other types of input, like ``numpy array``, we can tell XGBoost about the feature +types by using the ``feature_types`` parameter in :class:`DMatrix `: + +.. code:: python + + # "q" is numerical feature, while "c" is categorical feature + ft = ["q", "c", "c"] + X: np.ndarray = load_my_data() + assert X.shape[1] == 3 + Xy = xgb.DMatrix(X, y, feature_types=ft, enable_categorical=True) + +For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical +feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so +``dask.Array`` can also be used as categorical data. + + +********** +Next Steps +********** + +As of XGBoost 1.5, the feature is highly experimental and have limited features like CPU +training is not yet supported. Please see ` +https://github.com/dmlc/xgboost/issues/6503`_ for progress. diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 5363523315c9..4aa636d64de5 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -26,3 +26,4 @@ See `Awesome XGBoost `_ for mo param_tuning external_memory custom_metric_obj + categorical diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index e4b6f2f8fab4..01f3a4279074 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -807,7 +807,11 @@ def _can_use_inplace_predict(self) -> bool: # Inplace predict doesn't handle as many data types as DMatrix, but it's # sufficient for dask interface where input is simpiler. predictor = self.get_params().get("predictor", None) - if predictor in ("auto", None) and self.booster != "gblinear": + if ( + not self.enable_categorical + and predictor in ("auto", None) + and self.booster != "gblinear" + ): return True return False @@ -886,7 +890,10 @@ def predict( pass test = DMatrix( - X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs + X, base_margin=base_margin, + missing=self.missing, + nthread=self.n_jobs, + enable_categorical=self.enable_categorical ) return self.get_booster().predict( data=test,