Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add partial_fit function to DecisionTreeClassifier #18889

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
1542765
Start implementing the update function for trees
PSSF23 Nov 7, 2020
8ded0f7
Update _tree.pxd
PSSF23 Nov 7, 2020
d6d5879
Remove unused attribute
PSSF23 Nov 7, 2020
0ed0819
Remove duplicate operations
PSSF23 Nov 7, 2020
bebe2bc
Keep whole function for reference
PSSF23 Nov 8, 2020
6ca6725
Catch AttributeError
PSSF23 Nov 11, 2020
a403f5b
Evaluate tree building logic
PSSF23 Nov 15, 2020
cb4cf43
Follow node addition logic
PSSF23 Nov 15, 2020
eb7af31
Work with counting issues and overflowing trees
PSSF23 Nov 15, 2020
c24c87a
Work with high variability
PSSF23 Nov 16, 2020
5e6685c
Fix y coordinates
PSSF23 Nov 16, 2020
5f6c373
Duplicate sample organization
PSSF23 Nov 18, 2020
7ac15f2
Add _update_split_node function for BestFirstTree
PSSF23 Nov 18, 2020
2a94fa2
Work without max_leaf_nodes limit
PSSF23 Nov 18, 2020
d6c03a7
Update .gitignore
PSSF23 Nov 18, 2020
7a3985a
Remove capacity resetting
PSSF23 Nov 29, 2020
4f8605e
Resolve 1 node tree problem
PSSF23 Dec 7, 2020
11764a1
Optimize node order
PSSF23 Dec 20, 2020
02ca737
Update _tree.pyx
PSSF23 Jan 18, 2021
92f7e18
Optimize partial_fit api
PSSF23 Jan 21, 2021
ab51a53
Update from main branch to stream branch
PSSF23 Feb 2, 2021
f05a3b2
Fix linting
PSSF23 Feb 2, 2021
e1b6658
FIX add __reduce__ functions
PSSF23 Sep 14, 2021
f1a4174
Merge branch 'main' into stream
PSSF23 Sep 14, 2021
0a5420c
FIX black format the code
PSSF23 Sep 14, 2021
19893c3
FIX remove min_impurity_split
PSSF23 Sep 14, 2021
fdd1dfd
FIX update deprecated attribute
PSSF23 Sep 14, 2021
b4cbfa4
FIX optimize api & correct __cinit__
PSSF23 Sep 14, 2021
8f4b664
FIX optimize first partial_fit test
PSSF23 Sep 14, 2021
3562219
FIX remove FutureWarning filter
PSSF23 Sep 14, 2021
93ead2d
FIX modify partial_fit parameter
PSSF23 Sep 14, 2021
bfaa18c
FIX correct partial_fit parameter
PSSF23 Sep 14, 2021
73779c2
Revert "FIX remove FutureWarning filter"
PSSF23 Sep 15, 2021
7d724c1
FIX prevent feature number reset
PSSF23 Sep 15, 2021
d3f15ad
MAINT remove duplicate category
PSSF23 Sep 15, 2021
992e34a
FIX correct regressor partial_fit checks
PSSF23 Sep 15, 2021
2a72c8f
Revert "MAINT remove duplicate category"
PSSF23 Sep 15, 2021
68d2d7b
FIX change parameter order
PSSF23 Sep 15, 2021
631a953
DOC add classes parameter docstring
PSSF23 Sep 15, 2021
9ae93b8
EHN pass classes into first fit
PSSF23 Sep 15, 2021
23fd392
FIX add class indices
PSSF23 Sep 15, 2021
665ceef
FIX revert class changes
PSSF23 Sep 15, 2021
85689d2
EHN pass classes into first fit
PSSF23 Sep 15, 2021
71265db
FIX restrict partial_fit to classifiers
PSSF23 Sep 16, 2021
005e5fe
Merge branch 'scikit-learn:main' into stream
PSSF23 Oct 13, 2021
cd8864e
Merge branch 'scikit-learn:main' into stream
PSSF23 Oct 18, 2021
30a6237
Merge branch 'scikit-learn:main' into stream
PSSF23 Oct 21, 2021
814e67e
DOC add changelog
PSSF23 Oct 21, 2021
46a9ccc
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
PSSF23 Nov 5, 2021
f0d0eb0
DOC optimize log format
PSSF23 Nov 17, 2021
e9e62e4
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
PSSF23 Nov 17, 2021
aef4f84
FIX remove deprecated parameter
PSSF23 Nov 17, 2021
7d9ff8b
Merge branch 'main' into stream
PSSF23 Nov 26, 2021
9ba887f
Merge branch 'scikit-learn:main' into stream
PSSF23 Nov 29, 2021
55a6b4b
FIX optimize n_classes format
PSSF23 Nov 30, 2021
8d3f5c7
FIX add internal function
PSSF23 Nov 30, 2021
c47bbb2
MNT remove unnecessary checks
PSSF23 Dec 1, 2021
a2aab5f
Merge branch 'scikit-learn:main' into stream
PSSF23 Dec 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -82,3 +82,6 @@ sklearn/utils/_seq_dataset.pxd
sklearn/utils/_weight_vector.pyx
sklearn/utils/_weight_vector.pxd
sklearn/linear_model/_sag_fast.pyx

# Jupyter Notebook
.ipynb_checkpoints
7 changes: 7 additions & 0 deletions doc/whats_new/v1.1.rst
Expand Up @@ -254,6 +254,13 @@ Changelog
Setting a transformer to "passthrough" will pass the features unchanged.
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.

:mod:`sklearn.tree`
..................................

- |Enhancement| Added :func:`partial_fit` to :class:`tree.DecisionTreeClassifier`
and :class:`tree.ExtraTreeClassifier`.
:pr:`18889` by :user:`Haoyin Xu <PSSF23>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
194 changes: 174 additions & 20 deletions sklearn/tree/_classes.py
Expand Up @@ -11,6 +11,7 @@
# Joly Arnaud <arnaud.v.joly@gmail.com>
# Fares Hedayati <fares.hedayati@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
# Haoyin Xu <haoyinxu@gmail.com>
#
# License: BSD 3 clause

Expand All @@ -36,6 +37,7 @@
from ..utils.validation import _check_sample_weight
from ..utils import compute_sample_weight
from ..utils.multiclass import check_classification_targets
from ..utils.multiclass import _check_partial_fit_first_call
from ..utils.validation import check_is_fitted

from ._criterion import Criterion
Expand Down Expand Up @@ -147,7 +149,14 @@ def get_n_leaves(self):
check_is_fitted(self)
return self.tree_.n_leaves

def fit(self, X, y, sample_weight=None, check_input=True):
def fit(
self,
X,
y,
sample_weight=None,
check_input=True,
classes=None,
):

random_state = check_random_state(self.random_state)

Expand Down Expand Up @@ -201,24 +210,35 @@ def fit(self, X, y, sample_weight=None, check_input=True):
check_classification_targets(y)
y = np.copy(y)

self.classes_ = []
self.n_classes_ = []

if self.class_weight is not None:
y_original = np.copy(y)

y_encoded = np.zeros(y.shape, dtype=int)
for k in range(self.n_outputs_):
classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])
y = y_encoded

if self.class_weight is not None:
expanded_class_weight = compute_sample_weight(
self.class_weight, y_original
)

self.classes_ = []
self.n_classes_ = []

y_encoded = np.zeros(y.shape, dtype=int)
if classes is not None:
classes = np.atleast_1d(classes)
if classes.ndim == 1:
classes = np.array([classes])

for k in classes:
self.classes_.append(np.array(k))
self.n_classes_.append(np.array(k).shape[0])

for i in range(n_samples):
for j in range(self.n_outputs_):
y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][0]
else:
for k in range(self.n_outputs_):
classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])

y = y_encoded
self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
Expand Down Expand Up @@ -374,7 +394,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
random_state,
)

if is_classifier(self):
if is_classification:
self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
else:
self.tree_ = Tree(
Expand All @@ -386,7 +406,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):

# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
builder = DepthFirstTreeBuilder(
self.builder_ = DepthFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
Expand All @@ -395,7 +415,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
self.min_impurity_decrease,
)
else:
builder = BestFirstTreeBuilder(
self.builder_ = BestFirstTreeBuilder(
splitter,
min_samples_split,
min_samples_leaf,
Expand All @@ -405,9 +425,9 @@ def fit(self, X, y, sample_weight=None, check_input=True):
self.min_impurity_decrease,
)

builder.build(self.tree_, X, y, sample_weight)
self.builder_.build(self.tree_, X, y, sample_weight)

if self.n_outputs_ == 1 and is_classifier(self):
if self.n_outputs_ == 1 and is_classification:
self.n_classes_ = self.n_classes_[0]
self.classes_ = self.classes_[0]

Expand Down Expand Up @@ -808,6 +828,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.

builder_ : TreeBuilder instance
The underlying TreeBuilder object.

See Also
--------
DecisionTreeRegressor : A decision tree regressor.
Expand Down Expand Up @@ -884,7 +907,14 @@ def __init__(
ccp_alpha=ccp_alpha,
)

def fit(self, X, y, sample_weight=None, check_input=True):
def fit(
self,
X,
y,
sample_weight=None,
check_input=True,
classes=None,
):
"""Build a decision tree classifier from the training set (X, y).

Parameters
Expand All @@ -908,6 +938,11 @@ def fit(self, X, y, sample_weight=None, check_input=True):
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.

classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.
Must be provided at the first call to partial_fit, can be omitted
in subsequent calls.

Returns
-------
self : DecisionTreeClassifier
Expand All @@ -919,9 +954,109 @@ def fit(self, X, y, sample_weight=None, check_input=True):
y,
sample_weight=sample_weight,
check_input=check_input,
classes=classes,
)
return self

def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True):
"""Update a decision tree classifier from the training set (X, y).

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The training input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csc_matrix``.

y : array-like of shape (n_samples,) or (n_samples, n_outputs)
The target values (class labels) as integers or strings.

classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.
Must be provided at the first call to partial_fit, can be omitted
in subsequent calls.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights. If None, then samples are equally weighted. Splits
that would create child nodes with net zero or negative weight are
ignored while searching for a split in each node. Splits are also
ignored if they would result in any single class carrying a
negative weight in either child node.

check_input : bool, default=True
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.

Returns
-------
self : DecisionTreeClassifier
Fitted estimator.
"""

first_call = _check_partial_fit_first_call(self, classes=classes)

# Fit if no tree exists yet
if first_call:
self.fit(
X,
y,
sample_weight=sample_weight,
check_input=check_input,
classes=classes,
)
return self

if check_input:
# Need to validate separately here.
# We can't pass multi_ouput=True because that would allow y to be
# csr.
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")
check_y_params = dict(ensure_2d=False, dtype=None)
X, y = self._validate_data(
X, y, reset=False, validate_separately=(check_X_params, check_y_params)
)
if issparse(X):
X.sort_indices()

if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
raise ValueError(
"No support for np.int64 index based sparse matrices"
)

if X.shape[1] != self.n_features_in_:
msg = "Number of features %d does not match previous data %d."
raise ValueError(msg % (X.shape[1], self.n_features_in_))

y = np.atleast_1d(y)

if y.ndim == 1:
# reshape is necessary to preserve the data contiguity against vs
# [:, np.newaxis] that does not.
y = np.reshape(y, (-1, 1))

check_classification_targets(y)
y = np.copy(y)

classes = self.classes_
if self.n_outputs_ == 1:
classes = [classes]

y_encoded = np.zeros(y.shape, dtype=int)
for i in range(X.shape[0]):
for j in range(self.n_outputs_):
y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0]
y = y_encoded

if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y = np.ascontiguousarray(y, dtype=DOUBLE)

# Update tree
self.builder_.update(self.tree_, X, y, sample_weight)

self._prune_tree()

return self

def predict_proba(self, X, check_input=True):
"""Predict class probabilities of the input samples X.

Expand Down Expand Up @@ -1185,6 +1320,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.

builder_ : TreeBuilder instance
The underlying TreeBuilder object.

See Also
--------
DecisionTreeClassifier : A decision tree classifier.
Expand Down Expand Up @@ -1254,7 +1392,14 @@ def __init__(
ccp_alpha=ccp_alpha,
)

def fit(self, X, y, sample_weight=None, check_input=True):
def fit(
self,
X,
y,
sample_weight=None,
check_input=True,
classes=None,
):
"""Build a decision tree regressor from the training set (X, y).

Parameters
Expand All @@ -1277,6 +1422,9 @@ def fit(self, X, y, sample_weight=None, check_input=True):
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.

classes : array-like of shape (n_classes,), default=None
List of all the classes that can possibly appear in the y vector.

Returns
-------
self : DecisionTreeRegressor
Expand Down Expand Up @@ -1512,6 +1660,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.

builder_ : TreeBuilder instance
The underlying TreeBuilder object.

See Also
--------
ExtraTreeRegressor : An extremely randomized tree regressor.
Expand Down Expand Up @@ -1752,6 +1903,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor):
:ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
for basic usage of these attributes.

builder_ : TreeBuilder instance
The underlying TreeBuilder object.

See Also
--------
ExtraTreeClassifier : An extremely randomized tree classifier.
Expand Down
26 changes: 24 additions & 2 deletions sklearn/tree/_criterion.pyx
Expand Up @@ -203,12 +203,30 @@ cdef class Criterion:
- (self.weighted_n_left /
self.weighted_n_node_samples * impurity_left)))

def _check_n_classes(n_classes, expected_dtype):
if n_classes.ndim != 1:
raise ValueError(
f"Wrong dimensions for n_classes from the pickle: "
f"expected 1, got {n_classes.ndim}"
)

if n_classes.dtype == expected_dtype:
return n_classes

# Handles both different endianness and different bitness
if n_classes.dtype.kind == "i" and n_classes.dtype.itemsize in [4, 8]:
return n_classes.astype(expected_dtype, casting="same_kind")

raise ValueError(
"n_classes from the pickle has an incompatible dtype:\n"
f"- expected: {expected_dtype}\n"
f"- got: {n_classes.dtype}"
)

cdef class ClassificationCriterion(Criterion):
"""Abstract criterion for classification."""

def __cinit__(self, SIZE_t n_outputs,
np.ndarray[SIZE_t, ndim=1] n_classes):
def __cinit__(self, SIZE_t n_outputs, np.ndarray n_classes):
"""Initialize attributes for this criterion.

Parameters
Expand All @@ -218,6 +236,10 @@ cdef class ClassificationCriterion(Criterion):
n_classes : numpy.ndarray, dtype=SIZE_t
The number of unique classes in each target
"""
cdef SIZE_t dummy = 0
size_t_dtype = np.array(dummy).dtype

n_classes = _check_n_classes(n_classes, size_t_dtype)
self.sample_weight = NULL

self.samples = NULL
Expand Down