Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add Friedman's H-squared #28375

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
2f06cbc
Add Friedman's H-squared of pairwise interaction statistics
mayer79 Feb 6, 2024
38f3d6c
run black
mayer79 Feb 6, 2024
f173c51
np.unique() does not work well for non-standard values
mayer79 Feb 7, 2024
2e7b731
Reorganize imports
mayer79 Feb 7, 2024
64581e2
run ruff
mayer79 Feb 7, 2024
8e50637
check weights and is fitted
Feb 8, 2024
efa4071
Replace compression logic by try except
Feb 8, 2024
f53b838
Switch to Bunch output
Feb 8, 2024
032e661
Clip small numerators
Feb 8, 2024
1335f05
Apply suggestions from code review
mayer79 Feb 9, 2024
a28e8e0
use sample_without_replacement, plus some docstring improvements
mayer79 Feb 9, 2024
66f7bdd
fix existing problems
mayer79 Feb 9, 2024
6db5ac7
Merge branch 'main' into friedmans-h
mayer79 Feb 9, 2024
5beb941
Rename things and fix imports
mayer79 Feb 9, 2024
9823fe5
More compact output organization, faster example
mayer79 Feb 10, 2024
5aebb4c
Add formula to docstring
mayer79 Feb 10, 2024
36ed7a8
Add preliminary unit tests
mayer79 Feb 10, 2024
27e3540
Compare against two R packages
mayer79 Feb 11, 2024
e5f6e53
Split calculate_pd_over_data into two plus some optimizations
mayer79 Feb 12, 2024
a2b659d
Fix typos in docstring
mayer79 Feb 12, 2024
69111e3
Fix example output in docstring
mayer79 Feb 12, 2024
0416de3
Apply suggestions from code review
mayer79 Apr 13, 2024
3fe5d64
More changed from review
mayer79 Apr 13, 2024
c583aeb
add validate_params()
mayer79 Apr 13, 2024
a443f02
Add h_statistic to test_public_functions.py
mayer79 Apr 14, 2024
a3aaed4
Possession apostrophs
mayer79 Apr 14, 2024
95ed5de
Add docu
mayer79 Apr 14, 2024
7f4527f
Add entry to classes.rst
mayer79 Apr 16, 2024
0ee0f7a
reorder position in classes.rst
mayer79 Apr 16, 2024
309edef
Merge branch 'main' into friedmans-h
mayer79 Apr 19, 2024
a93a4c9
safe assign and indexing have moved
mayer79 Apr 19, 2024
6b63a55
Fix doctest failure
mayer79 Apr 19, 2024
3923d73
fix docstring failure attempt 2
mayer79 Apr 19, 2024
d406dc6
doc tests do not seem to allow multiline command in parantheses
mayer79 Apr 19, 2024
70b9f68
docstring checks do not like black
mayer79 Apr 19, 2024
3ed07af
assign result of plot in docu
mayer79 Apr 19, 2024
c7b798e
superfluous newline in docstring of function
mayer79 Apr 19, 2024
7e4ff8b
Replace plot by print()
mayer79 Apr 19, 2024
c5e56ab
rst docu: reformat code
mayer79 Apr 19, 2024
9ed9b2b
Change intendation in example output of rst docu
mayer79 Apr 19, 2024
0b29fc8
doctest failure
mayer79 Apr 19, 2024
5845331
documentation: image is better than print
mayer79 Apr 20, 2024
91c2d65
Intendation fix
mayer79 Apr 20, 2024
25a3f6d
fix doctest issues
mayer79 Apr 20, 2024
72ce9c7
Review Lorentzen
mayer79 Apr 26, 2024
cde7dae
switch to pred_fun argument in helper
mayer79 Apr 26, 2024
a3f1beb
Initialize all resulting numpy arrays
mayer79 Apr 26, 2024
ffc6b77
Too long line in docstring
mayer79 Apr 26, 2024
ef20380
Doctest failure
mayer79 Apr 27, 2024
3a34272
move example from rst file to plot_partial_dependence.py
mayer79 Apr 27, 2024
4658ee6
Reformat example output
mayer79 Apr 27, 2024
9bc1ff8
Fixing plot
mayer79 Apr 27, 2024
9b8a1d6
Fix reference
mayer79 Apr 27, 2024
a5cfa08
maybe we need copy()
mayer79 Apr 27, 2024
f36a913
Try second copy()
mayer79 Apr 27, 2024
54e98c8
Remove copy again
mayer79 Apr 27, 2024
131910b
fix dupe index issue with old pandas
mayer79 May 2, 2024
00f0eed
Fix typo
mayer79 May 17, 2024
90ffd7b
add column names to pandas unit test
mayer79 May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Binary file added doc/images/h_statistic.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/inspection.rst
Expand Up @@ -29,3 +29,4 @@ to diagnose issues with model performance.

modules/partial_dependence
modules/permutation_importance
modules/h_statistic
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Expand Up @@ -672,6 +672,7 @@ Kernels
:toctree: generated/
:template: function.rst

inspection.h_statistic
inspection.partial_dependence
inspection.permutation_importance

Expand Down
130 changes: 130 additions & 0 deletions doc/modules/h_statistic.rst
@@ -0,0 +1,130 @@

.. _h_statistic:

===============================================================
Friedman and Popescu's H-Statistic
===============================================================

.. currentmodule:: sklearn.inspection

What is the difference between a white box model and a black box model?
It is the many and complicated interaction effects of the latter.

Such interaction effects can be visualized by two-dimensional or stratified
partial dependence plots (PDP). But how to figure out between *which feature pairs*
the strongest interactions occur?

One approach is to study pairwise H-statistics, introduced by Friedman and Popescu
in [F2008]_. The H-statistic of two features provides the proportion of effect
variability of the two features coming from their pairwise interaction.

The figure below shows H-statistics and their unnormalized counterparts for
the bike sharing dataset, with a
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`:

.. figure:: ../auto_examples/inspection/images/sphx_glr_plot_partial_dependence_010.png
:target: ../auto_examples/inspection/plot_partial_dependence.html
:align: center
:scale: 70

The statistics have been compured for the five most important features.

Mathematical definition
=======================

**Partial dependence**

Let :math:`F: \mathbb{R}^p \to \mathbb{R}` denote the prediction function that
maps the :math:`p`-dimensional feature vector :math:`\mathbf x = (x_1, \dots, x_p)`
to its prediction.
Furthermore, let :math:`F_s(\mathbf x_s) = E_{\mathbf x_{\setminus s}}(F(\mathbf x_s, \mathbf x_{\setminus s}))`
be the partial dependence function of :math:`F` on the feature subset
:math:`\mathbf x_s`, where :math:`s \subseteq \{1, \dots, p\}`,
Here, the expectation runs over the joint marginal distribution of features
:math:`\mathbf x_{\setminus s}` not in :math:`\mathbf x_s`.

Given data, :math:`F_s(\mathbf x_s)` can be estimated by the empirical partial
dependence function

.. math::
\hat F_s(\mathbf x_s) = \frac{1}{n} \sum_{i = 1}^n F(\mathbf x_s, \mathbf x_{i \setminus s}),

where :math:`\mathbf x_{i\setminus s}`, :math:`i = 1, \dots, n`,
are the observed values of :math:`\mathbf x_{\setminus s}` in some "background" dataset.

**Pairwise H-statistic**

Following [F2008]_, if there are no interaction effects between features
:math:`x_j` and :math:`x_k`, their two-dimensional partial dependence function
:math:`F_{jk}` can be written as the sum of the univariate partial dependencies, i.e.,

.. math::
F_{jk}(x_j, x_k) = F_j(x_j) + F_k(x_k).

Correspondingly, Friedman and Popescu's H-statistic of pairwise interaction strength
is defined as

.. math::

H_{jk}^2 = A_{jk} / B_{jk},

where

.. math::

A_{jk} = \frac{1}{n} \sum_{i = 1}^n\big[\hat F_{jk}(x_{ij}, x_{ik}) - \hat F_j(x_{ij}) - \hat F_k(x_{ik})\big]^2

and

.. math::

B_{jk} = \frac{1}{n} \sum_{i = 1}^n\big[\hat F_{jk}(x_{ij}, x_{ik})\big]^2.

Remarks
=======

1. Partial dependence functions and :math:`F` are centered to mean 0.
2. Partial dependence functions and :math:`F` are evaluated over the data distribution.
This is different to partial dependence plots, where one uses a fixed grid.
3. Weighted versions follow by replacing all arithmetic means by corresponding weighted means.
4. Multi-output prediction (e.g., probabilistic classification) is handled component-wise.
5. Due to undesired extrapolation of partial dependence functions, values above 1 may occur.

Interpretation
==============

* The statistic provides the proportion of joint effect variability explained by the interaction.
* A value of 0 means no interaction.
* If main effects are weak, a small interaction effect can get a high value of the statistic.
Therefore, it often makes sense to study unnormalized statistics :math:`A_{jk}` or to
stay on the scale of the prediction :math:`\sqrt{A_{jk}}`.

Workflow
========

Calculating all pairwise H-statistics has computational complexity of :math:`O(n^2p^2)`.
Therefore, our implementation randomly selects ``n_max = 500`` rows from the provided dataset ``X``.
Furthermore, if the number of features :math:`p` is large, use some feature importance measure
to select the most important features and pass them via the ``features=None`` argument.

Limitations
===========

1. H-statistics are based on partial dependence estimates. Therefore, they are
just as good or poor as these. The major problem of partial dependence is
the application of the model to unseen and/or impossible feature combinations.
H-statistics, which should actually lie in the range between 0 and 1,
can become greater than 1 in extreme cases.
2. Due to their computational complexity, H-statistics are usually evaluated on
relatively small subsets of the data. Consequently, the estimates are
typically not very robust.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_inspection_plot_partial_dependence.py`

.. topic:: References

.. [F2008] J. H. Friedman and B. E. Popescu,
"Predictive Learning via Rule Ensembles",
The Annals of Applied Statistics, 2(3), 916-954, 2008.
58 changes: 55 additions & 3 deletions examples/inspection/plot_partial_dependence.py
Expand Up @@ -16,9 +16,12 @@
feature for each :term:`sample` separately, with one line per sample.
Only one feature of interest is supported for ICE plots.

This example shows how to obtain partial dependence and ICE plots from a
:class:`~sklearn.neural_network.MLPRegressor` and a
:class:`~sklearn.ensemble.HistGradientBoostingRegressor` trained on the
Furthermore, the concept of partial dependence is used to measure interaction
strength of feature pairs. The corresponding measure is called *H-statistic*
and has been introduced in [4]_.

To illustrate these methods, we use a :class:`~sklearn.neural_network.MLPRegressor`
and a :class:`~sklearn.ensemble.HistGradientBoostingRegressor` trained on the
bike sharing dataset. The example is inspired by [1]_.

.. [1] `Molnar, Christoph. "Interpretable machine learning.
Expand All @@ -32,6 +35,10 @@
"Peeking Inside the Black Box: Visualizing Statistical Learning With Plots of
Individual Conditional Expectation". Journal of Computational and
Graphical Statistics, 24(1): 44-65 <1309.6392>`

.. [4] Friedman, J. H. and Popescu, B. E. (2008).
"Predictive Learning via Rule Ensembles".
The Annals of Applied Statistics, 2(3), 916-954, 2008.
"""

# %%
Expand Down Expand Up @@ -567,3 +574,48 @@
plt.show()

# %%
# Interaction strength
# --------------------
#
# Above considerations show that comparing 2D PDPs with their univariate versions
# give hints about feature interactions. This idea is formalized by Friedman
# and Popescu's (pairwise) H-statistic, see [4]_.
# It measures how well the 2D partial dependence function can be approximated
# by the two one-dimensional partial dependence functions.
# The resulting value is then normalized and can be interpreted as proportion of effect
# variability explained by the interaction. Besides this relative measure, we advocate
# to also consider unnormalized statistics. They can directly be compared between
# feature pairs to see which interactions are strongest.
#
# Since computational burden is high, H-statistics are usually calculated only for
# important features, e.g., selected by permutation importance. What do we get for
# our model?
print("Select five important features and crunch H-statistics...")

from sklearn.inspection import h_statistic, permutation_importance

tic = time()
imp = permutation_importance(hgbdt_model, X=X_train, y=y_train, random_state=0)
features = X_train.columns[np.argsort(imp.importances_mean)][-5:]

H = h_statistic(hgbdt_model, X=X_train, features=features, random_state=0)

print(f"done in {time() - tic:.3f}s")

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
bar_labels = np.array([str(pair) for pair in H["feature_pairs"]])
stats = (H["h_squared_pairwise"], np.sqrt(H["numerator_pairwise"]))

for ax, stat, name in zip(axes, stats, ("$H^2$", "Unnormalized $H$")):
stat = stat.ravel()
idx = np.argsort(stat)
ax.barh(bar_labels[idx], stat[idx], color="orange")
ax.set(xlabel=name, title=name)
_ = fig.tight_layout()

# %%
# **The left plot** shows that the interaction between 'workingday' and 'hour'
# explains about 8% of their joint effect variability. For the other pairs, it is
# less than 5%. **The right plot** additionally shows that the interaction between
# 'workingday' and 'hour' is also largest in absolute terms (on the scale of the
# predictions).
2 changes: 2 additions & 0 deletions sklearn/inspection/__init__.py
@@ -1,11 +1,13 @@
"""The :mod:`sklearn.inspection` module includes tools for model inspection."""

from ._h_statistic import h_statistic
from ._partial_dependence import partial_dependence
from ._permutation_importance import permutation_importance
from ._plot.decision_boundary import DecisionBoundaryDisplay
from ._plot.partial_dependence import PartialDependenceDisplay

__all__ = [
"h_statistic",
"partial_dependence",
"permutation_importance",
"PartialDependenceDisplay",
Expand Down