Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add document for categorical data. #7307

Merged
merged 1 commit into from Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions demo/guide-python/README.md
Expand Up @@ -16,3 +16,5 @@ XGBoost Python Feature Walkthrough
* [External Memory](external_memory.py)
* [Training continuation](continuation.py)
* [Feature weights for column sampling](feature_weights.py)
* [Basic Categorical data support](categorical.py)
* [Compare builtin categorical data support with one-hot encoding](cat_in_the_dat.py)
118 changes: 118 additions & 0 deletions demo/guide-python/cat_in_the_dat.py
@@ -0,0 +1,118 @@
"""A simple demo for categorical data support using dataset from Kaggle categorical data
tutorial.

The excellent tutorial is at:
https://www.kaggle.com/shahules/an-overview-of-encoding-techniques

And the data can be found at:
https://www.kaggle.com/shahules/an-overview-of-encoding-techniques/data

.. versionadded 1.6.0

"""

from __future__ import annotations
from time import time
import os
from tempfile import TemporaryDirectory

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

import xgboost as xgb


def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]:
"""Assuming you have already downloaded the data into `input` directory."""

df_train = pd.read_csv("./input/cat-in-the-dat/train.csv")

print(
"train data set has got {} rows and {} columns".format(
df_train.shape[0], df_train.shape[1]
)
)
X = df_train.drop(["target"], axis=1)
y = df_train["target"]

for i in range(0, 5):
X["bin_" + str(i)] = X["bin_" + str(i)].astype("category")

for i in range(0, 5):
X["nom_" + str(i)] = X["nom_" + str(i)].astype("category")

for i in range(5, 10):
X["nom_" + str(i)] = X["nom_" + str(i)].apply(int, base=16)

for i in range(0, 6):
X["ord_" + str(i)] = X["ord_" + str(i)].astype("category")

print(X.shape)

print(
"train data set has got {} rows and {} columns".format(X.shape[0], X.shape[1])
)
return X, y


params = {"tree_method": "gpu_hist", "use_label_encoder": False, "n_estimators": 32}


def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
"""Train using builtin categorical data support from XGBoost"""
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=1994, test_size=0.2
)

clf = xgb.XGBClassifier(**params, enable_categorical=True)
clf.fit(
X_train,
y_train,
eval_set=[(X_test, y_test), (X_train, y_train)],
eval_metric="auc",
)
print(clf.n_classes_)
clf.save_model(os.path.join(output_dir, "categorical.json"))

y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples
auc = roc_auc_score(y_test, y_score)
print("AUC of using builtin categorical data support:", auc)


def onehot_encoding_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
"""Train using one-hot encoded data."""

X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=42, test_size=0.2
)
print(X_train.shape, y_train.shape)

clf = xgb.XGBClassifier(**params, enable_categorical=False)
clf.fit(
X_train,
y_train,
eval_set=[(X_test, y_test), (X_train, y_train)],
eval_metric="auc",
)
clf.save_model(os.path.join(output_dir, "one-hot.json"))

y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples
auc = roc_auc_score(y_test, y_score)
print("AUC of using onehot encoding:", auc)


if __name__ == "__main__":
X, y = load_cat_in_the_dat()

with TemporaryDirectory() as tmpdir:
start = time()
categorical_model(X, y, tmpdir)
end = time()
print("Duration:categorical", end - start)

X = pd.get_dummies(X)
start = time()
onehot_encoding_model(X, y, tmpdir)
end = time()
print("Duration:onehot", end - start)
118 changes: 118 additions & 0 deletions doc/tutorials/categorical.rst
@@ -0,0 +1,118 @@
################
Categorical Data
################

Starting from version 1.5, XGBoost has experimental support for categorical data available
for public testing. At the moment, the support is implemented as one-hot encoding based
categorical tree splits. For numerical data, the split condition is defined as
:math:`value < threshold`, while for categorical data the split is defined as :math:`value
== category` and ``category`` is a discrete value. More advanced categorical split
strategy is planned for future releases and this tutorial details how to inform XGBoost
about the data type. Also, the current support for training is limited to ``gpu_hist``
tree method.

************************************
Training with scikit-learn Interface
************************************

The easiest way to pass categorical data into XGBoost is using dataframe and the
``scikit-learn`` interface like :class:`XGBClassifier <xgboost.XGBClassifier>`. For
preparing the data, users need to specify the data type of input predictor as
``category``. For ``pandas/cudf Dataframe``, this can be achieved by

.. code:: python

X["cat_feature"].astype("category")

for all columns that represent categorical features. After which, users can tell XGBoost
to enable training with categorical data. Assuming that you are using the
:class:`XGBClassifier <xgboost.XGBClassifier>` for classification problem, specify the
parameter ``enable_categorical``:

.. code:: python

# Only gpu_hist is supported for categorical data as mentioned previously
clf = xgb.XGBClassifier(
tree_method="gpu_hist", enable_categorical=True, use_label_encoder=False
)
# X is the dataframe we created in previous snippet
clf.fit(X, y)
# Must use JSON for serialization, otherwise the information is lost
clf.save_model("categorical-model.json")


Once training is finished, most of other features can utilize the model. For instance one
can plot the model and calculate the global feature importance:


.. code:: python

# Get a graph
graph = xgb.to_graphviz(clf, num_trees=1)
# Or get a matplotlib axis
ax = xgb.plot_tree(reg, num_trees=1)
# Get feature importances
clf.feature_importances_


The ``scikit-learn`` interface from dask is similar to single node version. The basic
idea is create dataframe with category feature type, and tell XGBoost to use ``gpu_hist``
with parameter ``enable_categorical``. See `this demo
<https://github.com/dmlc/xgboost/blob/master/demo/guide-python/categorical.py>`_ for a
worked example using categorical data with ``scikit-learn`` interface. For using it with
the Kaggle tutorial dataset, see `<this demo
https://github.com/dmlc/xgboost/blob/master/demo/guide-python/cat_in_the_dat.py>`_


**********************
Using native interface
**********************

The ``scikit-learn`` interface is user friendly, but lacks some features that are only
available in native interface. For instance users cannot compute SHAP value directly or
use quantized ``DMatrix``. Also native interface supports data types other than
dataframe, like ``numpy/cupy array``. To use the native interface with categorical data,
we need to pass the similar parameter to ``DMatrix`` and the ``train`` function. For
dataframe input:

.. code:: python

# X is a dataframe we created in previous snippet
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "gpu_hist"}, Xy)
# Must use JSON for serialization, otherwise the information is lost
booster.save_model("categorical-model.json")

SHAP value computation:

.. code:: python

SHAP = booster.predict(Xy, pred_interactions=True)

# categorical features are listed as "c"
print(booster.feature_types)


For other types of input, like ``numpy array``, we can tell XGBoost about the feature
types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatrix>`:

.. code:: python

# "q" is numerical feature, while "c" is categorical feature
ft = ["q", "c", "c"]
X: np.ndarray = load_my_data()
assert X.shape[1] == 3
Xy = xgb.DMatrix(X, y, feature_types=ft, enable_categorical=True)

For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
``dask.Array`` can also be used as categorical data.


**********
Next Steps
**********

As of XGBoost 1.5, the feature is highly experimental and have limited features like CPU
training is not yet supported. Please see `<this issue>
https://github.com/dmlc/xgboost/issues/6503`_ for progress.
1 change: 1 addition & 0 deletions doc/tutorials/index.rst
Expand Up @@ -26,3 +26,4 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
param_tuning
external_memory
custom_metric_obj
categorical
11 changes: 9 additions & 2 deletions python-package/xgboost/sklearn.py
Expand Up @@ -807,7 +807,11 @@ def _can_use_inplace_predict(self) -> bool:
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
predictor = self.get_params().get("predictor", None)
if predictor in ("auto", None) and self.booster != "gblinear":
if (
not self.enable_categorical
and predictor in ("auto", None)
and self.booster != "gblinear"
):
return True
return False

Expand Down Expand Up @@ -886,7 +890,10 @@ def predict(
pass

test = DMatrix(
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
X, base_margin=base_margin,
missing=self.missing,
nthread=self.n_jobs,
enable_categorical=self.enable_categorical
)
return self.get_booster().predict(
data=test,
Expand Down