From 15427657d2ef2cdc27a6350f7e3c3ca27f21c265 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Fri, 6 Nov 2020 20:25:10 -0500 Subject: [PATCH 01/48] Start implementing the update function for trees --- sklearn/tree/_classes.py | 44 ++++++-- sklearn/tree/_tree.pyx | 238 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 271 insertions(+), 11 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 01d821d9e5c82..d9957f94be4f2 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -11,6 +11,7 @@ # Joly Arnaud # Fares Hedayati # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -140,7 +141,7 @@ def get_n_leaves(self): return self.tree_.n_leaves def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated"): + X_idx_sorted="deprecated", update_tree=False): random_state = check_random_state(self.random_state) @@ -213,6 +214,15 @@ def fit(self, X, y, sample_weight=None, check_input=True, if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: y = np.ascontiguousarray(y, dtype=DOUBLE) + if update_tree: + # TODO: find a way to build on previous tree + # Update tree + self.builder_.update(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + # Check parameters max_depth = (np.iinfo(np.int32).max if self.max_depth is None else self.max_depth) @@ -355,7 +365,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_weight_leaf, random_state) - if is_classifier(self): + if is_classification: self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) else: @@ -366,14 +376,14 @@ def fit(self, X, y, sample_weight=None, check_input=True, # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: - builder = DepthFirstTreeBuilder(splitter, min_samples_split, + self.builder_ = DepthFirstTreeBuilder(splitter, min_samples_split, min_samples_leaf, min_weight_leaf, max_depth, self.min_impurity_decrease, min_impurity_split) else: - builder = BestFirstTreeBuilder(splitter, min_samples_split, + self.builder_ = BestFirstTreeBuilder(splitter, min_samples_split, min_samples_leaf, min_weight_leaf, max_depth, @@ -381,9 +391,9 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.min_impurity_decrease, min_impurity_split) - builder.build(self.tree_, X, y, sample_weight) + self.builder_.build(self.tree_, X, y, sample_weight) - if self.n_outputs_ == 1 and is_classifier(self): + if self.n_outputs_ == 1 and is_classification: self.n_classes_ = self.n_classes_[0] self.classes_ = self.classes_[0] @@ -778,6 +788,9 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeRegressor : A decision tree regressor. @@ -853,7 +866,7 @@ def __init__(self, *, ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated"): + X_idx_sorted="deprecated", update_tree=False): """Build a decision tree classifier from the training set (X, y). Parameters @@ -883,6 +896,9 @@ def fit(self, X, y, sample_weight=None, check_input=True, .. deprecated :: 0.24 + update_tree : bool, default=False + Choice of updating the existing tree or creating a new one. + Returns ------- self : DecisionTreeClassifier @@ -893,7 +909,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, X, y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted) + X_idx_sorted=X_idx_sorted, + update_tree=update_tree) return self def predict_proba(self, X, check_input=True): @@ -1134,6 +1151,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- DecisionTreeClassifier : A decision tree classifier. @@ -1202,7 +1222,7 @@ def __init__(self, *, ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated"): + X_idx_sorted="deprecated", update_tree=False): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1231,6 +1251,9 @@ def fit(self, X, y, sample_weight=None, check_input=True, .. deprecated :: 0.24 + update_tree : bool, default=False + Choice of updating the existing tree or creating a new one. + Returns ------- self : DecisionTreeRegressor @@ -1241,7 +1264,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, X, y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted) + X_idx_sorted=X_idx_sorted, + update_tree=update_tree) return self def _compute_partial_dependence_recursion(self, grid, target_features): diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f4484ab1a3314..b4ede4275f3d3 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -13,6 +13,7 @@ # Fares Hedayati # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause @@ -87,6 +88,11 @@ cdef class TreeBuilder: """Build a decision tree from the training set (X, y).""" pass + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=None): + """Update a decision tree with the training set (X, y).""" + pass + cdef inline _check_input(self, object X, np.ndarray y, np.ndarray sample_weight): """Check input dtype, layout and format""" @@ -266,6 +272,128 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=None): + """Update a decision tree with the training set (X, y).""" + + # check input + X, y, sample_weight = self._check_input(X, y, sample_weight) + + cdef DOUBLE_t* sample_weight_ptr = NULL + if sample_weight is not None: + sample_weight_ptr = sample_weight.data + + # Parameters + cdef Splitter splitter = self.splitter + cdef SIZE_t max_depth = self.max_depth + cdef SIZE_t min_samples_leaf = self.min_samples_leaf + cdef double min_weight_leaf = self.min_weight_leaf + cdef SIZE_t min_samples_split = self.min_samples_split + cdef double min_impurity_decrease = self.min_impurity_decrease + cdef double min_impurity_split = self.min_impurity_split + + # Recursive partition (without actual recursion) + splitter.init(X, y, sample_weight_ptr) + + cdef SIZE_t start + cdef SIZE_t end + cdef SIZE_t depth + cdef SIZE_t parent + cdef bint is_left + cdef SIZE_t n_node_samples = splitter.n_samples + cdef double weighted_n_samples = splitter.weighted_n_samples + cdef double weighted_n_node_samples + cdef SplitRecord split + cdef SIZE_t node_id + + cdef double impurity = INFINITY + cdef SIZE_t n_constant_features + cdef bint is_leaf + cdef bint first = 1 + cdef SIZE_t max_depth_seen = tree.max_depth + cdef int rc = 0 + + cdef Stack stack = Stack(INITIAL_STACK_SIZE) + cdef StackRecord stack_record + + with nogil: + # push root node onto stack + rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0) + if rc == -1: + # got return code -1 - out-of-memory + with gil: + raise MemoryError() + + while not stack.is_empty(): + stack.pop(&stack_record) + + start = stack_record.start + end = stack_record.end + depth = stack_record.depth + parent = stack_record.parent + is_left = stack_record.is_left + impurity = stack_record.impurity + n_constant_features = stack_record.n_constant_features + + n_node_samples = end - start + splitter.node_reset(start, end, &weighted_n_node_samples) + + is_leaf = (depth >= max_depth or + n_node_samples < min_samples_split or + n_node_samples < 2 * min_samples_leaf or + weighted_n_node_samples < 2 * min_weight_leaf) + + if first: + impurity = splitter.node_impurity() + first = 0 + + is_leaf = (is_leaf or + (impurity <= min_impurity_split)) + + if not is_leaf: + splitter.node_split(impurity, &split, &n_constant_features) + # If EPSILON=0 in the below comparison, float precision + # issues stop splitting, producing trees that are + # dissimilar to v0.18 + is_leaf = (is_leaf or split.pos >= end or + (split.improvement + EPSILON < + min_impurity_decrease)) + + node_id = tree._add_node(parent, is_left, is_leaf, split.feature, + split.threshold, impurity, n_node_samples, + weighted_n_node_samples) + + if node_id == SIZE_MAX: + rc = -1 + break + + # Store value for all nodes, to facilitate tree/model + # inspection and interpretation + splitter.node_value(tree.value + node_id * tree.value_stride) + + if not is_leaf: + # Push right child on stack + rc = stack.push(split.pos, end, depth + 1, node_id, 0, + split.impurity_right, n_constant_features) + if rc == -1: + break + + # Push left child on stack + rc = stack.push(start, split.pos, depth + 1, node_id, 1, + split.impurity_left, n_constant_features) + if rc == -1: + break + + if depth > max_depth_seen: + max_depth_seen = depth + + if rc >= 0: + rc = tree._resize_c(tree.node_count) + + if rc >= 0: + tree.max_depth = max_depth_seen + if rc == -1: + raise MemoryError() # Best first builder ---------------------------------------------------------- @@ -414,6 +542,114 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if rc == -1: raise MemoryError() + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=None): + """Update a decision tree with the training set (X, y).""" + + # check input + X, y, sample_weight = self._check_input(X, y, sample_weight) + + cdef DOUBLE_t* sample_weight_ptr = NULL + if sample_weight is not None: + sample_weight_ptr = sample_weight.data + + # Parameters + cdef Splitter splitter = self.splitter + cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes + cdef SIZE_t min_samples_leaf = self.min_samples_leaf + cdef double min_weight_leaf = self.min_weight_leaf + cdef SIZE_t min_samples_split = self.min_samples_split + + # Recursive partition (without actual recursion) + splitter.init(X, y, sample_weight_ptr) + + cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) + cdef PriorityHeapRecord record + cdef PriorityHeapRecord split_node_left + cdef PriorityHeapRecord split_node_right + + cdef SIZE_t n_node_samples = splitter.n_samples + cdef SIZE_t max_split_nodes = max_leaf_nodes - 1 + cdef bint is_leaf + cdef SIZE_t max_depth_seen = tree.max_depth + cdef int rc = 0 + cdef Node* node + + with nogil: + # add root to frontier + rc = self._add_split_node(splitter, tree, 0, n_node_samples, + INFINITY, IS_FIRST, IS_LEFT, NULL, 0, + &split_node_left) + if rc >= 0: + rc = _add_to_frontier(&split_node_left, frontier) + + if rc == -1: + with gil: + raise MemoryError() + + while not frontier.is_empty(): + frontier.pop(&record) + + node = &tree.nodes[record.node_id] + is_leaf = (record.is_leaf or max_split_nodes <= 0) + + if is_leaf: + # Node is not expandable; set node as leaf + node.left_child = _TREE_LEAF + node.right_child = _TREE_LEAF + node.feature = _TREE_UNDEFINED + node.threshold = _TREE_UNDEFINED + + else: + # Node is expandable + + # Decrement number of split nodes available + max_split_nodes -= 1 + + # Compute left split node + rc = self._add_split_node(splitter, tree, + record.start, record.pos, + record.impurity_left, + IS_NOT_FIRST, IS_LEFT, node, + record.depth + 1, + &split_node_left) + if rc == -1: + break + + # tree.nodes may have changed + node = &tree.nodes[record.node_id] + + # Compute right split node + rc = self._add_split_node(splitter, tree, record.pos, + record.end, + record.impurity_right, + IS_NOT_FIRST, IS_NOT_LEFT, node, + record.depth + 1, + &split_node_right) + if rc == -1: + break + + # Add nodes to queue + rc = _add_to_frontier(&split_node_left, frontier) + if rc == -1: + break + + rc = _add_to_frontier(&split_node_right, frontier) + if rc == -1: + break + + if record.depth > max_depth_seen: + max_depth_seen = record.depth + + if rc >= 0: + rc = tree._resize_c(tree.node_count) + + if rc >= 0: + tree.max_depth = max_depth_seen + + if rc == -1: + raise MemoryError() + cdef inline int _add_split_node(self, Splitter splitter, Tree tree, SIZE_t start, SIZE_t end, double impurity, bint is_first, bint is_left, Node* parent, @@ -655,7 +891,7 @@ cdef class Tree: if (node_ndarray.dtype != NODE_DTYPE): # possible mismatch of big/little endian due to serialization - # on a different architecture. Try swapping the byte order. + # on a different architecture. Try swapping the byte order. node_ndarray = node_ndarray.byteswap().newbyteorder() if (node_ndarray.dtype != NODE_DTYPE): raise ValueError('Did not recognise loaded array dytpe') From 8ded0f74ed5383bcb04ca7606e12464e274d6164 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Fri, 6 Nov 2020 20:30:30 -0500 Subject: [PATCH 02/48] Update _tree.pxd --- sklearn/tree/_tree.pxd | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 8957f0342892a..f82e07dd5b167 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -101,4 +101,6 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*) + cpdef update(self, Tree tree, object X, np.ndarray y, + np.ndarray sample_weight=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) From d6d58799e73638a03533803898a63e8ab8182122 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Fri, 6 Nov 2020 21:08:22 -0500 Subject: [PATCH 03/48] Remove unused attribute --- sklearn/tree/_classes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index d9957f94be4f2..c5769850fe4a2 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -174,7 +174,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, # Determine output settings n_samples, self.n_features_ = X.shape - self.n_features_in_ = self.n_features_ is_classification = is_classifier(self) y = np.atleast_1d(y) From 0ed0819ee664cdd357b7c5c958f4f6cfbdaa1846 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Fri, 6 Nov 2020 21:10:37 -0500 Subject: [PATCH 04/48] Remove duplicate operations --- sklearn/tree/_classes.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index c5769850fe4a2..93fdad397a874 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -184,6 +184,28 @@ def fit(self, X, y, sample_weight=None, check_input=True, # [:, np.newaxis] that does not. y = np.reshape(y, (-1, 1)) + if update_tree: + # TODO: find a way to build on previous tree + if is_classification: + check_classification_targets(y) + y = np.copy(y) + + y_encoded = np.zeros(y.shape, dtype=int) + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique(y[:, k], + return_inverse=True) + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.update(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + self.n_outputs_ = y.shape[1] if is_classification: @@ -213,15 +235,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: y = np.ascontiguousarray(y, dtype=DOUBLE) - if update_tree: - # TODO: find a way to build on previous tree - # Update tree - self.builder_.update(self.tree_, X, y, sample_weight) - - self._prune_tree() - - return self - # Check parameters max_depth = (np.iinfo(np.int32).max if self.max_depth is None else self.max_depth) From bebe2bc7dfe83656ed7a4633a55ea7419a741432 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Sun, 8 Nov 2020 16:49:57 -0500 Subject: [PATCH 05/48] Keep whole function for reference --- sklearn/tree/_tree.pyx | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index b4ede4275f3d3..11f4c1210754e 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -282,6 +282,16 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data + + # Initial capacity + cdef int init_capacity + + if tree.max_depth <= 10: + init_capacity = (2 ** (tree.max_depth + 1)) - 1 + else: + init_capacity = 2047 + + tree._resize(init_capacity) # Parameters cdef Splitter splitter = self.splitter @@ -575,6 +585,10 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef int rc = 0 cdef Node* node + # Initial capacity + cdef SIZE_t init_capacity = max_split_nodes + max_leaf_nodes + tree._resize(init_capacity) + with nogil: # add root to frontier rc = self._add_split_node(splitter, tree, 0, n_node_samples, From 6ca6725800a3a6c21e1f3d3d2a4c4a876300e161 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 11 Nov 2020 11:30:27 -0500 Subject: [PATCH 06/48] Catch AttributeError --- sklearn/tree/_classes.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 93fdad397a874..8d216360918e0 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -200,7 +200,11 @@ def fit(self, X, y, sample_weight=None, check_input=True, y = np.ascontiguousarray(y, dtype=DOUBLE) # Update tree - self.builder_.update(self.tree_, X, y, sample_weight) + try: + self.builder_.update(self.tree_, X, y, sample_weight) + except AttributeError: + print("No existing tree to update") + return self self._prune_tree() From a403f5b994ac8da2b9f06c8a52cd2f3048e57147 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Sun, 15 Nov 2020 15:17:47 -0500 Subject: [PATCH 07/48] Evaluate tree building logic --- sklearn/tree/_tree.pxd | 4 +++ sklearn/tree/_tree.pyx | 72 +++++++++++++++++++++++++++++++++--------- 2 files changed, 61 insertions(+), 15 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index f82e07dd5b167..97dd4d8b8edb4 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -59,6 +59,10 @@ cdef class Tree: SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, double weighted_n_samples) nogil except -1 + cdef SIZE_t _update_node(self, SIZE_t node_id, bint is_leaf, SIZE_t feature, + double threshold, double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 11f4c1210754e..f8ec935f15139 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -282,7 +282,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data - + # Initial capacity cdef int init_capacity @@ -293,6 +293,19 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): tree._resize(init_capacity) + # Organize samples by decision paths + paths = tree.decision_path(X) + leafs = {} + for i in range(X.shape[0]): + leaf = paths[i].indices[-1] + depth = paths[i].indices.shape[0] - 1 + if leaf in leafs: + leafs[leaf][0] += 1 + else: + leafs[leaf] = [1, depth] + + X_copy = [] + # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_depth = self.max_depth @@ -305,8 +318,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight_ptr) - cdef SIZE_t start - cdef SIZE_t end + cdef SIZE_t start = 0 + cdef SIZE_t end = 0 cdef SIZE_t depth cdef SIZE_t parent cdef bint is_left @@ -316,10 +329,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SplitRecord split cdef SIZE_t node_id - cdef double impurity = INFINITY + cdef double impurity cdef SIZE_t n_constant_features cdef bint is_leaf - cdef bint first = 1 cdef SIZE_t max_depth_seen = tree.max_depth cdef int rc = 0 @@ -327,12 +339,15 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef StackRecord stack_record with nogil: - # push root node onto stack - rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0) - if rc == -1: - # got return code -1 - out-of-memory - with gil: - raise MemoryError() + # push reached leaf nodes onto stack + for key, value in leafs.items(): + end += value[0] + rc = stack.push(start, end, value[1], key, 0, tree.impurity[key], 0) + start += value + if rc == -1: + # got return code -1 - out-of-memory + with gil: + raise MemoryError() while not stack.is_empty(): stack.pop(&stack_record) @@ -353,10 +368,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): n_node_samples < 2 * min_samples_leaf or weighted_n_node_samples < 2 * min_weight_leaf) - if first: - impurity = splitter.node_impurity() - first = 0 - is_leaf = (is_leaf or (impurity <= min_impurity_split)) @@ -1010,6 +1021,37 @@ cdef class Tree: return node_id + cdef SIZE_t _update_node(self, SIZE_t node_id, bint is_leaf, SIZE_t feature, + double threshold, double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1: + """Update a node on the tree. + + The updated node remains on the same position. + + Returns (size_t)(-1) on error. + """ + if node_id >= self.capacity: + if self._resize_c() != 0: + return SIZE_MAX + + cdef Node* node = &self.nodes[node_id] + node.impurity = impurity + node.n_node_samples = n_node_samples + node.weighted_n_node_samples = weighted_n_node_samples + + if is_leaf: + node.left_child = _TREE_LEAF + node.right_child = _TREE_LEAF + node.feature = _TREE_UNDEFINED + node.threshold = _TREE_UNDEFINED + + else: + node.feature = feature + node.threshold = threshold + + return node_id + cpdef np.ndarray predict(self, object X): """Predict target for X.""" out = self._get_value_ndarray().take(self.apply(X), axis=0, From cb4cf43c5ec5e4674f0fdf50af257e2cec5a5810 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Sun, 15 Nov 2020 16:42:34 -0500 Subject: [PATCH 08/48] Follow node addition logic --- sklearn/tree/_tree.pxd | 4 --- sklearn/tree/_tree.pyx | 62 ++++++++++++------------------------------ 2 files changed, 17 insertions(+), 49 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 97dd4d8b8edb4..f82e07dd5b167 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -59,10 +59,6 @@ cdef class Tree: SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, double weighted_n_samples) nogil except -1 - cdef SIZE_t _update_node(self, SIZE_t node_id, bint is_leaf, SIZE_t feature, - double threshold, double impurity, - SIZE_t n_node_samples, - double weighted_n_node_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f8ec935f15139..37cd98216cea0 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -293,19 +293,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): tree._resize(init_capacity) - # Organize samples by decision paths - paths = tree.decision_path(X) - leafs = {} - for i in range(X.shape[0]): - leaf = paths[i].indices[-1] - depth = paths[i].indices.shape[0] - 1 - if leaf in leafs: - leafs[leaf][0] += 1 - else: - leafs[leaf] = [1, depth] - - X_copy = [] - # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_depth = self.max_depth @@ -339,8 +326,24 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef StackRecord stack_record with nogil: + # Organize samples by decision paths + paths = tree.decision_path(X) + false_roots = {} + for i in range(X.shape[0]): + parent = paths[i].indices[-2] + leaf = paths[i].indices[-1] + depth = paths[i].indices.shape[0] - 1 + left = 0 + + if parent in false_roots: + false_roots[parent][0] += 1 + else: + if tree.children_left[parent] == leaf: + left = 1 + false_roots[parent] = [1, depth, left] + # push reached leaf nodes onto stack - for key, value in leafs.items(): + for key, value in sorted(false_roots.items()): end += value[0] rc = stack.push(start, end, value[1], key, 0, tree.impurity[key], 0) start += value @@ -1021,37 +1024,6 @@ cdef class Tree: return node_id - cdef SIZE_t _update_node(self, SIZE_t node_id, bint is_leaf, SIZE_t feature, - double threshold, double impurity, - SIZE_t n_node_samples, - double weighted_n_node_samples) nogil except -1: - """Update a node on the tree. - - The updated node remains on the same position. - - Returns (size_t)(-1) on error. - """ - if node_id >= self.capacity: - if self._resize_c() != 0: - return SIZE_MAX - - cdef Node* node = &self.nodes[node_id] - node.impurity = impurity - node.n_node_samples = n_node_samples - node.weighted_n_node_samples = weighted_n_node_samples - - if is_leaf: - node.left_child = _TREE_LEAF - node.right_child = _TREE_LEAF - node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED - - else: - node.feature = feature - node.threshold = threshold - - return node_id - cpdef np.ndarray predict(self, object X): """Predict target for X.""" out = self._get_value_ndarray().take(self.apply(X), axis=0, From eb7af31988f1ddff1a0b543cc6232dc518cdfc8a Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Sun, 15 Nov 2020 18:17:35 -0500 Subject: [PATCH 09/48] Work with counting issues and overflowing trees --- sklearn/tree/_tree.pyx | 56 ++++++++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 37cd98216cea0..740defba1b66d 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -325,33 +325,37 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef Stack stack = Stack(INITIAL_STACK_SIZE) cdef StackRecord stack_record - with nogil: - # Organize samples by decision paths - paths = tree.decision_path(X) - false_roots = {} - for i in range(X.shape[0]): - parent = paths[i].indices[-2] - leaf = paths[i].indices[-1] - depth = paths[i].indices.shape[0] - 1 - left = 0 - - if parent in false_roots: - false_roots[parent][0] += 1 - else: - if tree.children_left[parent] == leaf: - left = 1 - false_roots[parent] = [1, depth, left] - - # push reached leaf nodes onto stack - for key, value in sorted(false_roots.items()): - end += value[0] - rc = stack.push(start, end, value[1], key, 0, tree.impurity[key], 0) - start += value - if rc == -1: - # got return code -1 - out-of-memory - with gil: - raise MemoryError() + # Organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + for i in range(X.shape[0]): + depth = paths[i].indices.shape[0] - 1 + PARENT = depth - 1 + CHILD = depth + + parent = paths[i].indices[PARENT] + child = paths[i].indices[CHILD] + left = 0 + + if parent in false_roots: + false_roots[parent][0] += 1 + else: + if tree.children_left[parent] == child: + left = 1 + false_roots[parent] = [1, depth, left] + + # push reached leaf nodes onto stack + for key, value in sorted(false_roots.items()): + end += value[0] + rc = stack.push(start, end, value[1], key, value[2], tree.impurity[key], 0) + start += value[0] + if rc == -1: + # got return code -1 - out-of-memory + raise MemoryError() + with nogil: while not stack.is_empty(): stack.pop(&stack_record) From c24c87abb1799c3de4a51fee84adc5eff54aeef7 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Sun, 15 Nov 2020 20:23:59 -0500 Subject: [PATCH 10/48] Work with high variability --- sklearn/tree/_tree.pxd | 6 +++++- sklearn/tree/_tree.pyx | 48 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index f82e07dd5b167..dde2a6e7e9370 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -58,7 +58,11 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_samples) nogil except -1 + double weighted_n_node_samples) nogil except -1 + cdef SIZE_t _update_node(self, SIZE_t parent, bint is_left, bint is_leaf, + SIZE_t feature, double threshold, double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 740defba1b66d..1b8a23e49a950 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -387,9 +387,17 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (split.improvement + EPSILON < min_impurity_decrease)) - node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + with gil: + if parent in false_roots: + node_id = tree._update_node(parent, is_left, is_leaf, + split.feature, split.threshold, + impurity, n_node_samples, + weighted_n_node_samples) + else: + node_id = tree._add_node(parent, is_left, is_leaf, + split.feature, split.threshold, + impurity, n_node_samples, + weighted_n_node_samples) if node_id == SIZE_MAX: rc = -1 @@ -1028,6 +1036,40 @@ cdef class Tree: return node_id + cdef SIZE_t _update_node(self, SIZE_t parent, bint is_left, bint is_leaf, + SIZE_t feature, double threshold, double impurity, + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1: + """Update a node on the tree. + The updated node remains on the same position. + Returns (size_t)(-1) on error. + """ + cdef SIZE_t node_id + if is_left: + node_id = self.nodes[parent].left_child + else: + node_id = self.nodes[parent].right_child + + if node_id >= self.capacity: + if self._resize_c() != 0: + return SIZE_MAX + + cdef Node* node = &self.nodes[node_id] + node.impurity = impurity + node.n_node_samples = n_node_samples + node.weighted_n_node_samples = weighted_n_node_samples + + if is_leaf: + node.left_child = _TREE_LEAF + node.right_child = _TREE_LEAF + node.feature = _TREE_UNDEFINED + node.threshold = _TREE_UNDEFINED + else: + node.feature = feature + node.threshold = threshold + + return node_id + cpdef np.ndarray predict(self, object X): """Predict target for X.""" out = self._get_value_ndarray().take(self.apply(X), axis=0, From 5e6685ccf672ed558d7ae333b04928d0074e3d4a Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Mon, 16 Nov 2020 00:15:19 -0500 Subject: [PATCH 11/48] Fix y coordinates --- sklearn/tree/_classes.py | 13 ++++---- sklearn/tree/_tree.pxd | 1 + sklearn/tree/_tree.pyx | 64 +++++++++++++++++++++++++--------------- 3 files changed, 50 insertions(+), 28 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 8d216360918e0..e5506d9c80f40 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -186,6 +186,13 @@ def fit(self, X, y, sample_weight=None, check_input=True, if update_tree: # TODO: find a way to build on previous tree + # See if there is an existing tree + try: + self.tree_ + except AttributeError: + print("No existing tree to update") + return self + if is_classification: check_classification_targets(y) y = np.copy(y) @@ -200,11 +207,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, y = np.ascontiguousarray(y, dtype=DOUBLE) # Update tree - try: - self.builder_.update(self.tree_, X, y, sample_weight) - except AttributeError: - print("No existing tree to update") - return self + self.builder_.update(self.tree_, X, y, sample_weight) self._prune_tree() diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index dde2a6e7e9370..b55a127b3c96e 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -5,6 +5,7 @@ # Arnaud Joly # Jacob Schreiber # Nelson Liu +# Haoyin Xu # # License: BSD 3 clause diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 1b8a23e49a950..cf1a5aad3303e 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -279,6 +279,44 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # check input X, y, sample_weight = self._check_input(X, y, sample_weight) + # organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + X_copy = {} + y_copy = {} + for i in range(X.shape[0]): + depth_i = paths[i].indices.shape[0] - 1 + PARENT = depth_i - 1 + CHILD = depth_i + + if PARENT < 0: + parent_i = _TREE_UNDEFINED + else: + parent_i = paths[i].indices[PARENT] + child_i = paths[i].indices[CHILD] + left = 0 + if tree.children_left[parent_i] == child_i: + left = 1 + + if (parent_i, left) in false_roots: + false_roots[(parent_i, left)][0] += 1 + X_copy[(parent_i, left)].append(X[i]) + y_copy[(parent_i, left)].append(y[i]) + else: + false_roots[(parent_i, left)] = [1, depth_i] + X_copy[(parent_i, left)] = [X[i]] + y_copy[(parent_i, left)] = [y[i]] + + X_list = [] + y_list = [] + for key, value in sorted(X_copy.items()): + X_list = X_list + value + y_list = y_list + y_copy[key] + cdef object X_new = np.array(X_list) + cdef np.ndarray y_new = np.array(y_list) + cdef DOUBLE_t* sample_weight_ptr = NULL if sample_weight is not None: sample_weight_ptr = sample_weight.data @@ -303,7 +341,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr) + splitter.init(X_new, y_new, sample_weight_ptr) cdef SIZE_t start = 0 cdef SIZE_t end = 0 @@ -325,31 +363,11 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef Stack stack = Stack(INITIAL_STACK_SIZE) cdef StackRecord stack_record - # Organize samples by decision paths - paths = tree.decision_path(X) - cdef int PARENT - cdef int CHILD - false_roots = {} - for i in range(X.shape[0]): - depth = paths[i].indices.shape[0] - 1 - PARENT = depth - 1 - CHILD = depth - - parent = paths[i].indices[PARENT] - child = paths[i].indices[CHILD] - left = 0 - - if parent in false_roots: - false_roots[parent][0] += 1 - else: - if tree.children_left[parent] == child: - left = 1 - false_roots[parent] = [1, depth, left] - # push reached leaf nodes onto stack for key, value in sorted(false_roots.items()): end += value[0] - rc = stack.push(start, end, value[1], key, value[2], tree.impurity[key], 0) + rc = stack.push(start, end, value[1], key[0], key[1], + tree.impurity[key[0]], 0) start += value[0] if rc == -1: # got return code -1 - out-of-memory From 5f6c373a684182b7b6fb85287ca8b815401b608d Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 17 Nov 2020 21:24:12 -0500 Subject: [PATCH 12/48] Duplicate sample organization --- sklearn/tree/_tree.pyx | 48 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index cf1a5aad3303e..f3c9973a93890 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -279,6 +279,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # check input X, y, sample_weight = self._check_input(X, y, sample_weight) + cdef DOUBLE_t* sample_weight_ptr = NULL + if sample_weight is not None: + sample_weight_ptr = sample_weight.data + # organize samples by decision paths paths = tree.decision_path(X) cdef int PARENT @@ -317,10 +321,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef object X_new = np.array(X_list) cdef np.ndarray y_new = np.array(y_list) - cdef DOUBLE_t* sample_weight_ptr = NULL - if sample_weight is not None: - sample_weight_ptr = sample_weight.data - # Initial capacity cdef int init_capacity @@ -607,6 +607,44 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + # organize samples by decision paths + paths = tree.decision_path(X) + cdef int PARENT + cdef int CHILD + false_roots = {} + X_copy = {} + y_copy = {} + for i in range(X.shape[0]): + depth_i = paths[i].indices.shape[0] - 1 + PARENT = depth_i - 1 + CHILD = depth_i + + if PARENT < 0: + parent_i = _TREE_UNDEFINED + else: + parent_i = paths[i].indices[PARENT] + child_i = paths[i].indices[CHILD] + left = 0 + if tree.children_left[parent_i] == child_i: + left = 1 + + if (parent_i, left) in false_roots: + false_roots[(parent_i, left)][0] += 1 + X_copy[(parent_i, left)].append(X[i]) + y_copy[(parent_i, left)].append(y[i]) + else: + false_roots[(parent_i, left)] = [1, depth_i] + X_copy[(parent_i, left)] = [X[i]] + y_copy[(parent_i, left)] = [y[i]] + + X_list = [] + y_list = [] + for key, value in sorted(X_copy.items()): + X_list = X_list + value + y_list = y_list + y_copy[key] + cdef object X_new = np.array(X_list) + cdef np.ndarray y_new = np.array(y_list) + # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes @@ -615,7 +653,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr) + splitter.init(X_new, y_new, sample_weight_ptr) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record From 7ac15f2019162be4d71fa459c72b229c76076f9d Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 18 Nov 2020 11:11:33 -0500 Subject: [PATCH 13/48] Add _update_split_node function for BestFirstTree --- sklearn/tree/_tree.pyx | 107 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 97 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f3c9973a93890..df10b6ac39414 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -671,18 +671,31 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t init_capacity = max_split_nodes + max_leaf_nodes tree._resize(init_capacity) - with nogil: - # add root to frontier - rc = self._add_split_node(splitter, tree, 0, n_node_samples, - INFINITY, IS_FIRST, IS_LEFT, NULL, 0, - &split_node_left) - if rc >= 0: - rc = _add_to_frontier(&split_node_left, frontier) - + # add reached leaf nodes onto stack + cdef SIZE_t start = 0 + cdef SIZE_t end = 0 + for key, value in sorted(false_roots.items()): + end += value[0] + if key[1]: + rc = self._update_split_node(splitter, tree, start, end, + tree.impurity[key[0]], IS_NOT_FIRST, + IS_LEFT, &tree.nodes[key[0]] + value[1], &split_node_left) + if rc >= 0: + rc = _add_to_frontier(&split_node_left, frontier) + else: + rc = self._update_split_node(splitter, tree, start, end, + tree.impurity[key[0]], IS_NOT_FIRST, + IS_NOT_LEFT, &tree.nodes[key[0]] + value[1], &split_node_right) + if rc >= 0: + rc = _add_to_frontier(&split_node_right, frontier) + start += value[0] if rc == -1: - with gil: - raise MemoryError() + # got return code -1 - out-of-memory + raise MemoryError() + with nogil: while not frontier.is_empty(): frontier.pop(&record) @@ -819,6 +832,80 @@ cdef class BestFirstTreeBuilder(TreeBuilder): return 0 + cdef inline int _update_split_node(self, Splitter splitter, Tree tree, + SIZE_t start, SIZE_t end, double impurity, + bint is_first, bint is_left, Node* parent, + SIZE_t depth, + PriorityHeapRecord* res) nogil except -1: + """Updates node w/ partition ``[start, end)`` to the frontier. """ + cdef SplitRecord split + cdef SIZE_t node_id + cdef SIZE_t n_node_samples + cdef SIZE_t n_constant_features = 0 + cdef double weighted_n_samples = splitter.weighted_n_samples + cdef double min_impurity_decrease = self.min_impurity_decrease + cdef double min_impurity_split = self.min_impurity_split + cdef double weighted_n_node_samples + cdef bint is_leaf + cdef SIZE_t n_left, n_right + cdef double imp_diff + + splitter.node_reset(start, end, &weighted_n_node_samples) + + if is_first: + impurity = splitter.node_impurity() + + n_node_samples = end - start + is_leaf = (depth >= self.max_depth or + n_node_samples < self.min_samples_split or + n_node_samples < 2 * self.min_samples_leaf or + weighted_n_node_samples < 2 * self.min_weight_leaf or + impurity <= min_impurity_split) + + if not is_leaf: + splitter.node_split(impurity, &split, &n_constant_features) + # If EPSILON=0 in the below comparison, float precision issues stop + # splitting early, producing trees that are dissimilar to v0.18 + is_leaf = (is_leaf or split.pos >= end or + split.improvement + EPSILON < min_impurity_decrease) + + node_id = tree._update_node(parent - tree.nodes + if parent != NULL + else _TREE_UNDEFINED, + is_left, is_leaf, + split.feature, split.threshold, + impurity, n_node_samples, + weighted_n_node_samples) + if node_id == SIZE_MAX: + return -1 + + # compute values also for split nodes (might become leafs later). + splitter.node_value(tree.value + node_id * tree.value_stride) + + res.node_id = node_id + res.start = start + res.end = end + res.depth = depth + res.impurity = impurity + + if not is_leaf: + # is split node + res.pos = split.pos + res.is_leaf = 0 + res.improvement = split.improvement + res.impurity_left = split.impurity_left + res.impurity_right = split.impurity_right + + else: + # is leaf => 0 improvement + res.pos = end + res.is_leaf = 1 + res.improvement = 0.0 + res.impurity_left = impurity + res.impurity_right = impurity + + return 0 + # ============================================================================= # Tree From 2a94fa2345469db0e066af34e1062069430f1d49 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 18 Nov 2020 11:34:54 -0500 Subject: [PATCH 14/48] Work without max_leaf_nodes limit --- sklearn/tree/_tree.pxd | 2 +- sklearn/tree/_tree.pyx | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index b55a127b3c96e..1d77d7b5215e6 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -107,5 +107,5 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*) cpdef update(self, Tree tree, object X, np.ndarray y, - np.ndarray sample_weight=*) + np.ndarray sample_weight=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index df10b6ac39414..faff8d693d6e1 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -671,7 +671,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t init_capacity = max_split_nodes + max_leaf_nodes tree._resize(init_capacity) - # add reached leaf nodes onto stack + # add reached leaf nodes to frontier cdef SIZE_t start = 0 cdef SIZE_t end = 0 for key, value in sorted(false_roots.items()): @@ -679,14 +679,14 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if key[1]: rc = self._update_split_node(splitter, tree, start, end, tree.impurity[key[0]], IS_NOT_FIRST, - IS_LEFT, &tree.nodes[key[0]] + IS_LEFT, &tree.nodes[key[0]], value[1], &split_node_left) if rc >= 0: rc = _add_to_frontier(&split_node_left, frontier) else: rc = self._update_split_node(splitter, tree, start, end, tree.impurity[key[0]], IS_NOT_FIRST, - IS_NOT_LEFT, &tree.nodes[key[0]] + IS_NOT_LEFT, &tree.nodes[key[0]], value[1], &split_node_right) if rc >= 0: rc = _add_to_frontier(&split_node_right, frontier) From d6c03a7d4b0f5989fd4dc3e90f364aae08fb0b15 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 18 Nov 2020 12:34:33 -0500 Subject: [PATCH 15/48] Update .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 3ebd8e2bb1699..2884e67b77db4 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,6 @@ _configtest.o.d sklearn/utils/_seq_dataset.pyx sklearn/utils/_seq_dataset.pxd sklearn/linear_model/_sag_fast.pyx + +# Jupyter Notebook +.ipynb_checkpoints From 7a3985a1d9b0212b6e53157a34fed4f7d385d5f7 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Sat, 28 Nov 2020 23:39:05 -0500 Subject: [PATCH 16/48] Remove capacity resetting --- sklearn/tree/_tree.pyx | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index faff8d693d6e1..17b577956ef5c 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -321,16 +321,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef object X_new = np.array(X_list) cdef np.ndarray y_new = np.array(y_list) - # Initial capacity - cdef int init_capacity - - if tree.max_depth <= 10: - init_capacity = (2 ** (tree.max_depth + 1)) - 1 - else: - init_capacity = 2047 - - tree._resize(init_capacity) - # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_depth = self.max_depth From 4f8605ed0ecbf5cec9694f8f4239015aaf6722c4 Mon Sep 17 00:00:00 2001 From: PSSF23 Date: Mon, 7 Dec 2020 17:36:04 -0500 Subject: [PATCH 17/48] Resolve 1 node tree problem --- sklearn/tree/_tree.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 17b577956ef5c..a87cae1aad6bb 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -296,7 +296,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): CHILD = depth_i if PARENT < 0: - parent_i = _TREE_UNDEFINED + parent_i = 0 else: parent_i = paths[i].indices[PARENT] child_i = paths[i].indices[CHILD] From 11764a179f9aeb9a090616e8027bcc81af2bef6f Mon Sep 17 00:00:00 2001 From: PSSF23 Date: Sun, 20 Dec 2020 12:50:27 -0500 Subject: [PATCH 18/48] Optimize node order --- sklearn/tree/_tree.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index a87cae1aad6bb..3fcd08cf9160a 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -315,7 +315,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): X_list = [] y_list = [] - for key, value in sorted(X_copy.items()): + for key, value in reversed(sorted(X_copy.items())): X_list = X_list + value y_list = y_list + y_copy[key] cdef object X_new = np.array(X_list) @@ -354,7 +354,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef StackRecord stack_record # push reached leaf nodes onto stack - for key, value in sorted(false_roots.items()): + for key, value in reversed(sorted(false_roots.items())): end += value[0] rc = stack.push(start, end, value[1], key[0], key[1], tree.impurity[key[0]], 0) From 02ca737f1f45e8e3beac5b1038e17eff2339ab5c Mon Sep 17 00:00:00 2001 From: PSSF23 Date: Mon, 18 Jan 2021 09:25:37 -0500 Subject: [PATCH 19/48] Update _tree.pyx --- sklearn/tree/_tree.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 3fcd08cf9160a..5d1572e593116 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -89,7 +89,7 @@ cdef class TreeBuilder: pass cpdef update(self, Tree tree, object X, np.ndarray y, - np.ndarray sample_weight=None): + np.ndarray sample_weight=None): """Update a decision tree with the training set (X, y).""" pass @@ -273,7 +273,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): raise MemoryError() cpdef update(self, Tree tree, object X, np.ndarray y, - np.ndarray sample_weight=None): + np.ndarray sample_weight=None): """Update a decision tree with the training set (X, y).""" # check input @@ -587,7 +587,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): raise MemoryError() cpdef update(self, Tree tree, object X, np.ndarray y, - np.ndarray sample_weight=None): + np.ndarray sample_weight=None): """Update a decision tree with the training set (X, y).""" # check input From 92f7e1878bbfdad8595253ee917eff6ffdd8fdc7 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 20 Jan 2021 21:27:28 -0500 Subject: [PATCH 20/48] Optimize partial_fit api --- sklearn/tree/_classes.py | 238 ++++++++++++++++++++++++++++----------- 1 file changed, 175 insertions(+), 63 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e5506d9c80f40..c212387f96af1 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -141,7 +141,7 @@ def get_n_leaves(self): return self.tree_.n_leaves def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated", update_tree=False): + X_idx_sorted="deprecated"): random_state = check_random_state(self.random_state) @@ -184,35 +184,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, # [:, np.newaxis] that does not. y = np.reshape(y, (-1, 1)) - if update_tree: - # TODO: find a way to build on previous tree - # See if there is an existing tree - try: - self.tree_ - except AttributeError: - print("No existing tree to update") - return self - - if is_classification: - check_classification_targets(y) - y = np.copy(y) - - y_encoded = np.zeros(y.shape, dtype=int) - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], - return_inverse=True) - y = y_encoded - - if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: - y = np.ascontiguousarray(y, dtype=DOUBLE) - - # Update tree - self.builder_.update(self.tree_, X, y, sample_weight) - - self._prune_tree() - - return self - self.n_outputs_ = y.shape[1] if is_classification: @@ -341,11 +312,13 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_impurity_split = self.min_impurity_split if min_impurity_split is not None: - warnings.warn("The min_impurity_split parameter is deprecated. " - "Its default value has changed from 1e-7 to 0 in " - "version 0.23, and it will be removed in 0.25. " - "Use the min_impurity_decrease parameter instead.", - FutureWarning) + warnings.warn( + "The min_impurity_split parameter is deprecated. Its default " + "value has changed from 1e-7 to 0 in version 0.23, and it " + "will be removed in 1.0 (renaming of 0.25). Use the " + "min_impurity_decrease parameter instead.", + FutureWarning + ) if min_impurity_split < 0.: raise ValueError("min_impurity_split must be greater than " @@ -357,12 +330,15 @@ def fit(self, X, y, sample_weight=None, check_input=True, raise ValueError("min_impurity_decrease must be greater than " "or equal to 0") - # TODO: Remove in v0.26 + # TODO: Remove in 1.1 if X_idx_sorted != "deprecated": - warnings.warn("The parameter 'X_idx_sorted' is deprecated and has " - "no effect. It will be removed in v0.26. You can " - "suppress this warning by not passing any value to " - "the 'X_idx_sorted' parameter.", FutureWarning) + warnings.warn( + "The parameter 'X_idx_sorted' is deprecated and has no " + "effect. It will be removed in 1.1 (renaming of 0.26). You " + "can suppress this warning by not passing any value to the " + "'X_idx_sorted' parameter.", + FutureWarning + ) # Build tree criterion = self.criterion @@ -396,19 +372,19 @@ def fit(self, X, y, sample_weight=None, check_input=True, # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: self.builder_ = DepthFirstTreeBuilder(splitter, min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - self.min_impurity_decrease, - min_impurity_split) + min_samples_leaf, + min_weight_leaf, + max_depth, + self.min_impurity_decrease, + min_impurity_split) else: self.builder_ = BestFirstTreeBuilder(splitter, min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - max_leaf_nodes, - self.min_impurity_decrease, - min_impurity_split) + min_samples_leaf, + min_weight_leaf, + max_depth, + max_leaf_nodes, + self.min_impurity_decrease, + min_impurity_split) self.builder_.build(self.tree_, X, y, sample_weight) @@ -420,6 +396,78 @@ def fit(self, X, y, sample_weight=None, check_input=True, return self + def partial_fit(self, X, y, sample_weight=None, check_input=True): + # Fit if no tree exists yet + try: + self.tree_ + except AttributeError: + self.fit( + X, y, + sample_weight=sample_weight, + check_input=check_input) + return self + + random_state = check_random_state(self.random_state) + + if self.ccp_alpha < 0.0: + raise ValueError("ccp_alpha must be greater than or equal to 0") + + if check_input: + # Need to validate separately here. + # We can't pass multi_ouput=True because that would allow y to be + # csr. + check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + check_y_params = dict(ensure_2d=False, dtype=None) + X, y = self._validate_data(X, y, + validate_separately=(check_X_params, + check_y_params)) + if issparse(X): + X.sort_indices() + + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError("No support for np.int64 index based " + "sparse matrices") + + if self.criterion == "poisson": + if np.any(y < 0): + raise ValueError("Some value(s) of y are negative which is" + " not allowed for Poisson regression.") + if np.sum(y) <= 0: + raise ValueError("Sum of y is not positive which is " + "necessary for Poisson regression.") + + # Determine output settings + n_samples, self.n_features_ = X.shape + is_classification = is_classifier(self) + + y = np.atleast_1d(y) + expanded_class_weight = None + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + if is_classification: + check_classification_targets(y) + y = np.copy(y) + + y_encoded = np.zeros(y.shape, dtype=int) + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique(y[:, k], + return_inverse=True) + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.update(self.tree_, X, y, sample_weight) + + self._prune_tree() + + return self + def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities).""" if check_input: @@ -885,7 +933,7 @@ def __init__(self, *, ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated", update_tree=False): + X_idx_sorted="deprecated"): """Build a decision tree classifier from the training set (X, y). Parameters @@ -915,9 +963,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, .. deprecated :: 0.24 - update_tree : bool, default=False - Choice of updating the existing tree or creating a new one. - Returns ------- self : DecisionTreeClassifier @@ -928,8 +973,43 @@ def fit(self, X, y, sample_weight=None, check_input=True, X, y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted, - update_tree=update_tree) + X_idx_sorted=X_idx_sorted) + return self + + def partial_fit(self, X, y, sample_weight=None, check_input=True): + """Update a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + + super().partial_fit( + X, y, + sample_weight=sample_weight, + check_input=check_input) return self def predict_proba(self, X, check_input=True): @@ -1241,7 +1321,7 @@ def __init__(self, *, ccp_alpha=ccp_alpha) def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted="deprecated", update_tree=False): + X_idx_sorted="deprecated"): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1270,9 +1350,6 @@ def fit(self, X, y, sample_weight=None, check_input=True, .. deprecated :: 0.24 - update_tree : bool, default=False - Choice of updating the existing tree or creating a new one. - Returns ------- self : DecisionTreeRegressor @@ -1283,8 +1360,43 @@ def fit(self, X, y, sample_weight=None, check_input=True, X, y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted, - update_tree=update_tree) + X_idx_sorted=X_idx_sorted) + return self + + def partial_fit(self, X, y, sample_weight=None, check_input=True): + """Update a decision tree classifier from the training set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like of shape (n_samples,) or (n_samples, n_outputs) + The target values (class labels) as integers or strings. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + check_input : bool, default=True + Allow to bypass several input checking. + Don't use this parameter unless you know what you do. + + Returns + ------- + self : DecisionTreeClassifier + Fitted estimator. + """ + + super().partial_fit( + X, y, + sample_weight=sample_weight, + check_input=check_input) return self def _compute_partial_dependence_recursion(self, grid, target_features): From f05a3b2d51a5d73d4df7eb5cea22397e0c51a2b3 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 2 Feb 2021 11:44:39 -0500 Subject: [PATCH 21/48] Fix linting --- sklearn/tree/_classes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 37a6879a0f7ef..c14e9ab7bcfef 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -407,8 +407,6 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): check_input=check_input) return self - random_state = check_random_state(self.random_state) - if self.ccp_alpha < 0.0: raise ValueError("ccp_alpha must be greater than or equal to 0") @@ -441,7 +439,6 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): is_classification = is_classifier(self) y = np.atleast_1d(y) - expanded_class_weight = None if y.ndim == 1: # reshape is necessary to preserve the data contiguity against vs From e1b6658d2878c1718b0f469a1b74a796edaa7860 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 09:15:50 -0400 Subject: [PATCH 22/48] FIX add __reduce__ functions --- sklearn/tree/_tree.pyx | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 5d1572e593116..7f721a99d8676 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -139,6 +139,13 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(DepthFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, self.max_depth, + self.min_impurity_decrease, + self.min_impurity_split)) + cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" @@ -474,6 +481,15 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + def __reduce__(self): + """Reduce re-implementation, for pickling.""" + return(BestFirstTreeBuilder, (self.splitter, self.min_samples_split, + self.min_samples_leaf, + self.min_weight_leaf, self.max_depth, + self.max_leaf_nodes, + self.min_impurity_decrease, + self.min_impurity_split)) + cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None): """Build a decision tree from the training set (X, y).""" From 0a5420c06914273cb3879f8baf1c4190e10db909 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 09:30:50 -0400 Subject: [PATCH 23/48] FIX black format the code --- sklearn/tree/_classes.py | 69 ++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 283d181d4d32a..1be00507232c6 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -399,18 +399,24 @@ def fit( # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: - self.builder_ = DepthFirstTreeBuilder(splitter, min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - self.min_impurity_decrease) + self.builder_ = DepthFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + self.min_impurity_decrease, + ) else: - self.builder_ = BestFirstTreeBuilder(splitter, min_samples_split, - min_samples_leaf, - min_weight_leaf, - max_depth, - max_leaf_nodes, - self.min_impurity_decrease) + self.builder_ = BestFirstTreeBuilder( + splitter, + min_samples_split, + min_samples_leaf, + min_weight_leaf, + max_depth, + max_leaf_nodes, + self.min_impurity_decrease, + ) self.builder_.build(self.tree_, X, y, sample_weight) @@ -427,10 +433,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): try: self.tree_ except AttributeError: - self.fit( - X, y, - sample_weight=sample_weight, - check_input=check_input) + self.fit(X, y, sample_weight=sample_weight, check_input=check_input) return self if self.ccp_alpha < 0.0: @@ -442,23 +445,28 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): # csr. check_X_params = dict(dtype=DTYPE, accept_sparse="csc") check_y_params = dict(ensure_2d=False, dtype=None) - X, y = self._validate_data(X, y, - validate_separately=(check_X_params, - check_y_params)) + X, y = self._validate_data( + X, y, validate_separately=(check_X_params, check_y_params) + ) if issparse(X): X.sort_indices() if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: - raise ValueError("No support for np.int64 index based " - "sparse matrices") + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) if self.criterion == "poisson": if np.any(y < 0): - raise ValueError("Some value(s) of y are negative which is" - " not allowed for Poisson regression.") + raise ValueError( + "Some value(s) of y are negative which is" + " not allowed for Poisson regression." + ) if np.sum(y) <= 0: - raise ValueError("Sum of y is not positive which is " - "necessary for Poisson regression.") + raise ValueError( + "Sum of y is not positive which is " + "necessary for Poisson regression." + ) # Determine output settings n_samples, self.n_features_ = X.shape @@ -477,8 +485,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): y_encoded = np.zeros(y.shape, dtype=int) for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], - return_inverse=True) + classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) y = y_encoded if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: @@ -1040,10 +1047,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): Fitted estimator. """ - super().partial_fit( - X, y, - sample_weight=sample_weight, - check_input=check_input) + super().partial_fit(X, y, sample_weight=sample_weight, check_input=check_input) return self def predict_proba(self, X, check_input=True): @@ -1457,10 +1461,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): Fitted estimator. """ - super().partial_fit( - X, y, - sample_weight=sample_weight, - check_input=check_input) + super().partial_fit(X, y, sample_weight=sample_weight, check_input=check_input) return self def _compute_partial_dependence_recursion(self, grid, target_features): From 19893c3a92fa441f16c6dfed45a90c7dabcee393 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 09:42:41 -0400 Subject: [PATCH 24/48] FIX remove min_impurity_split --- sklearn/tree/_tree.pyx | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index b18cdf3824170..ca04ecfac326f 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -140,8 +140,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): """Reduce re-implementation, for pickling.""" return(DepthFirstTreeBuilder, (self.splitter, self.min_samples_split, self.min_samples_leaf, self.max_depth, - self.min_impurity_decrease, - self.min_impurity_split)) + self.min_impurity_decrease)) cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None): @@ -331,7 +330,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_weight_leaf = self.min_weight_leaf cdef SIZE_t min_samples_split = self.min_samples_split cdef double min_impurity_decrease = self.min_impurity_decrease - cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) splitter.init(X_new, y_new, sample_weight_ptr) @@ -386,8 +384,12 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): n_node_samples < 2 * min_samples_leaf or weighted_n_node_samples < 2 * min_weight_leaf) - is_leaf = (is_leaf or - (impurity <= min_impurity_split)) + if first: + impurity = splitter.node_impurity() + first = 0 + + # impurity == 0 with tolerance due to rounding errors + is_leaf = is_leaf or impurity <= EPSILON if not is_leaf: splitter.node_split(impurity, &split, &n_constant_features) @@ -482,8 +484,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self.min_samples_leaf, self.min_weight_leaf, self.max_depth, self.max_leaf_nodes, - self.min_impurity_decrease, - self.min_impurity_split)) + self.min_impurity_decrease)) cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None): @@ -845,7 +846,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t n_constant_features = 0 cdef double weighted_n_samples = splitter.weighted_n_samples cdef double min_impurity_decrease = self.min_impurity_decrease - cdef double min_impurity_split = self.min_impurity_split cdef double weighted_n_node_samples cdef bint is_leaf cdef SIZE_t n_left, n_right @@ -861,8 +861,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): n_node_samples < self.min_samples_split or n_node_samples < 2 * self.min_samples_leaf or weighted_n_node_samples < 2 * self.min_weight_leaf or - impurity <= min_impurity_split) - + impurity <= EPSILON # impurity == 0 with tolerance + ) if not is_leaf: splitter.node_split(impurity, &split, &n_constant_features) # If EPSILON=0 in the below comparison, float precision issues stop From fdd1dfd684f38a86d3c02ea62d9328b958845344 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 10:52:19 -0400 Subject: [PATCH 25/48] FIX update deprecated attribute --- sklearn/tree/_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 1be00507232c6..354c3ad0aaa82 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -469,7 +469,7 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): ) # Determine output settings - n_samples, self.n_features_ = X.shape + n_samples, self.n_features_in_ = X.shape is_classification = is_classifier(self) y = np.atleast_1d(y) From b4cbfa40183fd888e3f130e0c6c1d82b7c94c555 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 11:24:55 -0400 Subject: [PATCH 26/48] FIX optimize api & correct __cinit__ --- sklearn/tree/_classes.py | 29 +++++++++++++++++++++-------- sklearn/tree/_tree.pyx | 4 +++- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 354c3ad0aaa82..b741dee023323 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -37,6 +37,7 @@ from ..utils.validation import _check_sample_weight from ..utils import compute_sample_weight from ..utils.multiclass import check_classification_targets +from ..utils.multiclass import _check_partial_fit_first_call from ..utils.validation import check_is_fitted from ._criterion import Criterion @@ -428,11 +429,9 @@ def fit( return self - def partial_fit(self, X, y, sample_weight=None, check_input=True): + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): # Fit if no tree exists yet - try: - self.tree_ - except AttributeError: + if _check_partial_fit_first_call(self, classes): self.fit(X, y, sample_weight=sample_weight, check_input=check_input) return self @@ -1017,7 +1016,7 @@ def fit( ) return self - def partial_fit(self, X, y, sample_weight=None, check_input=True): + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): """Update a decision tree classifier from the training set (X, y). Parameters @@ -1030,6 +1029,11 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): y : array-like of shape (n_samples,) or (n_samples, n_outputs) The target values (class labels) as integers or strings. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + sample_weight : array-like of shape (n_samples,), default=None Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are @@ -1047,7 +1051,9 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): Fitted estimator. """ - super().partial_fit(X, y, sample_weight=sample_weight, check_input=check_input) + super().partial_fit( + X, y, classes=classes, sample_weight=sample_weight, check_input=check_input + ) return self def predict_proba(self, X, check_input=True): @@ -1431,7 +1437,7 @@ def fit( ) return self - def partial_fit(self, X, y, sample_weight=None, check_input=True): + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): """Update a decision tree classifier from the training set (X, y). Parameters @@ -1444,6 +1450,11 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): y : array-like of shape (n_samples,) or (n_samples, n_outputs) The target values (class labels) as integers or strings. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + sample_weight : array-like of shape (n_samples,), default=None Sample weights. If None, then samples are equally weighted. Splits that would create child nodes with net zero or negative weight are @@ -1461,7 +1472,9 @@ def partial_fit(self, X, y, sample_weight=None, check_input=True): Fitted estimator. """ - super().partial_fit(X, y, sample_weight=sample_weight, check_input=check_input) + super().partial_fit( + X, y, classes=classes, sample_weight=sample_weight, check_input=check_input + ) return self def _compute_partial_dependence_recursion(self, grid, target_features): diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index ca04ecfac326f..92ff20b901dbb 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -139,7 +139,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): def __reduce__(self): """Reduce re-implementation, for pickling.""" return(DepthFirstTreeBuilder, (self.splitter, self.min_samples_split, - self.min_samples_leaf, self.max_depth, + self.min_samples_leaf, + self.min_weight_leaf, + self.max_depth, self.min_impurity_decrease)) cpdef build(self, Tree tree, object X, np.ndarray y, From 8f4b6641be5009130d93929ed528404bd1235707 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 15:45:21 -0400 Subject: [PATCH 27/48] FIX optimize first partial_fit test --- sklearn/tree/_classes.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index b741dee023323..01b89a5584632 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -429,7 +429,9 @@ def fit( return self - def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + def partial_fit( + self, X, y, classes=np.unique(y), sample_weight=None, check_input=True + ): # Fit if no tree exists yet if _check_partial_fit_first_call(self, classes): self.fit(X, y, sample_weight=sample_weight, check_input=check_input) @@ -467,8 +469,11 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): "necessary for Poisson regression." ) + if X.shape[1] != self.n_features_in_: + msg = "Number of features %d does not match previous data %d." + raise ValueError(msg % (X.shape[1], self.n_features_in_)) + # Determine output settings - n_samples, self.n_features_in_ = X.shape is_classification = is_classifier(self) y = np.atleast_1d(y) @@ -1016,7 +1021,9 @@ def fit( ) return self - def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + def partial_fit( + self, X, y, classes=np.unique(y), sample_weight=None, check_input=True + ): """Update a decision tree classifier from the training set (X, y). Parameters @@ -1437,8 +1444,10 @@ def fit( ) return self - def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): - """Update a decision tree classifier from the training set (X, y). + def partial_fit( + self, X, y, classes=np.unique(y), sample_weight=None, check_input=True + ): + """Update a decision tree regressor from the training set (X, y). Parameters ---------- @@ -1468,7 +1477,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): Returns ------- - self : DecisionTreeClassifier + self : DecisionTreeRegressor Fitted estimator. """ @@ -1698,6 +1707,9 @@ class ExtraTreeClassifier(DecisionTreeClassifier): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeRegressor : An extremely randomized tree regressor. @@ -1939,6 +1951,9 @@ class ExtraTreeRegressor(DecisionTreeRegressor): :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py` for basic usage of these attributes. + builder_ : TreeBuilder instance + The underlying TreeBuilder object. + See Also -------- ExtraTreeClassifier : An extremely randomized tree classifier. From 3562219b43b993e35f74ddf1f312d4931c0958c8 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 16:12:25 -0400 Subject: [PATCH 28/48] FIX remove FutureWarning filter --- sklearn/tests/test_common.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 195a4b3da785a..daa1a014edb9d 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -104,9 +104,7 @@ def _tested_estimators(type_filter=None): @parametrize_with_checks(list(_tested_estimators())) def test_estimators(estimator, check, request): # Common tests for estimator instances - with ignore_warnings( - category=(FutureWarning, ConvergenceWarning, UserWarning, FutureWarning) - ): + with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning)): _set_checking_parameters(estimator) check(estimator) @@ -304,7 +302,6 @@ def test_search_cv(estimator, check, request): FutureWarning, ConvergenceWarning, UserWarning, - FutureWarning, FitFailedWarning, ) ): @@ -351,13 +348,12 @@ def test_check_n_features_in_after_fitting(estimator): ) def test_pandas_column_name_consistency(estimator): _set_checking_parameters(estimator) - with ignore_warnings(category=(FutureWarning)): - with pytest.warns(None) as record: - check_dataframe_column_names_consistency( - estimator.__class__.__name__, estimator - ) - for warning in record: - assert "was fitted without feature names" not in str(warning.message) + with pytest.warns(None) as record: + check_dataframe_column_names_consistency( + estimator.__class__.__name__, estimator + ) + for warning in record: + assert "was fitted without feature names" not in str(warning.message) # TODO: As more modules support get_feature_names_out they should be removed From 93ead2dcf889a3fe26dc64ad1c28163b62b01770 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 16:21:08 -0400 Subject: [PATCH 29/48] FIX modify partial_fit parameter --- sklearn/tree/_classes.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 01b89a5584632..c72cf5dc935eb 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -429,11 +429,12 @@ def fit( return self - def partial_fit( - self, X, y, classes=np.unique(y), sample_weight=None, check_input=True - ): + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + if classes == None: + classes = np.unique(y) + # Fit if no tree exists yet - if _check_partial_fit_first_call(self, classes): + if _check_partial_fit_first_call(self, classes=classes): self.fit(X, y, sample_weight=sample_weight, check_input=check_input) return self @@ -1021,9 +1022,7 @@ def fit( ) return self - def partial_fit( - self, X, y, classes=np.unique(y), sample_weight=None, check_input=True - ): + def partial_fit(self, X, y, classes=classes, sample_weight=None, check_input=True): """Update a decision tree classifier from the training set (X, y). Parameters @@ -1444,9 +1443,7 @@ def fit( ) return self - def partial_fit( - self, X, y, classes=np.unique(y), sample_weight=None, check_input=True - ): + def partial_fit(self, X, y, classes=classes, sample_weight=None, check_input=True): """Update a decision tree regressor from the training set (X, y). Parameters From bfaa18c34ee251533b633fada3342a46c4968785 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 14 Sep 2021 16:24:28 -0400 Subject: [PATCH 30/48] FIX correct partial_fit parameter --- sklearn/tree/_classes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index c72cf5dc935eb..760532286f749 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -430,7 +430,7 @@ def fit( return self def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): - if classes == None: + if classes is None: classes = np.unique(y) # Fit if no tree exists yet @@ -1022,7 +1022,7 @@ def fit( ) return self - def partial_fit(self, X, y, classes=classes, sample_weight=None, check_input=True): + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): """Update a decision tree classifier from the training set (X, y). Parameters @@ -1443,7 +1443,7 @@ def fit( ) return self - def partial_fit(self, X, y, classes=classes, sample_weight=None, check_input=True): + def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): """Update a decision tree regressor from the training set (X, y). Parameters From 73779c2f5b58916fe39b1a9c590512059e0c6b21 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 09:30:47 -0400 Subject: [PATCH 31/48] Revert "FIX remove FutureWarning filter" This reverts commit 3562219b43b993e35f74ddf1f312d4931c0958c8. --- sklearn/tests/test_common.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index daa1a014edb9d..195a4b3da785a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -104,7 +104,9 @@ def _tested_estimators(type_filter=None): @parametrize_with_checks(list(_tested_estimators())) def test_estimators(estimator, check, request): # Common tests for estimator instances - with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning)): + with ignore_warnings( + category=(FutureWarning, ConvergenceWarning, UserWarning, FutureWarning) + ): _set_checking_parameters(estimator) check(estimator) @@ -302,6 +304,7 @@ def test_search_cv(estimator, check, request): FutureWarning, ConvergenceWarning, UserWarning, + FutureWarning, FitFailedWarning, ) ): @@ -348,12 +351,13 @@ def test_check_n_features_in_after_fitting(estimator): ) def test_pandas_column_name_consistency(estimator): _set_checking_parameters(estimator) - with pytest.warns(None) as record: - check_dataframe_column_names_consistency( - estimator.__class__.__name__, estimator - ) - for warning in record: - assert "was fitted without feature names" not in str(warning.message) + with ignore_warnings(category=(FutureWarning)): + with pytest.warns(None) as record: + check_dataframe_column_names_consistency( + estimator.__class__.__name__, estimator + ) + for warning in record: + assert "was fitted without feature names" not in str(warning.message) # TODO: As more modules support get_feature_names_out they should be removed From 7d724c17a0509e8c9e8f145a0a9151fd7d11f603 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 09:32:44 -0400 Subject: [PATCH 32/48] FIX prevent feature number reset --- sklearn/tree/_classes.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 760532286f749..a0dbc0f49771e 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -430,9 +430,6 @@ def fit( return self def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): - if classes is None: - classes = np.unique(y) - # Fit if no tree exists yet if _check_partial_fit_first_call(self, classes=classes): self.fit(X, y, sample_weight=sample_weight, check_input=check_input) @@ -448,7 +445,7 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): check_X_params = dict(dtype=DTYPE, accept_sparse="csc") check_y_params = dict(ensure_2d=False, dtype=None) X, y = self._validate_data( - X, y, validate_separately=(check_X_params, check_y_params) + X, y, reset=False, validate_separately=(check_X_params, check_y_params) ) if issparse(X): X.sort_indices() From d3f15ad296009a83220072acc34ce9f632aaae7e Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 09:34:10 -0400 Subject: [PATCH 33/48] MAINT remove duplicate category --- sklearn/tests/test_common.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 195a4b3da785a..4f6818081c67d 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -104,9 +104,7 @@ def _tested_estimators(type_filter=None): @parametrize_with_checks(list(_tested_estimators())) def test_estimators(estimator, check, request): # Common tests for estimator instances - with ignore_warnings( - category=(FutureWarning, ConvergenceWarning, UserWarning, FutureWarning) - ): + with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning)): _set_checking_parameters(estimator) check(estimator) @@ -304,7 +302,6 @@ def test_search_cv(estimator, check, request): FutureWarning, ConvergenceWarning, UserWarning, - FutureWarning, FitFailedWarning, ) ): From 992e34a1373c0923b8b9a86fb372424a0604d922 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 11:20:02 -0400 Subject: [PATCH 34/48] FIX correct regressor partial_fit checks --- sklearn/tree/_classes.py | 47 +++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index a0dbc0f49771e..baa247c4263c3 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -150,7 +150,13 @@ def get_n_leaves(self): return self.tree_.n_leaves def fit( - self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + self, + X, + y, + classes=None, + sample_weight=None, + check_input=True, + X_idx_sorted="deprecated", ): random_state = check_random_state(self.random_state) @@ -430,9 +436,23 @@ def fit( return self def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): + # Determine output settings + is_classification = is_classifier(self) + + if is_classification: + first_call = _check_partial_fit_first_call(self, classes=classes) + else: + first_call = _check_partial_fit_first_call(self, classes=[]) + # Fit if no tree exists yet - if _check_partial_fit_first_call(self, classes=classes): - self.fit(X, y, sample_weight=sample_weight, check_input=check_input) + if first_call: + self.fit( + X, + y, + classes=classes, + sample_weight=sample_weight, + check_input=check_input, + ) return self if self.ccp_alpha < 0.0: @@ -471,9 +491,6 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): msg = "Number of features %d does not match previous data %d." raise ValueError(msg % (X.shape[1], self.n_features_in_)) - # Determine output settings - is_classification = is_classifier(self) - y = np.atleast_1d(y) if y.ndim == 1: @@ -973,7 +990,13 @@ def __init__( ) def fit( - self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + self, + X, + y, + classes=None, + sample_weight=None, + check_input=True, + X_idx_sorted="deprecated", ): """Build a decision tree classifier from the training set (X, y). @@ -1013,6 +1036,7 @@ def fit( super().fit( X, y, + classes=classes, sample_weight=sample_weight, check_input=check_input, X_idx_sorted=X_idx_sorted, @@ -1395,7 +1419,13 @@ def __init__( ) def fit( - self, X, y, sample_weight=None, check_input=True, X_idx_sorted="deprecated" + self, + X, + y, + classes=None, + sample_weight=None, + check_input=True, + X_idx_sorted="deprecated", ): """Build a decision tree regressor from the training set (X, y). @@ -1434,6 +1464,7 @@ def fit( super().fit( X, y, + classes=classes, sample_weight=sample_weight, check_input=check_input, X_idx_sorted=X_idx_sorted, From 2a72c8f9bbb9fe6f088381566f1699686b37f5f5 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 11:31:16 -0400 Subject: [PATCH 35/48] Revert "MAINT remove duplicate category" This reverts commit d3f15ad296009a83220072acc34ce9f632aaae7e. --- sklearn/tests/test_common.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 4f6818081c67d..195a4b3da785a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -104,7 +104,9 @@ def _tested_estimators(type_filter=None): @parametrize_with_checks(list(_tested_estimators())) def test_estimators(estimator, check, request): # Common tests for estimator instances - with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning)): + with ignore_warnings( + category=(FutureWarning, ConvergenceWarning, UserWarning, FutureWarning) + ): _set_checking_parameters(estimator) check(estimator) @@ -302,6 +304,7 @@ def test_search_cv(estimator, check, request): FutureWarning, ConvergenceWarning, UserWarning, + FutureWarning, FitFailedWarning, ) ): From 68d2d7bf78f9ae00a00218efb2b3f0aa25b1962a Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 12:08:30 -0400 Subject: [PATCH 36/48] FIX change parameter order --- sklearn/tree/_classes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index baa247c4263c3..08c2ad04cd272 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -153,9 +153,9 @@ def fit( self, X, y, - classes=None, sample_weight=None, check_input=True, + classes=None, X_idx_sorted="deprecated", ): @@ -993,9 +993,9 @@ def fit( self, X, y, - classes=None, sample_weight=None, check_input=True, + classes=None, X_idx_sorted="deprecated", ): """Build a decision tree classifier from the training set (X, y). @@ -1036,9 +1036,9 @@ def fit( super().fit( X, y, - classes=classes, sample_weight=sample_weight, check_input=check_input, + classes=classes, X_idx_sorted=X_idx_sorted, ) return self @@ -1422,9 +1422,9 @@ def fit( self, X, y, - classes=None, sample_weight=None, check_input=True, + classes=None, X_idx_sorted="deprecated", ): """Build a decision tree regressor from the training set (X, y). @@ -1464,9 +1464,9 @@ def fit( super().fit( X, y, - classes=classes, sample_weight=sample_weight, check_input=check_input, + classes=classes, X_idx_sorted=X_idx_sorted, ) return self From 631a953080a9e8da2fc8983ea18523d91524c0e5 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 13:20:52 -0400 Subject: [PATCH 37/48] DOC add classes parameter docstring --- sklearn/tree/_classes.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 08c2ad04cd272..cace3ba13fd16 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1021,6 +1021,11 @@ def fit( Allow to bypass several input checking. Don't use this parameter unless you know what you do. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + X_idx_sorted : deprecated, default="deprecated" This parameter is deprecated and has no effect. It will be removed in 1.1 (renaming of 0.26). @@ -1449,6 +1454,11 @@ def fit( Allow to bypass several input checking. Don't use this parameter unless you know what you do. + classes : array-like of shape (n_classes,), default=None + List of all the classes that can possibly appear in the y vector. + Must be provided at the first call to partial_fit, can be omitted + in subsequent calls. + X_idx_sorted : deprecated, default="deprecated" This parameter is deprecated and has no effect. It will be removed in 1.1 (renaming of 0.26). From 9ae93b8b1e9b195af109131152ce65cd068ef700 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 13:29:56 -0400 Subject: [PATCH 38/48] EHN pass classes into first fit --- sklearn/tree/_classes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index cace3ba13fd16..a75b8cc71214b 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -219,7 +219,10 @@ def fit( y_encoded = np.zeros(y.shape, dtype=int) for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) + if classes is None: + classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) + else: + classes_k, y_encoded[:, k] = np.copy(classes) self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) y = y_encoded From 23fd3925a46020ff90dc42a4ec3b24decb0cb96a Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 13:46:57 -0400 Subject: [PATCH 39/48] FIX add class indices --- sklearn/tree/_classes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index a75b8cc71214b..b44a501f14838 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -222,7 +222,7 @@ def fit( if classes is None: classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) else: - classes_k, y_encoded[:, k] = np.copy(classes) + classes_k, y_encoded[:, k] = np.copy(classes)[:, k] self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) y = y_encoded From 665ceef12aa18a39bdc7927b540623922cc75724 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 14:55:23 -0400 Subject: [PATCH 40/48] FIX revert class changes --- sklearn/tree/_classes.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index b44a501f14838..7fe15f00bbbdc 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -219,10 +219,7 @@ def fit( y_encoded = np.zeros(y.shape, dtype=int) for k in range(self.n_outputs_): - if classes is None: - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) - else: - classes_k, y_encoded[:, k] = np.copy(classes)[:, k] + classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) self.classes_.append(classes_k) self.n_classes_.append(classes_k.shape[0]) y = y_encoded @@ -452,9 +449,9 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): self.fit( X, y, - classes=classes, sample_weight=sample_weight, check_input=check_input, + classes=classes, ) return self From 85689d2b766ef3bee53728a9507066c7e0e01951 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 15 Sep 2021 17:04:25 -0400 Subject: [PATCH 41/48] EHN pass classes into first fit --- sklearn/tree/_classes.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7fe15f00bbbdc..e9ddc0217c246 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -211,24 +211,35 @@ def fit( check_classification_targets(y) y = np.copy(y) - self.classes_ = [] - self.n_classes_ = [] - if self.class_weight is not None: y_original = np.copy(y) - - y_encoded = np.zeros(y.shape, dtype=int) - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) - self.classes_.append(classes_k) - self.n_classes_.append(classes_k.shape[0]) - y = y_encoded - - if self.class_weight is not None: expanded_class_weight = compute_sample_weight( self.class_weight, y_original ) + self.classes_ = [] + self.n_classes_ = [] + + y_encoded = np.zeros(y.shape, dtype=int) + if classes is not None: + classes = np.atleast_1d(classes) + if classes.ndim == 1: + classes = np.array([classes]) + + for k in classes: + self.classes_.append(np.array(k)) + self.n_classes_.append(np.array(k).shape[0]) + + for i in range(n_samples): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(self.classes_[j] == y[i, j])[0][0] + else: + for k in range(self.n_outputs_): + classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) + self.classes_.append(classes_k) + self.n_classes_.append(classes_k.shape[0]) + + y = y_encoded self.n_classes_ = np.array(self.n_classes_, dtype=np.intp) if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: From 71265db1d9e4fca92eca408f63665ec9836990ec Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Thu, 16 Sep 2021 11:12:52 -0400 Subject: [PATCH 42/48] FIX restrict partial_fit to classifiers --- sklearn/tree/_classes.py | 205 +++++++++++++++------------------------ 1 file changed, 77 insertions(+), 128 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e9ddc0217c246..437cefe5cf79c 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -446,88 +446,6 @@ def fit( return self - def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): - # Determine output settings - is_classification = is_classifier(self) - - if is_classification: - first_call = _check_partial_fit_first_call(self, classes=classes) - else: - first_call = _check_partial_fit_first_call(self, classes=[]) - - # Fit if no tree exists yet - if first_call: - self.fit( - X, - y, - sample_weight=sample_weight, - check_input=check_input, - classes=classes, - ) - return self - - if self.ccp_alpha < 0.0: - raise ValueError("ccp_alpha must be greater than or equal to 0") - - if check_input: - # Need to validate separately here. - # We can't pass multi_ouput=True because that would allow y to be - # csr. - check_X_params = dict(dtype=DTYPE, accept_sparse="csc") - check_y_params = dict(ensure_2d=False, dtype=None) - X, y = self._validate_data( - X, y, reset=False, validate_separately=(check_X_params, check_y_params) - ) - if issparse(X): - X.sort_indices() - - if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: - raise ValueError( - "No support for np.int64 index based sparse matrices" - ) - - if self.criterion == "poisson": - if np.any(y < 0): - raise ValueError( - "Some value(s) of y are negative which is" - " not allowed for Poisson regression." - ) - if np.sum(y) <= 0: - raise ValueError( - "Sum of y is not positive which is " - "necessary for Poisson regression." - ) - - if X.shape[1] != self.n_features_in_: - msg = "Number of features %d does not match previous data %d." - raise ValueError(msg % (X.shape[1], self.n_features_in_)) - - y = np.atleast_1d(y) - - if y.ndim == 1: - # reshape is necessary to preserve the data contiguity against vs - # [:, np.newaxis] that does not. - y = np.reshape(y, (-1, 1)) - - if is_classification: - check_classification_targets(y) - y = np.copy(y) - - y_encoded = np.zeros(y.shape, dtype=int) - for k in range(self.n_outputs_): - classes_k, y_encoded[:, k] = np.unique(y[:, k], return_inverse=True) - y = y_encoded - - if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: - y = np.ascontiguousarray(y, dtype=DOUBLE) - - # Update tree - self.builder_.update(self.tree_, X, y, sample_weight) - - self._prune_tree() - - return self - def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities).""" if check_input: @@ -1094,9 +1012,83 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): Fitted estimator. """ - super().partial_fit( - X, y, classes=classes, sample_weight=sample_weight, check_input=check_input - ) + first_call = _check_partial_fit_first_call(self, classes=classes) + + # Fit if no tree exists yet + if first_call: + self.fit( + X, + y, + sample_weight=sample_weight, + check_input=check_input, + classes=classes, + ) + return self + + if self.ccp_alpha < 0.0: + raise ValueError("ccp_alpha must be greater than or equal to 0") + + if check_input: + # Need to validate separately here. + # We can't pass multi_ouput=True because that would allow y to be + # csr. + check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + check_y_params = dict(ensure_2d=False, dtype=None) + X, y = self._validate_data( + X, y, reset=False, validate_separately=(check_X_params, check_y_params) + ) + if issparse(X): + X.sort_indices() + + if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: + raise ValueError( + "No support for np.int64 index based sparse matrices" + ) + + if self.criterion == "poisson": + if np.any(y < 0): + raise ValueError( + "Some value(s) of y are negative which is" + " not allowed for Poisson regression." + ) + if np.sum(y) <= 0: + raise ValueError( + "Sum of y is not positive which is " + "necessary for Poisson regression." + ) + + if X.shape[1] != self.n_features_in_: + msg = "Number of features %d does not match previous data %d." + raise ValueError(msg % (X.shape[1], self.n_features_in_)) + + y = np.atleast_1d(y) + + if y.ndim == 1: + # reshape is necessary to preserve the data contiguity against vs + # [:, np.newaxis] that does not. + y = np.reshape(y, (-1, 1)) + + check_classification_targets(y) + y = np.copy(y) + + classes = self.classes_ + if self.n_outputs_ == 1: + classes = [classes] + + y_encoded = np.zeros(y.shape, dtype=int) + for i in range(X.shape[0]): + for j in range(self.n_outputs_): + y_encoded[i, j] = np.where(classes[j] == y[i, j])[0][0] + y = y_encoded + + if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: + y = np.ascontiguousarray(y, dtype=DOUBLE) + + # Update tree + self.builder_.update(self.tree_, X, y, sample_weight) + + self._prune_tree() + return self def predict_proba(self, X, check_input=True): @@ -1467,8 +1459,6 @@ def fit( classes : array-like of shape (n_classes,), default=None List of all the classes that can possibly appear in the y vector. - Must be provided at the first call to partial_fit, can be omitted - in subsequent calls. X_idx_sorted : deprecated, default="deprecated" This parameter is deprecated and has no effect. @@ -1487,51 +1477,10 @@ def fit( y, sample_weight=sample_weight, check_input=check_input, - classes=classes, X_idx_sorted=X_idx_sorted, ) return self - def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): - """Update a decision tree regressor from the training set (X, y). - - Parameters - ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - The target values (class labels) as integers or strings. - - classes : array-like of shape (n_classes,), default=None - List of all the classes that can possibly appear in the y vector. - Must be provided at the first call to partial_fit, can be omitted - in subsequent calls. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - - check_input : bool, default=True - Allow to bypass several input checking. - Don't use this parameter unless you know what you do. - - Returns - ------- - self : DecisionTreeRegressor - Fitted estimator. - """ - - super().partial_fit( - X, y, classes=classes, sample_weight=sample_weight, check_input=check_input - ) - return self - def _compute_partial_dependence_recursion(self, grid, target_features): """Fast partial dependence computation. From 814e67e3470f47fc1f468c6be3878a29c5639f66 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Thu, 21 Oct 2021 10:42:34 -0400 Subject: [PATCH 43/48] DOC add changelog --- doc/whats_new/v1.1.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 6d1cff9084f95..19d063c02b138 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -105,6 +105,13 @@ Changelog Setting a transformer to "passthrough" will pass the features unchanged. :pr:`20860` by :user:`Shubhraneel Pal `. +:mod:`sklearn.tree` +.................................. + +- |Enhancement| Added :func:`partial_fit` to :class:`tree.DecisionTreeClassifier` + and :class:`tree.ExtraTreeClassifier`. + :pr:`18889` by :user:`Haoyin Xu `. + :mod:`sklearn.utils` .................... From f0d0eb0a4b6691f8ea5667c7ab4279050816d5f4 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 17 Nov 2021 09:40:31 -0500 Subject: [PATCH 44/48] DOC optimize log format --- doc/whats_new/v1.1.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 19ac050bccd00..514084cf50184 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -60,7 +60,7 @@ Changelog add this information to the plot. :pr:`21038` by :user:`Guillaume Lemaitre `. -- |Enhancement| :class:`cluster.SpectralClustering` and :func:`cluster.spectral` +- |Enhancement| :class:`cluster.SpectralClustering` and :func:`cluster.spectral` now include the new `'cluster_qr'` method from :func:`cluster.cluster_qr` that clusters samples in the embedding space as an alternative to the existing `'kmeans'` and `'discrete'` methods. @@ -144,6 +144,7 @@ Changelog - |Enhancement| Added :func:`partial_fit` to :class:`tree.DecisionTreeClassifier` and :class:`tree.ExtraTreeClassifier`. :pr:`18889` by :user:`Haoyin Xu `. + :mod:`sklearn.preprocessing` ............................ From aef4f840637677193324b227a0054a69a5fc5130 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 17 Nov 2021 09:54:56 -0500 Subject: [PATCH 45/48] FIX remove deprecated parameter --- sklearn/tree/_classes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 5af2f4623e22c..ec127e1f4db7a 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -914,7 +914,6 @@ def fit( sample_weight=None, check_input=True, classes=None, - X_idx_sorted="deprecated", ): """Build a decision tree classifier from the training set (X, y). From 55a6b4bc37c1fff11175a039c3936a54801d8002 Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 30 Nov 2021 14:56:12 -0500 Subject: [PATCH 46/48] FIX optimize n_classes format --- sklearn/tree/_criterion.pyx | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 2c115d0bd6ea1..9d5ff8c81e383 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -207,8 +207,7 @@ cdef class Criterion: cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" - def __cinit__(self, SIZE_t n_outputs, - np.ndarray[SIZE_t, ndim=1] n_classes): + def __cinit__(self, SIZE_t n_outputs, np.ndarray n_classes): """Initialize attributes for this criterion. Parameters @@ -218,6 +217,10 @@ cdef class ClassificationCriterion(Criterion): n_classes : numpy.ndarray, dtype=SIZE_t The number of unique classes in each target """ + cdef SIZE_t dummy = 0 + size_t_dtype = np.array(dummy).dtype + + n_classes = _check_n_classes(n_classes, size_t_dtype) self.sample_weight = NULL self.samples = NULL From 8d3f5c7df1e73660712e6a70b7e464139538683e Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Tue, 30 Nov 2021 14:58:56 -0500 Subject: [PATCH 47/48] FIX add internal function --- sklearn/tree/_criterion.pyx | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 9d5ff8c81e383..a690049c54157 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -203,6 +203,25 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) +def _check_n_classes(n_classes, expected_dtype): + if n_classes.ndim != 1: + raise ValueError( + f"Wrong dimensions for n_classes from the pickle: " + f"expected 1, got {n_classes.ndim}" + ) + + if n_classes.dtype == expected_dtype: + return n_classes + + # Handles both different endianness and different bitness + if n_classes.dtype.kind == "i" and n_classes.dtype.itemsize in [4, 8]: + return n_classes.astype(expected_dtype, casting="same_kind") + + raise ValueError( + "n_classes from the pickle has an incompatible dtype:\n" + f"- expected: {expected_dtype}\n" + f"- got: {n_classes.dtype}" + ) cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" From c47bbb297ebb8bb10b9bad2d81af2fd640c0417e Mon Sep 17 00:00:00 2001 From: Haoyin Xu Date: Wed, 1 Dec 2021 15:02:04 -0500 Subject: [PATCH 48/48] MNT remove unnecessary checks --- sklearn/tree/_classes.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index ec127e1f4db7a..ca6de9060fe95 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1006,9 +1006,6 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): ) return self - if self.ccp_alpha < 0.0: - raise ValueError("ccp_alpha must be greater than or equal to 0") - if check_input: # Need to validate separately here. # We can't pass multi_ouput=True because that would allow y to be @@ -1026,18 +1023,6 @@ def partial_fit(self, X, y, classes=None, sample_weight=None, check_input=True): "No support for np.int64 index based sparse matrices" ) - if self.criterion == "poisson": - if np.any(y < 0): - raise ValueError( - "Some value(s) of y are negative which is" - " not allowed for Poisson regression." - ) - if np.sum(y) <= 0: - raise ValueError( - "Sum of y is not positive which is " - "necessary for Poisson regression." - ) - if X.shape[1] != self.n_features_in_: msg = "Number of features %d does not match previous data %d." raise ValueError(msg % (X.shape[1], self.n_features_in_))