Skip to content

Commit

Permalink
Use sliding rule also on median pivot for stability
Browse files Browse the repository at this point in the history
Median pivot (aka "standard split") results in stability issues when we have a lot of ties.
  • Loading branch information
sturlamolden committed May 8, 2024
1 parent 45dc44d commit 2304fe7
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions scipy/spatial/ckdtree/src/build.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,34 @@ build(ckdtree *self, ckdtree_intp_t start_idx, intptr_t end_idx,
auto mid = node_indices + n_points / 2;
std::nth_element(
node_indices, mid, node_indices + n_points, index_compare);

split = data[*mid * m + d];
p = partition_pivot(indices + start_idx, indices + end_idx, split);
}
else {
/* split with the sliding midpoint rule */
split = (maxval + minval) / 2;

p = partition_pivot(indices + start_idx, indices + end_idx, split);

/* slide midpoint if necessary */
if (p == start_idx) {
/* no points less than split */
auto min_idx = *std::min_element(
p = partition_pivot(indices + start_idx, indices + end_idx, split);
}

// slide midpoint/pivot if necessary
// we might even need to do this for the median pivot to avoid infinite
// recursion
if (p == start_idx) {
/* no points less than split */
auto min_idx = *std::min_element(
indices + start_idx, indices + end_idx, index_compare);
split = std::nextafter(data[min_idx * m + d], std::numeric_limits<double>::max() );
p = partition_pivot(indices + start_idx, indices + end_idx, split);
}
else if (p == end_idx) {
/* no points greater than split */
auto max_idx = *std::max_element(
split = std::nextafter(data[min_idx * m + d], std::numeric_limits<double>::max() );
p = partition_pivot(indices + start_idx, indices + end_idx, split);
}
else if (p == end_idx) {
/* no points greater than split */
auto max_idx = *std::max_element(
indices + start_idx, indices + end_idx, index_compare);
split = std::nextafter(data[max_idx * m + d], std::numeric_limits<double>::lowest() );
p = partition_pivot(indices + start_idx, indices + end_idx, split);
}
}

if (CKDTREE_UNLIKELY(p == start_idx || p == end_idx)) {
split = std::nextafter(data[max_idx * m + d], std::numeric_limits<double>::lowest() );
p = partition_pivot(indices + start_idx, indices + end_idx, split);
}

if (CKDTREE_UNLIKELY(p == start_idx || p == end_idx)) {
// All children are equal in this dimension, try again with
// this dimension tabooed
assert(!_compact);
Expand Down

0 comments on commit 2304fe7

Please sign in to comment.