From 8dd9a05767d12924f974144107396ffdf730ef65 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 28 Sep 2019 18:19:36 +0200 Subject: [PATCH 1/6] Implement missing specializations on the MergeJoinBy Iterator --- src/merge_join.rs | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/merge_join.rs b/src/merge_join.rs index 5f9a0f401..bf697af3e 100644 --- a/src/merge_join.rs +++ b/src/merge_join.rs @@ -84,4 +84,70 @@ impl Iterator for MergeJoinBy (lower, upper) } + + fn count(mut self) -> usize { + let mut count = 0; + loop { + match (self.left.next(), self.right.next()) { + (None, None) => break count, + (Some(_left), None) => break count + 1 + self.left.count(), + (None, Some(_right)) => break count + 1 + self.right.count(), + (Some(left), Some(right)) => { + count += 1; + match (self.cmp_fn)(&left, &right) { + Ordering::Equal => {} + Ordering::Less => self.right.put_back(right), + Ordering::Greater => self.left.put_back(left), + } + } + } + } + } + + fn last(mut self) -> Option { + let mut previous_element = None; + loop { + match (self.left.next(), self.right.next()) { + (None, None) => break previous_element, + (Some(left), None) => { + break Some(EitherOrBoth::Left(self.left.last().unwrap_or(left))) + } + (None, Some(right)) => { + break Some(EitherOrBoth::Right(self.right.last().unwrap_or(right))) + } + (Some(left), Some(right)) => { + previous_element = match (self.cmp_fn)(&left, &right) { + Ordering::Equal => Some(EitherOrBoth::Both(left, right)), + Ordering::Less => { + self.right.put_back(right); + Some(EitherOrBoth::Left(left)) + } + Ordering::Greater => { + self.left.put_back(left); + Some(EitherOrBoth::Right(right)) + } + } + } + } + } + } + + fn nth(&mut self, mut n: usize) -> Option { + loop { + if n == 0 { + break self.next(); + } + n -= 1; + match (self.left.next(), self.right.next()) { + (None, None) => break None, + (Some(_left), None) => break self.left.nth(n).map(EitherOrBoth::Left), + (None, Some(_right)) => break self.right.nth(n).map(EitherOrBoth::Right), + (Some(left), Some(right)) => match (self.cmp_fn)(&left, &right) { + Ordering::Equal => {} + Ordering::Less => self.right.put_back(right), + Ordering::Greater => self.left.put_back(left), + }, + } + } + } } From d73460b2e00929dd0009eb793102f63c898d80dc Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 28 Sep 2019 19:18:00 +0200 Subject: [PATCH 2/6] Bypass PutBack adaptor where possible because it doesn't specialize these methods either --- src/merge_join.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/merge_join.rs b/src/merge_join.rs index bf697af3e..cbbeb28d0 100644 --- a/src/merge_join.rs +++ b/src/merge_join.rs @@ -90,8 +90,8 @@ impl Iterator for MergeJoinBy loop { match (self.left.next(), self.right.next()) { (None, None) => break count, - (Some(_left), None) => break count + 1 + self.left.count(), - (None, Some(_right)) => break count + 1 + self.right.count(), + (Some(_left), None) => break count + 1 + self.left.into_parts().1.count(), + (None, Some(_right)) => break count + 1 + self.right.into_parts().1.count(), (Some(left), Some(right)) => { count += 1; match (self.cmp_fn)(&left, &right) { @@ -110,10 +110,14 @@ impl Iterator for MergeJoinBy match (self.left.next(), self.right.next()) { (None, None) => break previous_element, (Some(left), None) => { - break Some(EitherOrBoth::Left(self.left.last().unwrap_or(left))) + break Some(EitherOrBoth::Left( + self.left.into_parts().1.last().unwrap_or(left), + )) } (None, Some(right)) => { - break Some(EitherOrBoth::Right(self.right.last().unwrap_or(right))) + break Some(EitherOrBoth::Right( + self.right.into_parts().1.last().unwrap_or(right), + )) } (Some(left), Some(right)) => { previous_element = match (self.cmp_fn)(&left, &right) { From 33a246514ae9cf46b3ec9b84cc6a934f6cc2b971 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 28 Sep 2019 19:38:40 +0200 Subject: [PATCH 3/6] Specialize `count`, `last` and `nth` PutBack iterator for performance They were missing and did hurt the MergeJoinBy performance improvement on the `nth` method because we couldn't borrow iter without exposing new interfaces or using something like the `take_mut` crate. --- src/adaptors/mod.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/adaptors/mod.rs b/src/adaptors/mod.rs index 4d0f5e129..c9b408cbd 100644 --- a/src/adaptors/mod.rs +++ b/src/adaptors/mod.rs @@ -243,6 +243,28 @@ impl Iterator for PutBack size_hint::add_scalar(self.iter.size_hint(), self.top.is_some() as usize) } + fn count(self) -> usize { + self.iter.count() + (self.top.is_some() as usize) + } + + fn last(self) -> Option { + self.iter.last().or(self.top) + } + + fn nth(&mut self, n: usize) -> Option { + match self.top { + None => self.iter.nth(n), + ref mut some => { + if n == 0 { + some.take() + } else { + *some = None; + self.iter.nth(n - 1) + } + } + } + } + fn all(&mut self, mut f: G) -> bool where G: FnMut(Self::Item) -> bool { From b13e35fe9215b2152fd79eb681c6d70e77b8be44 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Fri, 1 Nov 2019 01:55:43 +0100 Subject: [PATCH 4/6] Specialization test framework + tests --- src/merge_join.rs | 17 ++++++ tests/specializations.rs | 119 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 tests/specializations.rs diff --git a/src/merge_join.rs b/src/merge_join.rs index cbbeb28d0..e2bbc0a87 100644 --- a/src/merge_join.rs +++ b/src/merge_join.rs @@ -31,6 +31,23 @@ pub struct MergeJoinBy { cmp_fn: F } +impl Clone for MergeJoinBy +where + I: Iterator, + J: Iterator, + PutBack>: Clone, + PutBack>: Clone, + F: Clone, +{ + fn clone(&self) -> Self { + MergeJoinBy { + left: self.left.clone(), + right: self.right.clone(), + cmp_fn: self.cmp_fn.clone(), + } + } +} + impl fmt::Debug for MergeJoinBy where I: Iterator + fmt::Debug, I::Item: fmt::Debug, diff --git a/tests/specializations.rs b/tests/specializations.rs new file mode 100644 index 000000000..86c8164d4 --- /dev/null +++ b/tests/specializations.rs @@ -0,0 +1,119 @@ +extern crate itertools; + +use itertools::{EitherOrBoth, Itertools}; + +use std::fmt::Debug; +use std::ops::BitXor; + +struct Unspecialized(I); +impl Iterator for Unspecialized +where + I: Iterator, +{ + type Item = I::Item; + + #[inline(always)] + fn next(&mut self) -> Option { + self.0.next() + } + + #[inline(always)] + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +fn check_specialized<'a, V, IterItem, Iter, F>(iterator: &Iter, mapper: F) +where + V: Eq + Debug, + IterItem: 'a, + Iter: Iterator + Clone + 'a, + F: Fn(Box + 'a>) -> V, +{ + assert_eq!( + mapper(Box::new(Unspecialized(iterator.clone()))), + mapper(Box::new(iterator.clone())) + ) +} + +fn check_specialized_count_last_nth_sizeh<'a, IterItem, Iter>(it: &Iter, expected_size: usize) +where + IterItem: 'a + Eq + Debug, + Iter: Iterator + Clone + 'a, +{ + let size = it.clone().count(); + assert_eq!(size, expected_size); + check_specialized(it, |i| i.count()); + check_specialized(it, |i| i.last()); + for n in 0..size + 2 { + check_specialized(it, |mut i| i.nth(n)); + } + let mut it_sh = it.clone(); + for n in 0..size + 2 { + let len = it_sh.clone().count(); + let (min, max) = it_sh.size_hint(); + assert_eq!((size - n.min(size)), len); + assert!(min <= len); + if let Some(max) = max { + assert!(len <= max); + } + it_sh.next(); + } +} + +fn check_specialized_fold_xor<'a, IterItem, Iter>(it: &Iter) +where + IterItem: 'a + + BitXor + + Eq + + Debug + + BitXor<::Output, Output = ::Output> + + Clone, + ::Output: + BitXor::Output> + Eq + Debug + Clone, + Iter: Iterator + Clone + 'a, +{ + check_specialized(it, |mut i| { + let first = i.next().map(|f| f.clone() ^ (f.clone() ^ f)); + i.fold(first, |acc, v: IterItem| acc.map(move |a| v ^ a)) + }); +} + +#[test] +fn put_back() { + let test_vec = vec![7, 4, 1]; + { + // Lexical lifetimes support + let pb = itertools::put_back(test_vec.iter()); + check_specialized_count_last_nth_sizeh(&pb, 3); + check_specialized_fold_xor(&pb); + } + + let mut pb = itertools::put_back(test_vec.into_iter()); + pb.put_back(1); + check_specialized_count_last_nth_sizeh(&pb, 4); + check_specialized_fold_xor(&pb); +} + +#[test] +fn merge_join_by() { + let i1 = vec![1, 3, 5].into_iter(); + let i2 = vec![0, 3, 4, 5].into_iter(); + let mjb = i1.merge_join_by(i2, std::cmp::Ord::cmp); + check_specialized_count_last_nth_sizeh(&mjb, 5); + // Rust 1.24 compatibility: + fn eob_left_z(eob: EitherOrBoth) -> usize { + eob.left().unwrap_or(0) + } + fn eob_right_z(eob: EitherOrBoth) -> usize { + eob.left().unwrap_or(0) + } + fn eob_both_z(eob: EitherOrBoth) -> usize { + let (a, b) = eob.both().unwrap_or((0, 0)); + assert_eq!(a, b); + a + } + check_specialized_fold_xor(&mjb.clone().map(eob_left_z)); + check_specialized_fold_xor(&mjb.clone().map(eob_right_z)); + check_specialized_fold_xor(&mjb.clone().map(eob_both_z)); +} From a88946d0833aaf82e4c7c66912239d229d9c1406 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Fri, 1 Nov 2019 12:40:24 +0100 Subject: [PATCH 5/6] More exhaustive MergeJoinBy specialization test --- tests/specializations.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/specializations.rs b/tests/specializations.rs index 86c8164d4..223de28ee 100644 --- a/tests/specializations.rs +++ b/tests/specializations.rs @@ -97,10 +97,10 @@ fn put_back() { #[test] fn merge_join_by() { - let i1 = vec![1, 3, 5].into_iter(); + let i1 = vec![1, 3, 5, 7, 8, 9].into_iter(); let i2 = vec![0, 3, 4, 5].into_iter(); - let mjb = i1.merge_join_by(i2, std::cmp::Ord::cmp); - check_specialized_count_last_nth_sizeh(&mjb, 5); + let mjb = i1.clone().merge_join_by(i2.clone(), std::cmp::Ord::cmp); + check_specialized_count_last_nth_sizeh(&mjb, 8); // Rust 1.24 compatibility: fn eob_left_z(eob: EitherOrBoth) -> usize { eob.left().unwrap_or(0) @@ -115,5 +115,12 @@ fn merge_join_by() { } check_specialized_fold_xor(&mjb.clone().map(eob_left_z)); check_specialized_fold_xor(&mjb.clone().map(eob_right_z)); + check_specialized_fold_xor(&mjb.clone().map(eob_both_z)); + + // And the other way around + let mjb = i2.merge_join_by(i1, std::cmp::Ord::cmp); + check_specialized_count_last_nth_sizeh(&mjb, 8); + check_specialized_fold_xor(&mjb.clone().map(eob_left_z)); + check_specialized_fold_xor(&mjb.clone().map(eob_right_z)); check_specialized_fold_xor(&mjb.clone().map(eob_both_z)); } From a0eb01c2164320d5e044c3babff1e7ad89b6099b Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Thu, 26 Dec 2019 14:22:35 +0100 Subject: [PATCH 6/6] Added quickchecks --- tests/specializations.rs | 58 +++++++++++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/tests/specializations.rs b/tests/specializations.rs index 223de28ee..87e6b2b26 100644 --- a/tests/specializations.rs +++ b/tests/specializations.rs @@ -1,5 +1,8 @@ extern crate itertools; +#[macro_use] +extern crate quickcheck; + use itertools::{EitherOrBoth, Itertools}; use std::fmt::Debug; @@ -36,13 +39,17 @@ where ) } -fn check_specialized_count_last_nth_sizeh<'a, IterItem, Iter>(it: &Iter, expected_size: usize) -where +fn check_specialized_count_last_nth_sizeh<'a, IterItem, Iter>( + it: &Iter, + known_expected_size: Option, +) where IterItem: 'a + Eq + Debug, Iter: Iterator + Clone + 'a, { let size = it.clone().count(); - assert_eq!(size, expected_size); + if let Some(expected_size) = known_expected_size { + assert_eq!(size, expected_size); + } check_specialized(it, |i| i.count()); check_specialized(it, |i| i.last()); for n in 0..size + 2 { @@ -79,28 +86,36 @@ where }); } -#[test] -fn put_back() { - let test_vec = vec![7, 4, 1]; +fn put_back_test(test_vec: Vec, known_expected_size: Option) { { // Lexical lifetimes support let pb = itertools::put_back(test_vec.iter()); - check_specialized_count_last_nth_sizeh(&pb, 3); + check_specialized_count_last_nth_sizeh(&pb, known_expected_size); check_specialized_fold_xor(&pb); } let mut pb = itertools::put_back(test_vec.into_iter()); pb.put_back(1); - check_specialized_count_last_nth_sizeh(&pb, 4); - check_specialized_fold_xor(&pb); + check_specialized_count_last_nth_sizeh(&pb, known_expected_size.map(|x| x + 1)); + check_specialized_fold_xor(&pb) } #[test] -fn merge_join_by() { - let i1 = vec![1, 3, 5, 7, 8, 9].into_iter(); - let i2 = vec![0, 3, 4, 5].into_iter(); +fn put_back() { + put_back_test(vec![7, 4, 1], Some(3)); +} + +quickcheck! { + fn put_back_qc(test_vec: Vec) -> () { + put_back_test(test_vec, None) + } +} + +fn merge_join_by_test(i1: Vec, i2: Vec, known_expected_size: Option) { + let i1 = i1.into_iter(); + let i2 = i2.into_iter(); let mjb = i1.clone().merge_join_by(i2.clone(), std::cmp::Ord::cmp); - check_specialized_count_last_nth_sizeh(&mjb, 8); + check_specialized_count_last_nth_sizeh(&mjb, known_expected_size); // Rust 1.24 compatibility: fn eob_left_z(eob: EitherOrBoth) -> usize { eob.left().unwrap_or(0) @@ -117,10 +132,23 @@ fn merge_join_by() { check_specialized_fold_xor(&mjb.clone().map(eob_right_z)); check_specialized_fold_xor(&mjb.clone().map(eob_both_z)); - // And the other way around + // And the other way around let mjb = i2.merge_join_by(i1, std::cmp::Ord::cmp); - check_specialized_count_last_nth_sizeh(&mjb, 8); + check_specialized_count_last_nth_sizeh(&mjb, known_expected_size); check_specialized_fold_xor(&mjb.clone().map(eob_left_z)); check_specialized_fold_xor(&mjb.clone().map(eob_right_z)); check_specialized_fold_xor(&mjb.clone().map(eob_both_z)); } + +#[test] +fn merge_join_by() { + let i1 = vec![1, 3, 5, 7, 8, 9]; + let i2 = vec![0, 3, 4, 5]; + merge_join_by_test(i1, i2, Some(8)); +} + +quickcheck! { + fn merge_join_by_qc(i1: Vec, i2: Vec) -> () { + merge_join_by_test(i1, i2, None) + } +}