From d02a4012be1c48ff30c06d9a6c27d01644cc058c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 19 May 2022 09:48:29 +0200 Subject: [PATCH] FIX fix performance regression in trees with low-cardinality features (#23410) Co-authored-by: Guillaume Lemaitre Co-authored-by: Thomas J. Fan --- sklearn/tree/_splitter.pyx | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 8a40d14cac5b7..76b502f98f144 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -26,7 +26,6 @@ from ._utils cimport log from ._utils cimport rand_int from ._utils cimport rand_uniform from ._utils cimport RAND_R_MAX -from ..utils._sorting cimport simultaneous_sort cdef double INFINITY = np.inf @@ -342,7 +341,7 @@ cdef class BestSplitter(BaseDenseSplitter): for i in range(start, end): Xf[i] = self.X[samples[i], current.feature] - simultaneous_sort(&Xf[start], &samples[start], end - start) + sort(&Xf[start], &samples[start], end - start) if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD: features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j] @@ -1161,11 +1160,11 @@ cdef class BestSparseSplitter(BaseSparseSplitter): current.feature = features[f_j] self.extract_nnz(current.feature, &end_negative, &start_positive, &is_samples_sorted) - # Sort the positive and negative parts of `Xf` - simultaneous_sort(&Xf[start], &samples[start], end_negative - start) + sort(&Xf[start], &samples[start], end_negative - start) if start_positive < end: - simultaneous_sort(&Xf[start_positive], &samples[start_positive], end - start_positive) + sort(&Xf[start_positive], &samples[start_positive], + end - start_positive) # Update index_to_samples to take into account the sort for p in range(start, end_negative):