Skip to content

Commit

Permalink
FIX fix performance regression in trees with low-cardinality features (
Browse files Browse the repository at this point in the history
…scikit-learn#23410)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
  • Loading branch information
3 people committed Aug 4, 2022
1 parent 3b92f2d commit d02a401
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions sklearn/tree/_splitter.pyx
Expand Up @@ -26,7 +26,6 @@ from ._utils cimport log
from ._utils cimport rand_int
from ._utils cimport rand_uniform
from ._utils cimport RAND_R_MAX
from ..utils._sorting cimport simultaneous_sort

cdef double INFINITY = np.inf

Expand Down Expand Up @@ -342,7 +341,7 @@ cdef class BestSplitter(BaseDenseSplitter):
for i in range(start, end):
Xf[i] = self.X[samples[i], current.feature]

simultaneous_sort(&Xf[start], &samples[start], end - start)
sort(&Xf[start], &samples[start], end - start)

if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD:
features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j]
Expand Down Expand Up @@ -1161,11 +1160,11 @@ cdef class BestSparseSplitter(BaseSparseSplitter):
current.feature = features[f_j]
self.extract_nnz(current.feature, &end_negative, &start_positive,
&is_samples_sorted)

# Sort the positive and negative parts of `Xf`
simultaneous_sort(&Xf[start], &samples[start], end_negative - start)
sort(&Xf[start], &samples[start], end_negative - start)
if start_positive < end:
simultaneous_sort(&Xf[start_positive], &samples[start_positive], end - start_positive)
sort(&Xf[start_positive], &samples[start_positive],
end - start_positive)

# Update index_to_samples to take into account the sort
for p in range(start, end_negative):
Expand Down

0 comments on commit d02a401

Please sign in to comment.