Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept an ndarray as a point #31

Merged
merged 1 commit into from Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Expand Up @@ -5,6 +5,12 @@ file is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and
this project adheres to [Semantic
Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Changed

- `BallTree` and `VantagePointTree` accept an ndarray as a point.

## [0.4.0] - 2020-06-01

### Added
Expand Down Expand Up @@ -44,6 +50,7 @@ Versioning](https://semver.org/spec/v2.0.0.html).

- A ball tree data structure to find nearest neighbors.

[Unreleased]: https://github.com/petabi/petal-neighbors/compare/0.4.0...master
[0.4.0]: https://github.com/petabi/petal-neighbors/compare/0.3.0...0.4.0
[0.3.0]: https://github.com/petabi/petal-neighbors/compare/0.2.0...0.3.0
[0.2.0]: https://github.com/petabi/petal-neighbors/compare/0.1.0...0.2.0
Expand Down
2 changes: 1 addition & 1 deletion benches/ball_tree.rs
Expand Up @@ -29,7 +29,7 @@ fn query_radius(c: &mut Criterion) {
b.iter(|| {
for i in 0..n {
let query = &data[i * dim..i * dim + dim];
tree.query_radius(query, 0.2);
tree.query_radius(&query.into(), 0.2);
}
})
});
Expand Down
114 changes: 53 additions & 61 deletions src/ball_tree.rs
@@ -1,6 +1,6 @@
use crate::distance::{self, Distance};
use crate::ArrayError;
use ndarray::{ArrayBase, ArrayView1, CowArray, Data, Ix2};
use ndarray::{Array1, ArrayBase, ArrayView1, CowArray, Data, Ix1, Ix2};
use num_traits::{Float, FromPrimitive, Zero};
use std::cmp;
use std::collections::BinaryHeap;
Expand Down Expand Up @@ -97,17 +97,20 @@ where
/// # Examples
///
/// ```
/// use ndarray::array;
/// use ndarray::{array, aview1};
/// use petal_neighbors::{BallTree, distance};
///
/// let points = array![[1., 1.], [1., 2.], [9., 9.]];
/// let tree = BallTree::euclidean(points).expect("valid array");
/// let (index, distance) = tree.query_nearest(&[8., 8.]);
/// let (index, distance) = tree.query_nearest(&aview1(&[8., 8.]));
/// assert_eq!(index, 2); // points[2] is the nearest.
/// assert!((2_f64.sqrt() - distance).abs() < 1e-8);
/// ```
pub fn query_nearest(&self, point: &[A]) -> (usize, A) {
self.nearest_neighbor_in_subtree(point, 0, A::infinity())
pub fn query_nearest<S>(&self, point: &ArrayBase<S, Ix1>) -> (usize, A)
where
S: Data<Elem = A>,
{
self.nearest_neighbor_in_subtree(&point.view(), 0, A::infinity())
.expect("0 is a valid index")
}

Expand All @@ -117,17 +120,20 @@ where
/// # Examples
///
/// ```
/// use ndarray::array;
/// use ndarray::{array, aview1};
/// use petal_neighbors::{BallTree, distance};
///
/// let points = array![[1., 1.], [1., 2.], [9., 9.]];
/// let tree = BallTree::euclidean(points).expect("non-empty input");
/// let (indices, distances) = tree.query(&[3., 3.], 2);
/// let (indices, distances) = tree.query(&aview1(&[3., 3.]), 2);
/// assert_eq!(indices, &[1, 0]); // points[1] is the nearest, followed by points[0].
/// ```
pub fn query(&self, point: &[A], k: usize) -> (Vec<usize>, Vec<A>) {
pub fn query<S>(&self, point: &ArrayBase<S, Ix1>, k: usize) -> (Vec<usize>, Vec<A>)
where
S: Data<Elem = A>,
{
let mut neighbors = BinaryHeap::with_capacity(k);
self.nearest_k_neighbors_in_subtree(point, 0, A::infinity(), k, &mut neighbors);
self.nearest_k_neighbors_in_subtree(&point.view(), 0, A::infinity(), k, &mut neighbors);
let sorted = neighbors.into_sorted_vec();
let indices = sorted.iter().map(|v| v.idx).collect();
let distances = sorted.iter().map(|v| v.distance).collect();
Expand All @@ -140,16 +146,19 @@ where
/// # Examples
///
/// ```
/// use ndarray::array;
/// use ndarray::{array, aview1};
/// use petal_neighbors::{BallTree, distance};
///
/// let points = array![[1., 0.], [2., 0.], [9., 0.]];
/// let tree = BallTree::euclidean(points).expect("non-empty input");
/// let indices = tree.query_radius(&[3., 0.], 1.5);
/// let indices = tree.query_radius(&aview1(&[3., 0.]), 1.5);
/// assert_eq!(indices, &[1]); // The distance to points[1] is less than 1.5.
/// ```
pub fn query_radius(&self, point: &[A], distance: A) -> Vec<usize> {
self.neighbors_within_radius_in_subtree(point, distance, 0)
pub fn query_radius<S>(&self, point: &ArrayBase<S, Ix1>, distance: A) -> Vec<usize>
where
S: Data<Elem = A>,
{
self.neighbors_within_radius_in_subtree(&point.view(), distance, 0)
}

/// Finds the nearest neighbor and its distance in the subtree rooted at `root`.
Expand All @@ -159,7 +168,7 @@ where
/// Panics if `root` is out of bound.
fn nearest_neighbor_in_subtree(
&self,
point: &[A],
point: &ArrayView1<A>,
root: usize,
radius: A,
) -> Option<(usize, A)> {
Expand All @@ -174,10 +183,7 @@ where
(0, A::infinity()),
|(min_i, min_dist), &i| {
let distance = self.distance;
let dist = distance(
point,
self.points.row(i).as_slice().expect("standard layout"),
);
let dist = distance(point, &self.points.row(i));

if dist < min_dist {
(i, dist)
Expand Down Expand Up @@ -223,7 +229,7 @@ where
/// Panics if `root` is out of bound.
fn nearest_k_neighbors_in_subtree(
&self,
point: &[A],
point: &ArrayView1<A>,
root: usize,
radius: A,
k: usize,
Expand All @@ -239,10 +245,7 @@ where
.iter()
.filter_map(|&i| {
let distance = self.distance;
let dist = distance(
point,
self.points.row(i).as_slice().expect("standard layout"),
);
let dist = distance(point, &self.points.row(i));

if dist < radius {
Some(Neighbor::new(i, dist))
Expand Down Expand Up @@ -281,7 +284,7 @@ where
/// Panics if `root` is out of bound.
fn neighbors_within_radius_in_subtree(
&self,
point: &[A],
point: &ArrayView1<A>,
radius: A,
root: usize,
) -> Vec<usize> {
Expand All @@ -306,10 +309,7 @@ where
} else if root_node.is_leaf {
neighbors.extend(self.idx[root_node.range.clone()].iter().filter_map(|&i| {
let distance = self.distance;
let dist = distance(
point,
self.points.row(i).as_slice().expect("standard layout"),
);
let dist = distance(point, &self.points.row(i));
if dist < radius {
Some(i)
} else {
Expand Down Expand Up @@ -386,7 +386,7 @@ impl<A> Eq for Neighbor<A> where A: Float {}
#[derive(Clone, Debug)]
struct Node<A> {
range: Range<usize>,
centroid: Vec<A>,
centroid: Array1<A>,
radius: A,
is_leaf: bool,
}
Expand All @@ -405,7 +405,7 @@ where
fn init(&mut self, points: &CowArray<A, Ix2>, idx: &[usize], distance: Distance<A>) {
let mut sum = idx
.iter()
.fold(vec![A::zero(); points.ncols()], |mut sum, &i| {
.fold(Array1::<A>::zeros(points.ncols()), |mut sum, &i| {
for (s, v) in sum.iter_mut().zip(points.row(i)) {
*s += *v;
}
Expand All @@ -416,15 +416,12 @@ where
self.centroid = sum;

self.radius = idx.iter().fold(A::zero(), |max, &i| {
A::max(
distance(&self.centroid, &points.row(i).as_slice().unwrap()),
max,
)
A::max(distance(&self.centroid.view(), &points.row(i)), max)
});
}

fn distance_bounds(&self, point: &[A], distance: Distance<A>) -> (A, A) {
let centroid_dist = distance(point, &self.centroid);
fn distance_bounds(&self, point: &ArrayView1<A>, distance: Distance<A>) -> (A, A) {
let centroid_dist = distance(point, &self.centroid.view());
let mut lb = centroid_dist - self.radius;
if lb < A::zero() {
lb = A::zero();
Expand All @@ -433,8 +430,8 @@ where
(lb, ub)
}

fn distance_lower_bound(&self, point: &[A], distance: Distance<A>) -> A {
let centroid_dist = distance(point, &self.centroid);
fn distance_lower_bound(&self, point: &ArrayView1<A>, distance: Distance<A>) -> A {
let centroid_dist = distance(point, &self.centroid.view());
let lb = centroid_dist - self.radius;
if lb < A::zero() {
A::zero()
Expand All @@ -452,7 +449,7 @@ where
fn default() -> Self {
Self {
range: (0..0),
centroid: Vec::new(),
centroid: Array1::<A>::zeros(0),
radius: A::zero(),
is_leaf: false,
}
Expand Down Expand Up @@ -593,7 +590,7 @@ mod test {
use super::*;
use crate::distance;
use approx;
use ndarray::{array, aview1, aview2, Array, Axis};
use ndarray::{arr1, array, aview1, aview2, Array, Axis};
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;

Expand All @@ -607,7 +604,7 @@ mod test {
Some(distance::euclidean_reduced),
)
.expect("`data` should not be empty");
let point = [0., 0.];
let point = aview1(&[0., 0.]);
tree.query_nearest(&point);
}

Expand All @@ -621,7 +618,7 @@ mod test {
)
.expect("`array` should not be empty");

let point = [0., 0.];
let point = aview1(&[0., 0.]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 0);
assert!(approx::abs_diff_eq!(neighbor.1, 2_f64.sqrt()));
Expand All @@ -634,7 +631,7 @@ mod test {
neighbors.sort_unstable();
assert_eq!(neighbors, &[0, 1]);

let point = [1.1, 1.2];
let point = aview1(&[1.1, 1.2]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 1);
assert!(approx::abs_diff_eq!(
Expand All @@ -647,7 +644,7 @@ mod test {
assert_eq!(indices[0], neighbor.0);
assert!(approx::abs_diff_eq!(distances[0], neighbor.1));

let point = [7., 7.];
let point = aview1(&[7., 7.]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 2);
assert!(approx::abs_diff_eq!(neighbor.1, 8_f64.sqrt()));
Expand Down Expand Up @@ -675,7 +672,7 @@ mod test {
)
.expect("`array` should not be empty");

let point = [1., 2.];
let point = aview1(&[1., 2.]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 0);
assert!(approx::abs_diff_eq!(neighbor.1, 0_f64.sqrt()));
Expand All @@ -700,7 +697,7 @@ mod test {
)
.expect("`array` should not be empty");

let point = [1., 2.];
let point = aview1(&[1., 2.]);
let neighbor = tree.query_nearest(&point);
assert!(approx::abs_diff_eq!(neighbor.1, 1_f64.sqrt()));
}
Expand All @@ -714,14 +711,9 @@ mod test {
.expect("`array` should not be empty");
for _ in 0..10 {
let query = Array::random(DIMENSION, Uniform::new(0., 1.));
let (_, bt_distances) =
bt.query(query.as_slice().expect("should be contiguous in memory"), 5);
let naive_neighbors = naive_k_nearest_neighbors(
&array,
query.as_slice().expect("should be contiguous in memory"),
5,
distance::euclidean,
);
let (_, bt_distances) = bt.query(&query, 5);
let naive_neighbors =
naive_k_nearest_neighbors(&array, &query.view(), 5, distance::euclidean);
for (bt_dist, naive_neighbor) in bt_distances.iter().zip(naive_neighbors.iter()) {
assert!(approx::abs_diff_eq!(*bt_dist, naive_neighbor.distance));
}
Expand All @@ -738,14 +730,14 @@ mod test {
)
.expect("`array` should not be empty");

let neighbors = bt.query_radius(&[0.1], 1.);
let neighbors = bt.query_radius(&aview1(&[0.1]), 1.);
assert_eq!(neighbors, &[0]);

let mut neighbors = bt.query_radius(&[3.2], 1.);
let mut neighbors = bt.query_radius(&aview1(&[3.2]), 1.);
neighbors.sort_unstable();
assert_eq!(neighbors, &[2, 3]);

let neighbors = bt.query_radius(&[9.], 0.9);
let neighbors = bt.query_radius(&aview1(&[9.]), 0.9);
assert!(neighbors.is_empty());
}

Expand All @@ -755,12 +747,12 @@ mod test {
let idx: [usize; 3] = [0, 1, 2];
let mut node = Node::default();
node.init(&array.view().into(), &idx, distance::euclidean);
assert_eq!(node.centroid, [0., 4.]);
assert_eq!(node.centroid, arr1(&[0., 4.]));
assert_eq!(node.radius, 5.);

let idx: [usize; 2] = [0, 2];
node.init(&array.into(), &idx, distance::euclidean);
assert_eq!(node.centroid, [0., 1.5]);
assert_eq!(node.centroid, arr1(&[0., 1.5]));
}

#[test]
Expand Down Expand Up @@ -838,7 +830,7 @@ mod test {
/// Panics if any row in `neighbors` is not contiguous in memory.
fn naive_k_nearest_neighbors<'a, A, S>(
neighbors: &'a ArrayBase<S, Ix2>,
point: &[A],
point: &ArrayView1<A>,
k: usize,
distance: Distance<A>,
) -> Vec<Neighbor<A>>
Expand All @@ -851,7 +843,7 @@ mod test {
.enumerate()
.map(|(i, n)| Neighbor {
idx: i,
distance: distance(n.to_slice().unwrap(), point),
distance: distance(&n, point),
})
.collect::<Vec<Neighbor<A>>>();
knn.sort();
Expand Down
9 changes: 6 additions & 3 deletions src/distance.rs
@@ -1,15 +1,17 @@
//! Distance metrics.

use ndarray::ArrayView1;
use num_traits::{Float, Zero};
use std::ops::AddAssign;

/// The type of a distance metric function.
pub type Distance<A> = fn(&[A], &[A]) -> A;
pub type Distance<A> = fn(&ArrayView1<A>, &ArrayView1<A>) -> A;

/// Euclidean distance before taking a squre root. Used as a lightweight version
/// of [`euclidean`] for relative comparisions.
/// [`euclidean`]: #method.euclidean
pub fn euclidean_reduced<A>(x1: &[A], x2: &[A]) -> A
#[must_use]
pub fn euclidean_reduced<A>(x1: &ArrayView1<A>, x2: &ArrayView1<A>) -> A
where
A: Float + Zero + AddAssign,
{
Expand All @@ -23,7 +25,8 @@ where
}

/// Euclidean distance metric.
pub fn euclidean<A>(x1: &[A], x2: &[A]) -> A
#[must_use]
pub fn euclidean<A>(x1: &ArrayView1<A>, x2: &ArrayView1<A>) -> A
where
A: Float + Zero + AddAssign,
{
Expand Down