diff --git a/.gitignore b/.gitignore index 0bb730f493a56..fe4337d090cf9 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,6 @@ sklearn/utils/_seq_dataset.pxd sklearn/utils/_weight_vector.pyx sklearn/utils/_weight_vector.pxd sklearn/linear_model/_sag_fast.pyx + +# Jupyter Notebook +.ipynb_checkpoints diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 83395c4180c44..4658d83a67d21 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -254,6 +254,13 @@ Changelog Setting a transformer to "passthrough" will pass the features unchanged. :pr:`20860` by :user:`Shubhraneel Pal `. +:mod:`sklearn.tree` +.................................. + +- |Enhancement| Added :func:`partial_fit` to :class:`tree.DecisionTreeClassifier` + and :class:`tree.ExtraTreeClassifier`. + :pr:`18889` by :user:`Haoyin Xu `. + :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 3cd0e000bd4dd..ca6de9060fe95 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -11,6 +11,7 @@ # Joly Arnaud # Fares Hedayati # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -36,6 +37,7 @@ from ..utils.validation import _check_sample_weight from ..utils import compute_sample_weight from ..utils.multiclass import check_classification_targets +from ..utils.multiclass import _check_partial_fit_first_call from ..utils.validation import check_is_fitted from ._criterion import Criterion @@ -147,7 +149,14 @@ def get_n_leaves(self): check_is_fitted(self) return self.tree_.n_leaves - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): random_state = check_random_state(self.random_state) @@ -201,24 +210,35 @@ def fit(self, X, y, sample_weight=None, check_input=True): check_classification_targets(y) y = np.copy(y) - self.classes_ = [] - self.n_classes_ = [] - if self.class_weight is not None: y_original = np.copy(y) - - y_encoded = np.zeros(y.shape, dtype=int) - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) - self.classes_.append(classes_k) - self.n_classes_.append(classes_k.shape[0]) - y = y_encoded - - if self.class_weight is not None: expanded_class_weight = compute_sample_weight( self.class_weight, y_original ) + self.classes_ = [] + self.n_classes_ = [] + + y_encoded = np.zeros(y.shape, dtype=int) + if classes is not None: + classes = np.atleast_1d(classes) + if classes.ndim == 1: + classes = np.array([classes]) + + for k in classes: + self.classes_.append(np.array(k)) + self.n_classes_.append(np.array(k).shape[0]) + + for i in range(n_samples): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][0] + else: + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + + y = y_encoded self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: @@ -374,7 +394,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): random_state, ) - if is_classifier(self): + if is_classification: self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) else: self.tree_ = Tree( @@ -386,7 +406,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: - builder = DepthFirstTreeBuilder( + self.builder_ = DepthFirstTreeBuilder( splitter, min_samples_split, min_samples_leaf, @@ -395,7 +415,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): self.min_impurity_decrease, ) else: - builder = BestFirstTreeBuilder( + self.builder_ = BestFirstTreeBuilder( splitter, min_samples_split, min_samples_leaf, @@ -405,9 +425,9 @@ def fit(self, X, y, sample_weight=None, check_input=True): self.min_impurity_decrease, ) - builder.build(self.tree_, X, y, sample_weight) + self.builder_.build(self.tree_, X, y, sample_weight) - if self.n_outputs_ == 1 and is_classifier(self): + if self.n_outputs_ == 1 and is_classification: self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] @@ -808,6 +828,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeRegressor : A decision tree regressor. @@ -884,7 +907,14 @@ def __init__( ccp_alpha=ccp_alpha, ) - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): """Build a decision tree classifier from the training set (X, y). Parameters @@ -908,6 +938,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): Allow to bypass several input checking. Don't use this parameter unless you know what you do. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + Returns ------- self : DecisionTreeClassifier @@ -919,9 +954,109 @@ def fit(self, X, y, sample_weight=None, check_input=True): y, sample_weight=sample_weight, check_input=check_input, + classes=classes, ) return self + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + """Update a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + + first_call = _check_partial_fit_first_call(self, classes=classes) + + # Fit if no tree exists yet + if first_call: + self.fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + if check_input: + # Need to validate separately here. + # We can't pass multi_ouput=True because that would allow y to be + # csr. + check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + check_y_params = dict(ensure_2d=False, dtype=None) + X, y = self._validate_data( + X, y, reset=False, validate_separately=(check_X_params, check_y_params) + ) + if issparse(X): + X.sort_indices() + + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) + + if X.shape[1] != self.n_features_in_: + msg = "Number of features %d does not match previous data %d." + raise ValueError(msg % (X.shape[1], self.n_features_in_)) + + y = np.atleast_1d(y) + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + check_classification_targets(y) + y = np.copy(y) + + classes = self.classes_ + if self.n_outputs_ == 1: + classes = [classes] + + y_encoded = np.zeros(y.shape, dtype=int) + for i in range(X.shape[0]): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.update(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + def predict_proba(self, X, check_input=True): """Predict class probabilities of the input samples X. @@ -1185,6 +1320,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeClassifier : A decision tree classifier. @@ -1254,7 +1392,14 @@ def __init__( ccp_alpha=ccp_alpha, ) - def fit(self, X, y, sample_weight=None, check_input=True): + def fit( + self, + X, + y, + sample_weight=None, + check_input=True, + classes=None, + ): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1277,6 +1422,9 @@ def fit(self, X, y, sample_weight=None, check_input=True): Allow to bypass several input checking. Don't use this parameter unless you know what you do. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Returns ------- self : DecisionTreeRegressor @@ -1512,6 +1660,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeRegressor : An extremely randomized tree regressor. @@ -1752,6 +1903,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeClassifier : An extremely randomized tree classifier. diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 2c115d0bd6ea1..a690049c54157 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -203,12 +203,30 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) +def _check_n_classes(n_classes, expected_dtype): + if n_classes.ndim != 1: + raise ValueError( + f"Wrong dimensions for n_classes from the pickle: " + f"expected 1, got {n_classes.ndim}" + ) + + if n_classes.dtype == expected_dtype: + return n_classes + + # Handles both different endianness and different bitness + if n_classes.dtype.kind == "i" and n_classes.dtype.itemsize in [4, 8]: + return n_classes.astype(expected_dtype, casting="same_kind") + + raise ValueError( + "n_classes from the pickle has an incompatible dtype:\n" + f"- expected: {expected_dtype}\n" + f"- got: {n_classes.dtype}" + ) cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" - def __cinit__(self, SIZE_t n_outputs, - np.ndarray[SIZE_t, ndim=1] n_classes): + def __cinit__(self, SIZE_t n_outputs, np.ndarray n_classes): """Initialize attributes for this criterion. Parameters @@ -218,6 +236,10 @@ cdef class ClassificationCriterion(Criterion): n_classes : numpy.ndarray, dtype=SIZE_t The number of unique classes in each target """ + cdef SIZE_t dummy = 0 + size_t_dtype = np.array(dummy).dtype + + n_classes = _check_n_classes(n_classes, size_t_dtype) self.sample_weight = NULL self.samples = NULL diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 0874187ee98ae..89a73620ba02d 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -5,6 +5,7 @@ # Arnaud Joly # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -59,6 +60,10 @@ cdef class Tree: SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1 + cdef SIZE_t _update_node(self, SIZE_t parent, bint is_left, bint is_leaf, + SIZE_t feature, double threshold, double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 @@ -100,4 +105,6 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*) + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 3d60f16c51062..e57e7b45097f7 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -9,6 +9,7 @@ # Fares Hedayati # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -84,6 +85,11 @@ cdef class TreeBuilder: """Build a decision tree from the training set (X, y).""" pass + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=None): + """Update a decision tree with the training set (X, y).""" + pass + cdef inline _check_input(self, object X, np.ndarray y, np.ndarray sample_weight): """Check input dtype, layout and format""" @@ -128,6 +134,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self.max_depth = max_depth self.min_impurity_decrease = min_impurity_decrease + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(DepthFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, + self.max_depth, + self.min_impurity_decrease)) + cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" @@ -260,6 +274,175 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=None): + """Update a decision tree with the training set (X, y).""" + + # check input + X, y, sample_weight = self._check_input(X, y, sample_weight) + + cdef DOUBLE_t* sample_weight_ptr = NULL + if sample_weight is not None: + sample_weight_ptr = sample_weight.data + + # organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + X_copy = {} + y_copy = {} + for i in range(X.shape[0]): + depth_i = paths[i].indices.shape[0] - 1 + PARENT = depth_i - 1 + CHILD = depth_i + + if PARENT < 0: + parent_i = 0 + else: + parent_i = paths[i].indices[PARENT] + child_i = paths[i].indices[CHILD] + left = 0 + if tree.children_left[parent_i] == child_i: + left = 1 + + if (parent_i, left) in false_roots: + false_roots[(parent_i, left)][0] += 1 + X_copy[(parent_i, left)].append(X[i]) + y_copy[(parent_i, left)].append(y[i]) + else: + false_roots[(parent_i, left)] = [1, depth_i] + X_copy[(parent_i, left)] = [X[i]] + y_copy[(parent_i, left)] = [y[i]] + + X_list = [] + y_list = [] + for key, value in reversed(sorted(X_copy.items())): + X_list = X_list + value + y_list = y_list + y_copy[key] + cdef object X_new = np.array(X_list) + cdef np.ndarray y_new = np.array(y_list) + + # Parameters + cdef Splitter splitter = self.splitter + cdef SIZE_t max_depth = self.max_depth + cdef SIZE_t min_samples_leaf = self.min_samples_leaf + cdef double min_weight_leaf = self.min_weight_leaf + cdef SIZE_t min_samples_split = self.min_samples_split + cdef double min_impurity_decrease = self.min_impurity_decrease + + # Recursive partition (without actual recursion) + splitter.init(X_new, y_new, sample_weight_ptr) + + cdef SIZE_t start = 0 + cdef SIZE_t end = 0 + cdef SIZE_t depth + cdef SIZE_t parent + cdef bint is_left + cdef SIZE_t n_node_samples = splitter.n_samples + cdef double weighted_n_samples = splitter.weighted_n_samples + cdef double weighted_n_node_samples + cdef SplitRecord split + cdef SIZE_t node_id + + cdef double impurity + cdef SIZE_t n_constant_features + cdef bint is_leaf + cdef SIZE_t max_depth_seen = tree.max_depth + cdef int rc = 0 + + cdef Stack stack = Stack(INITIAL_STACK_SIZE) + cdef StackRecord stack_record + + # push reached leaf nodes onto stack + for key, value in reversed(sorted(false_roots.items())): + end += value[0] + rc = stack.push(start, end, value[1], key[0], key[1], + tree.impurity[key[0]], 0) + start += value[0] + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + + with nogil: + while not stack.is_empty(): + stack.pop(&stack_record) + + start = stack_record.start + end = stack_record.end + depth = stack_record.depth + parent = stack_record.parent + is_left = stack_record.is_left + impurity = stack_record.impurity + n_constant_features = stack_record.n_constant_features + + n_node_samples = end - start + splitter.node_reset(start, end, &weighted_n_node_samples) + + is_leaf = (depth >= max_depth or + n_node_samples < min_samples_split or + n_node_samples < 2 * min_samples_leaf or + weighted_n_node_samples < 2 * min_weight_leaf) + + if first: + impurity = splitter.node_impurity() + first = 0 + + # impurity == 0 with tolerance due to rounding errors + is_leaf = is_leaf or impurity <= EPSILON + + if not is_leaf: + splitter.node_split(impurity, &split, &n_constant_features) + # If EPSILON=0 in the below comparison, float precision + # issues stop splitting, producing trees that are + # dissimilar to v0.18 + is_leaf = (is_leaf or split.pos >= end or + (split.improvement + EPSILON < + min_impurity_decrease)) + + with gil: + if parent in false_roots: + node_id = tree._update_node(parent, is_left, is_leaf, + split.feature, split.threshold, + impurity, n_node_samples, + weighted_n_node_samples) + else: + node_id = tree._add_node(parent, is_left, is_leaf, + split.feature, split.threshold, + impurity, n_node_samples, + weighted_n_node_samples) + + if node_id == SIZE_MAX: + rc = -1 + break + + # Store value for all nodes, to facilitate tree/model + # inspection and interpretation + splitter.node_value(tree.value + node_id * tree.value_stride) + + if not is_leaf: + # Push right child on stack + rc = stack.push(split.pos, end, depth + 1, node_id, 0, + split.impurity_right, n_constant_features) + if rc == -1: + break + + # Push left child on stack + rc = stack.push(start, split.pos, depth + 1, node_id, 1, + split.impurity_left, n_constant_features) + if rc == -1: + break + + if depth > max_depth_seen: + max_depth_seen = depth + + if rc >= 0: + rc = tree._resize_c(tree.node_count) + + if rc >= 0: + tree.max_depth = max_depth_seen + if rc == -1: + raise MemoryError() # Best first builder ---------------------------------------------------------- @@ -295,6 +478,14 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(BestFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, self.max_depth, + self.max_leaf_nodes, + self.min_impurity_decrease)) + cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" @@ -407,6 +598,169 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=None): + """Update a decision tree with the training set (X, y).""" + + # check input + X, y, sample_weight = self._check_input(X, y, sample_weight) + + cdef DOUBLE_t* sample_weight_ptr = NULL + if sample_weight is not None: + sample_weight_ptr = sample_weight.data + + # organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + X_copy = {} + y_copy = {} + for i in range(X.shape[0]): + depth_i = paths[i].indices.shape[0] - 1 + PARENT = depth_i - 1 + CHILD = depth_i + + if PARENT < 0: + parent_i = _TREE_UNDEFINED + else: + parent_i = paths[i].indices[PARENT] + child_i = paths[i].indices[CHILD] + left = 0 + if tree.children_left[parent_i] == child_i: + left = 1 + + if (parent_i, left) in false_roots: + false_roots[(parent_i, left)][0] += 1 + X_copy[(parent_i, left)].append(X[i]) + y_copy[(parent_i, left)].append(y[i]) + else: + false_roots[(parent_i, left)] = [1, depth_i] + X_copy[(parent_i, left)] = [X[i]] + y_copy[(parent_i, left)] = [y[i]] + + X_list = [] + y_list = [] + for key, value in sorted(X_copy.items()): + X_list = X_list + value + y_list = y_list + y_copy[key] + cdef object X_new = np.array(X_list) + cdef np.ndarray y_new = np.array(y_list) + + # Parameters + cdef Splitter splitter = self.splitter + cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes + cdef SIZE_t min_samples_leaf = self.min_samples_leaf + cdef double min_weight_leaf = self.min_weight_leaf + cdef SIZE_t min_samples_split = self.min_samples_split + + # Recursive partition (without actual recursion) + splitter.init(X_new, y_new, sample_weight_ptr) + + cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) + cdef PriorityHeapRecord record + cdef PriorityHeapRecord split_node_left + cdef PriorityHeapRecord split_node_right + + cdef SIZE_t n_node_samples = splitter.n_samples + cdef SIZE_t max_split_nodes = max_leaf_nodes - 1 + cdef bint is_leaf + cdef SIZE_t max_depth_seen = tree.max_depth + cdef int rc = 0 + cdef Node* node + + # Initial capacity + cdef SIZE_t init_capacity = max_split_nodes + max_leaf_nodes + tree._resize(init_capacity) + + # add reached leaf nodes to frontier + cdef SIZE_t start = 0 + cdef SIZE_t end = 0 + for key, value in sorted(false_roots.items()): + end += value[0] + if key[1]: + rc = self._update_split_node(splitter, tree, start, end, + tree.impurity[key[0]], IS_NOT_FIRST, + IS_LEFT, &tree.nodes[key[0]], + value[1], &split_node_left) + if rc >= 0: + rc = _add_to_frontier(&split_node_left, frontier) + else: + rc = self._update_split_node(splitter, tree, start, end, + tree.impurity[key[0]], IS_NOT_FIRST, + IS_NOT_LEFT, &tree.nodes[key[0]], + value[1], &split_node_right) + if rc >= 0: + rc = _add_to_frontier(&split_node_right, frontier) + start += value[0] + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + + with nogil: + while not frontier.is_empty(): + frontier.pop(&record) + + node = &tree.nodes[record.node_id] + is_leaf = (record.is_leaf or max_split_nodes <= 0) + + if is_leaf: + # Node is not expandable; set node as leaf + node.left_child = _TREE_LEAF + node.right_child = _TREE_LEAF + node.feature = _TREE_UNDEFINED + node.threshold = _TREE_UNDEFINED + + else: + # Node is expandable + + # Decrement number of split nodes available + max_split_nodes -= 1 + + # Compute left split node + rc = self._add_split_node(splitter, tree, + record.start, record.pos, + record.impurity_left, + IS_NOT_FIRST, IS_LEFT, node, + record.depth + 1, + &split_node_left) + if rc == -1: + break + + # tree.nodes may have changed + node = &tree.nodes[record.node_id] + + # Compute right split node + rc = self._add_split_node(splitter, tree, record.pos, + record.end, + record.impurity_right, + IS_NOT_FIRST, IS_NOT_LEFT, node, + record.depth + 1, + &split_node_right) + if rc == -1: + break + + # Add nodes to queue + rc = _add_to_frontier(&split_node_left, frontier) + if rc == -1: + break + + rc = _add_to_frontier(&split_node_right, frontier) + if rc == -1: + break + + if record.depth > max_depth_seen: + max_depth_seen = record.depth + + if rc >= 0: + rc = tree._resize_c(tree.node_count) + + if rc >= 0: + tree.max_depth = max_depth_seen + + if rc == -1: + raise MemoryError() + cdef inline int _add_split_node(self, Splitter splitter, Tree tree, SIZE_t start, SIZE_t end, double impurity, bint is_first, bint is_left, Node* parent, @@ -480,6 +834,79 @@ cdef class BestFirstTreeBuilder(TreeBuilder): return 0 + cdef inline int _update_split_node(self, Splitter splitter, Tree tree, + SIZE_t start, SIZE_t end, double impurity, + bint is_first, bint is_left, Node* parent, + SIZE_t depth, + PriorityHeapRecord* res) nogil except -1: + """Updates node w/ partition ``[start, end)`` to the frontier. """ + cdef SplitRecord split + cdef SIZE_t node_id + cdef SIZE_t n_node_samples + cdef SIZE_t n_constant_features = 0 + cdef double weighted_n_samples = splitter.weighted_n_samples + cdef double min_impurity_decrease = self.min_impurity_decrease + cdef double weighted_n_node_samples + cdef bint is_leaf + cdef SIZE_t n_left, n_right + cdef double imp_diff + + splitter.node_reset(start, end, &weighted_n_node_samples) + + if is_first: + impurity = splitter.node_impurity() + + n_node_samples = end - start + is_leaf = (depth >= self.max_depth or + n_node_samples < self.min_samples_split or + n_node_samples < 2 * self.min_samples_leaf or + weighted_n_node_samples < 2 * self.min_weight_leaf or + impurity <= EPSILON # impurity == 0 with tolerance + ) + if not is_leaf: + splitter.node_split(impurity, &split, &n_constant_features) + # If EPSILON=0 in the below comparison, float precision issues stop + # splitting early, producing trees that are dissimilar to v0.18 + is_leaf = (is_leaf or split.pos >= end or + split.improvement + EPSILON < min_impurity_decrease) + + node_id = tree._update_node(parent - tree.nodes + if parent != NULL + else _TREE_UNDEFINED, + is_left, is_leaf, + split.feature, split.threshold, + impurity, n_node_samples, + weighted_n_node_samples) + if node_id == SIZE_MAX: + return -1 + + # compute values also for split nodes (might become leafs later). + splitter.node_value(tree.value + node_id * tree.value_stride) + + res.node_id = node_id + res.start = start + res.end = end + res.depth = depth + res.impurity = impurity + + if not is_leaf: + # is split node + res.pos = split.pos + res.is_leaf = 0 + res.improvement = split.improvement + res.impurity_left = split.impurity_left + res.impurity_right = split.impurity_right + + else: + # is leaf => 0 improvement + res.pos = end + res.is_leaf = 1 + res.improvement = 0.0 + res.impurity_left = impurity + res.impurity_right = impurity + + return 0 + # ============================================================================= # Tree @@ -750,6 +1177,40 @@ cdef class Tree: return node_id + cdef SIZE_t _update_node(self, SIZE_t parent, bint is_left, bint is_leaf, + SIZE_t feature, double threshold, double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1: + """Update a node on the tree. + The updated node remains on the same position. + Returns (size_t)(-1) on error. + """ + cdef SIZE_t node_id + if is_left: + node_id = self.nodes[parent].left_child + else: + node_id = self.nodes[parent].right_child + + if node_id >= self.capacity: + if self._resize_c() != 0: + return SIZE_MAX + + cdef Node* node = &self.nodes[node_id] + node.impurity = impurity + node.n_node_samples = n_node_samples + node.weighted_n_node_samples = weighted_n_node_samples + + if is_leaf: + node.left_child = _TREE_LEAF + node.right_child = _TREE_LEAF + node.feature = _TREE_UNDEFINED + node.threshold = _TREE_UNDEFINED + else: + node.feature = feature + node.threshold = threshold + + return node_id + cpdef np.ndarray predict(self, object X): """Predict target for X.""" out = self._get_value_ndarray().take(self.apply(X), axis=0,