diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a3afa5..e11b672 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ file is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Changed + +- `BallTree` and `VantagePointTree` accept an ndarray as a point. + ## [0.4.0] - 2020-06-01 ### Added @@ -44,6 +50,7 @@ Versioning](https://semver.org/spec/v2.0.0.html). - A ball tree data structure to find nearest neighbors. +[Unreleased]: https://github.com/petabi/petal-neighbors/compare/0.4.0...master [0.4.0]: https://github.com/petabi/petal-neighbors/compare/0.3.0...0.4.0 [0.3.0]: https://github.com/petabi/petal-neighbors/compare/0.2.0...0.3.0 [0.2.0]: https://github.com/petabi/petal-neighbors/compare/0.1.0...0.2.0 diff --git a/benches/ball_tree.rs b/benches/ball_tree.rs index 9acba03..57cc691 100644 --- a/benches/ball_tree.rs +++ b/benches/ball_tree.rs @@ -29,7 +29,7 @@ fn query_radius(c: &mut Criterion) { b.iter(|| { for i in 0..n { let query = &data[i * dim..i * dim + dim]; - tree.query_radius(query, 0.2); + tree.query_radius(&query.into(), 0.2); } }) }); diff --git a/src/ball_tree.rs b/src/ball_tree.rs index 3279f23..4529085 100644 --- a/src/ball_tree.rs +++ b/src/ball_tree.rs @@ -1,6 +1,6 @@ use crate::distance::{self, Distance}; use crate::ArrayError; -use ndarray::{ArrayBase, ArrayView1, CowArray, Data, Ix2}; +use ndarray::{Array1, ArrayBase, ArrayView1, CowArray, Data, Ix1, Ix2}; use num_traits::{Float, FromPrimitive, Zero}; use std::cmp; use std::collections::BinaryHeap; @@ -97,17 +97,20 @@ where /// # Examples /// /// ``` - /// use ndarray::array; + /// use ndarray::{array, aview1}; /// use petal_neighbors::{BallTree, distance}; /// /// let points = array![[1., 1.], [1., 2.], [9., 9.]]; /// let tree = BallTree::euclidean(points).expect("valid array"); - /// let (index, distance) = tree.query_nearest(&[8., 8.]); + /// let (index, distance) = tree.query_nearest(&aview1(&[8., 8.])); /// assert_eq!(index, 2); // points[2] is the nearest. /// assert!((2_f64.sqrt() - distance).abs() < 1e-8); /// ``` - pub fn query_nearest(&self, point: &[A]) -> (usize, A) { - self.nearest_neighbor_in_subtree(point, 0, A::infinity()) + pub fn query_nearest(&self, point: &ArrayBase) -> (usize, A) + where + S: Data, + { + self.nearest_neighbor_in_subtree(&point.view(), 0, A::infinity()) .expect("0 is a valid index") } @@ -117,17 +120,20 @@ where /// # Examples /// /// ``` - /// use ndarray::array; + /// use ndarray::{array, aview1}; /// use petal_neighbors::{BallTree, distance}; /// /// let points = array![[1., 1.], [1., 2.], [9., 9.]]; /// let tree = BallTree::euclidean(points).expect("non-empty input"); - /// let (indices, distances) = tree.query(&[3., 3.], 2); + /// let (indices, distances) = tree.query(&aview1(&[3., 3.]), 2); /// assert_eq!(indices, &[1, 0]); // points[1] is the nearest, followed by points[0]. /// ``` - pub fn query(&self, point: &[A], k: usize) -> (Vec, Vec) { + pub fn query(&self, point: &ArrayBase, k: usize) -> (Vec, Vec) + where + S: Data, + { let mut neighbors = BinaryHeap::with_capacity(k); - self.nearest_k_neighbors_in_subtree(point, 0, A::infinity(), k, &mut neighbors); + self.nearest_k_neighbors_in_subtree(&point.view(), 0, A::infinity(), k, &mut neighbors); let sorted = neighbors.into_sorted_vec(); let indices = sorted.iter().map(|v| v.idx).collect(); let distances = sorted.iter().map(|v| v.distance).collect(); @@ -140,16 +146,19 @@ where /// # Examples /// /// ``` - /// use ndarray::array; + /// use ndarray::{array, aview1}; /// use petal_neighbors::{BallTree, distance}; /// /// let points = array![[1., 0.], [2., 0.], [9., 0.]]; /// let tree = BallTree::euclidean(points).expect("non-empty input"); - /// let indices = tree.query_radius(&[3., 0.], 1.5); + /// let indices = tree.query_radius(&aview1(&[3., 0.]), 1.5); /// assert_eq!(indices, &[1]); // The distance to points[1] is less than 1.5. /// ``` - pub fn query_radius(&self, point: &[A], distance: A) -> Vec { - self.neighbors_within_radius_in_subtree(point, distance, 0) + pub fn query_radius(&self, point: &ArrayBase, distance: A) -> Vec + where + S: Data, + { + self.neighbors_within_radius_in_subtree(&point.view(), distance, 0) } /// Finds the nearest neighbor and its distance in the subtree rooted at `root`. @@ -159,7 +168,7 @@ where /// Panics if `root` is out of bound. fn nearest_neighbor_in_subtree( &self, - point: &[A], + point: &ArrayView1, root: usize, radius: A, ) -> Option<(usize, A)> { @@ -174,10 +183,7 @@ where (0, A::infinity()), |(min_i, min_dist), &i| { let distance = self.distance; - let dist = distance( - point, - self.points.row(i).as_slice().expect("standard layout"), - ); + let dist = distance(point, &self.points.row(i)); if dist < min_dist { (i, dist) @@ -223,7 +229,7 @@ where /// Panics if `root` is out of bound. fn nearest_k_neighbors_in_subtree( &self, - point: &[A], + point: &ArrayView1, root: usize, radius: A, k: usize, @@ -239,10 +245,7 @@ where .iter() .filter_map(|&i| { let distance = self.distance; - let dist = distance( - point, - self.points.row(i).as_slice().expect("standard layout"), - ); + let dist = distance(point, &self.points.row(i)); if dist < radius { Some(Neighbor::new(i, dist)) @@ -281,7 +284,7 @@ where /// Panics if `root` is out of bound. fn neighbors_within_radius_in_subtree( &self, - point: &[A], + point: &ArrayView1, radius: A, root: usize, ) -> Vec { @@ -306,10 +309,7 @@ where } else if root_node.is_leaf { neighbors.extend(self.idx[root_node.range.clone()].iter().filter_map(|&i| { let distance = self.distance; - let dist = distance( - point, - self.points.row(i).as_slice().expect("standard layout"), - ); + let dist = distance(point, &self.points.row(i)); if dist < radius { Some(i) } else { @@ -386,7 +386,7 @@ impl Eq for Neighbor where A: Float {} #[derive(Clone, Debug)] struct Node { range: Range, - centroid: Vec, + centroid: Array1, radius: A, is_leaf: bool, } @@ -405,7 +405,7 @@ where fn init(&mut self, points: &CowArray, idx: &[usize], distance: Distance) { let mut sum = idx .iter() - .fold(vec![A::zero(); points.ncols()], |mut sum, &i| { + .fold(Array1::::zeros(points.ncols()), |mut sum, &i| { for (s, v) in sum.iter_mut().zip(points.row(i)) { *s += *v; } @@ -416,15 +416,12 @@ where self.centroid = sum; self.radius = idx.iter().fold(A::zero(), |max, &i| { - A::max( - distance(&self.centroid, &points.row(i).as_slice().unwrap()), - max, - ) + A::max(distance(&self.centroid.view(), &points.row(i)), max) }); } - fn distance_bounds(&self, point: &[A], distance: Distance) -> (A, A) { - let centroid_dist = distance(point, &self.centroid); + fn distance_bounds(&self, point: &ArrayView1, distance: Distance) -> (A, A) { + let centroid_dist = distance(point, &self.centroid.view()); let mut lb = centroid_dist - self.radius; if lb < A::zero() { lb = A::zero(); @@ -433,8 +430,8 @@ where (lb, ub) } - fn distance_lower_bound(&self, point: &[A], distance: Distance) -> A { - let centroid_dist = distance(point, &self.centroid); + fn distance_lower_bound(&self, point: &ArrayView1, distance: Distance) -> A { + let centroid_dist = distance(point, &self.centroid.view()); let lb = centroid_dist - self.radius; if lb < A::zero() { A::zero() @@ -452,7 +449,7 @@ where fn default() -> Self { Self { range: (0..0), - centroid: Vec::new(), + centroid: Array1::::zeros(0), radius: A::zero(), is_leaf: false, } @@ -593,7 +590,7 @@ mod test { use super::*; use crate::distance; use approx; - use ndarray::{array, aview1, aview2, Array, Axis}; + use ndarray::{arr1, array, aview1, aview2, Array, Axis}; use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; @@ -607,7 +604,7 @@ mod test { Some(distance::euclidean_reduced), ) .expect("`data` should not be empty"); - let point = [0., 0.]; + let point = aview1(&[0., 0.]); tree.query_nearest(&point); } @@ -621,7 +618,7 @@ mod test { ) .expect("`array` should not be empty"); - let point = [0., 0.]; + let point = aview1(&[0., 0.]); let neighbor = tree.query_nearest(&point); assert_eq!(neighbor.0, 0); assert!(approx::abs_diff_eq!(neighbor.1, 2_f64.sqrt())); @@ -634,7 +631,7 @@ mod test { neighbors.sort_unstable(); assert_eq!(neighbors, &[0, 1]); - let point = [1.1, 1.2]; + let point = aview1(&[1.1, 1.2]); let neighbor = tree.query_nearest(&point); assert_eq!(neighbor.0, 1); assert!(approx::abs_diff_eq!( @@ -647,7 +644,7 @@ mod test { assert_eq!(indices[0], neighbor.0); assert!(approx::abs_diff_eq!(distances[0], neighbor.1)); - let point = [7., 7.]; + let point = aview1(&[7., 7.]); let neighbor = tree.query_nearest(&point); assert_eq!(neighbor.0, 2); assert!(approx::abs_diff_eq!(neighbor.1, 8_f64.sqrt())); @@ -675,7 +672,7 @@ mod test { ) .expect("`array` should not be empty"); - let point = [1., 2.]; + let point = aview1(&[1., 2.]); let neighbor = tree.query_nearest(&point); assert_eq!(neighbor.0, 0); assert!(approx::abs_diff_eq!(neighbor.1, 0_f64.sqrt())); @@ -700,7 +697,7 @@ mod test { ) .expect("`array` should not be empty"); - let point = [1., 2.]; + let point = aview1(&[1., 2.]); let neighbor = tree.query_nearest(&point); assert!(approx::abs_diff_eq!(neighbor.1, 1_f64.sqrt())); } @@ -714,14 +711,9 @@ mod test { .expect("`array` should not be empty"); for _ in 0..10 { let query = Array::random(DIMENSION, Uniform::new(0., 1.)); - let (_, bt_distances) = - bt.query(query.as_slice().expect("should be contiguous in memory"), 5); - let naive_neighbors = naive_k_nearest_neighbors( - &array, - query.as_slice().expect("should be contiguous in memory"), - 5, - distance::euclidean, - ); + let (_, bt_distances) = bt.query(&query, 5); + let naive_neighbors = + naive_k_nearest_neighbors(&array, &query.view(), 5, distance::euclidean); for (bt_dist, naive_neighbor) in bt_distances.iter().zip(naive_neighbors.iter()) { assert!(approx::abs_diff_eq!(*bt_dist, naive_neighbor.distance)); } @@ -738,14 +730,14 @@ mod test { ) .expect("`array` should not be empty"); - let neighbors = bt.query_radius(&[0.1], 1.); + let neighbors = bt.query_radius(&aview1(&[0.1]), 1.); assert_eq!(neighbors, &[0]); - let mut neighbors = bt.query_radius(&[3.2], 1.); + let mut neighbors = bt.query_radius(&aview1(&[3.2]), 1.); neighbors.sort_unstable(); assert_eq!(neighbors, &[2, 3]); - let neighbors = bt.query_radius(&[9.], 0.9); + let neighbors = bt.query_radius(&aview1(&[9.]), 0.9); assert!(neighbors.is_empty()); } @@ -755,12 +747,12 @@ mod test { let idx: [usize; 3] = [0, 1, 2]; let mut node = Node::default(); node.init(&array.view().into(), &idx, distance::euclidean); - assert_eq!(node.centroid, [0., 4.]); + assert_eq!(node.centroid, arr1(&[0., 4.])); assert_eq!(node.radius, 5.); let idx: [usize; 2] = [0, 2]; node.init(&array.into(), &idx, distance::euclidean); - assert_eq!(node.centroid, [0., 1.5]); + assert_eq!(node.centroid, arr1(&[0., 1.5])); } #[test] @@ -838,7 +830,7 @@ mod test { /// Panics if any row in `neighbors` is not contiguous in memory. fn naive_k_nearest_neighbors<'a, A, S>( neighbors: &'a ArrayBase, - point: &[A], + point: &ArrayView1, k: usize, distance: Distance, ) -> Vec> @@ -851,7 +843,7 @@ mod test { .enumerate() .map(|(i, n)| Neighbor { idx: i, - distance: distance(n.to_slice().unwrap(), point), + distance: distance(&n, point), }) .collect::>>(); knn.sort(); diff --git a/src/distance.rs b/src/distance.rs index 74d2a16..64c072b 100644 --- a/src/distance.rs +++ b/src/distance.rs @@ -1,15 +1,17 @@ //! Distance metrics. +use ndarray::ArrayView1; use num_traits::{Float, Zero}; use std::ops::AddAssign; /// The type of a distance metric function. -pub type Distance = fn(&[A], &[A]) -> A; +pub type Distance = fn(&ArrayView1, &ArrayView1) -> A; /// Euclidean distance before taking a squre root. Used as a lightweight version /// of [`euclidean`] for relative comparisions. /// [`euclidean`]: #method.euclidean -pub fn euclidean_reduced(x1: &[A], x2: &[A]) -> A +#[must_use] +pub fn euclidean_reduced(x1: &ArrayView1, x2: &ArrayView1) -> A where A: Float + Zero + AddAssign, { @@ -23,7 +25,8 @@ where } /// Euclidean distance metric. -pub fn euclidean(x1: &[A], x2: &[A]) -> A +#[must_use] +pub fn euclidean(x1: &ArrayView1, x2: &ArrayView1) -> A where A: Float + Zero + AddAssign, { diff --git a/src/vantage_point_tree.rs b/src/vantage_point_tree.rs index 74d7c19..8fc9203 100644 --- a/src/vantage_point_tree.rs +++ b/src/vantage_point_tree.rs @@ -1,6 +1,6 @@ use crate::distance::{self, Distance}; use crate::ArrayError; -use ndarray::{ArrayBase, CowArray, Data, Ix2}; +use ndarray::{ArrayBase, ArrayView1, CowArray, Data, Ix1, Ix2}; use num_traits::{Float, Zero}; use std::ops::AddAssign; @@ -70,30 +70,30 @@ where /// # Examples /// /// ``` - /// use ndarray::array; + /// use ndarray::{array, aview1}; /// use petal_neighbors::{VantagePointTree, distance}; /// /// let points = array![[1., 1.], [1., 2.], [9., 9.]]; /// let tree = VantagePointTree::euclidean(points).expect("valid array"); - /// let (index, distance) = tree.query_nearest(&[8., 8.]); + /// let (index, distance) = tree.query_nearest(&aview1(&[8., 8.])); /// assert_eq!(index, 2); // points[2] is the nearest. /// assert!((2_f64.sqrt() - distance).abs() < 1e-8); /// ``` - pub fn query_nearest(&self, needle: &[A]) -> (usize, A) { + pub fn query_nearest(&self, needle: &ArrayBase) -> (usize, A) + where + S: Data, + { let mut nearest = DistanceIndex { distance: A::max_value(), id: NULL, }; - self.search_node(&self.nodes[self.root], needle, &mut nearest); + self.search_node(&self.nodes[self.root], &needle.view(), &mut nearest); (nearest.id, nearest.distance) } - fn search_node(&self, node: &Node, needle: &[A], nearest: &mut DistanceIndex) { + fn search_node(&self, node: &Node, needle: &ArrayView1, nearest: &mut DistanceIndex) { let distance = self.distance; - let distance = distance( - self.points.row(node.vantage_point).as_slice().unwrap(), - needle, - ); + let distance = distance(&self.points.row(node.vantage_point), needle); if distance < nearest.distance { nearest.distance = distance; @@ -166,10 +166,7 @@ where let rest = &mut indexes[..vp_pos]; for r in rest.iter_mut() { - r.distance = distance( - points.row(r.id).as_slice().unwrap(), - points.row(vantage_point).as_slice().unwrap(), - ); + r.distance = distance(&points.row(r.id), &points.row(vantage_point)); } rest.sort_unstable_by(|a, b| a.distance.partial_cmp(&b.distance).expect("unexpected nan")); @@ -210,7 +207,7 @@ struct DistanceIndex { #[cfg(test)] mod test { use super::*; - use ndarray::array; + use ndarray::{array, aview1}; #[test] fn euclidian() { @@ -224,6 +221,6 @@ mod test { ]; let vp = VantagePointTree::euclidean(points).expect("valid array"); - assert_eq!(vp.query_nearest(&[0.95, 1.96]).0, 0); + assert_eq!(vp.query_nearest(&aview1(&[0.95, 1.96])).0, 0); } }