Skip to content

Commit

Permalink
Merge #372
Browse files Browse the repository at this point in the history
372: Implement missing specializations on the PutBack adaptor and on the MergeJoinBy Iterator r=jswrenn a=Ten0

Resolves #371

`count`, `last` and `nth` of the `MergeJoinBy` iterator are made faster when one of the iterators is completely consumed by directly calling the methods of the underlying only iterator left (there is benefit if the underlying iterator also specialized these methods).

This is in particular useful when you want to count the number of different elements in the union of two sorted known-size iterators (`count`).

Those methods are also specialized on the `PutBack` adaptor for the same performance reasons.
The `nth` specialization on the `MergeJoinBy` iterator depends on the `nth` specialization on the `PutBack` adaptor working.

Co-authored-by: Thomas BESSOU <thomas.bessou@hotmail.fr>
  • Loading branch information
bors[bot] and Ten0 committed Feb 21, 2020
2 parents c620ae8 + 8bae261 commit e7ebc13
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 4 deletions.
22 changes: 22 additions & 0 deletions src/adaptors/mod.rs
Expand Up @@ -233,6 +233,28 @@ impl<I> Iterator for PutBack<I>
size_hint::add_scalar(self.iter.size_hint(), self.top.is_some() as usize)
}

fn count(self) -> usize {
self.iter.count() + (self.top.is_some() as usize)
}

fn last(self) -> Option<Self::Item> {
self.iter.last().or(self.top)
}

fn nth(&mut self, n: usize) -> Option<Self::Item> {
match self.top {
None => self.iter.nth(n),
ref mut some => {
if n == 0 {
some.take()
} else {
*some = None;
self.iter.nth(n - 1)
}
}
}
}

fn all<G>(&mut self, mut f: G) -> bool
where G: FnMut(Self::Item) -> bool
{
Expand Down
78 changes: 74 additions & 4 deletions src/merge_join.rs
Expand Up @@ -32,10 +32,10 @@ pub struct MergeJoinBy<I: Iterator, J: Iterator, F> {
}

impl<I, J, F> Clone for MergeJoinBy<I, J, F>
where I: Clone + Iterator,
I::Item: Clone,
J: Clone + Iterator,
J::Item: Clone,
where I: Iterator,
J: Iterator,
PutBack<Fuse<I>>: Clone,
PutBack<Fuse<J>>: Clone,
F: Clone,
{
clone_fields!(left, right, cmp_fn);
Expand Down Expand Up @@ -94,4 +94,74 @@ impl<I, J, F> Iterator for MergeJoinBy<I, J, F>

(lower, upper)
}

fn count(mut self) -> usize {
let mut count = 0;
loop {
match (self.left.next(), self.right.next()) {
(None, None) => break count,
(Some(_left), None) => break count + 1 + self.left.into_parts().1.count(),
(None, Some(_right)) => break count + 1 + self.right.into_parts().1.count(),
(Some(left), Some(right)) => {
count += 1;
match (self.cmp_fn)(&left, &right) {
Ordering::Equal => {}
Ordering::Less => self.right.put_back(right),
Ordering::Greater => self.left.put_back(left),
}
}
}
}
}

fn last(mut self) -> Option<Self::Item> {
let mut previous_element = None;
loop {
match (self.left.next(), self.right.next()) {
(None, None) => break previous_element,
(Some(left), None) => {
break Some(EitherOrBoth::Left(
self.left.into_parts().1.last().unwrap_or(left),
))
}
(None, Some(right)) => {
break Some(EitherOrBoth::Right(
self.right.into_parts().1.last().unwrap_or(right),
))
}
(Some(left), Some(right)) => {
previous_element = match (self.cmp_fn)(&left, &right) {
Ordering::Equal => Some(EitherOrBoth::Both(left, right)),
Ordering::Less => {
self.right.put_back(right);
Some(EitherOrBoth::Left(left))
}
Ordering::Greater => {
self.left.put_back(left);
Some(EitherOrBoth::Right(right))
}
}
}
}
}
}

fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
loop {
if n == 0 {
break self.next();
}
n -= 1;
match (self.left.next(), self.right.next()) {
(None, None) => break None,
(Some(_left), None) => break self.left.nth(n).map(EitherOrBoth::Left),
(None, Some(_right)) => break self.right.nth(n).map(EitherOrBoth::Right),
(Some(left), Some(right)) => match (self.cmp_fn)(&left, &right) {
Ordering::Equal => {}
Ordering::Less => self.right.put_back(right),
Ordering::Greater => self.left.put_back(left),
},
}
}
}
}
154 changes: 154 additions & 0 deletions tests/specializations.rs
@@ -0,0 +1,154 @@
extern crate itertools;

#[macro_use]
extern crate quickcheck;

use itertools::{EitherOrBoth, Itertools};

use std::fmt::Debug;
use std::ops::BitXor;

struct Unspecialized<I>(I);
impl<I> Iterator for Unspecialized<I>
where
I: Iterator,
{
type Item = I::Item;

#[inline(always)]
fn next(&mut self) -> Option<I::Item> {
self.0.next()
}

#[inline(always)]
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}

fn check_specialized<'a, V, IterItem, Iter, F>(iterator: &Iter, mapper: F)
where
V: Eq + Debug,
IterItem: 'a,
Iter: Iterator<Item = IterItem> + Clone + 'a,
F: Fn(Box<Iterator<Item = IterItem> + 'a>) -> V,
{
assert_eq!(
mapper(Box::new(Unspecialized(iterator.clone()))),
mapper(Box::new(iterator.clone()))
)
}

fn check_specialized_count_last_nth_sizeh<'a, IterItem, Iter>(
it: &Iter,
known_expected_size: Option<usize>,
) where
IterItem: 'a + Eq + Debug,
Iter: Iterator<Item = IterItem> + Clone + 'a,
{
let size = it.clone().count();
if let Some(expected_size) = known_expected_size {
assert_eq!(size, expected_size);
}
check_specialized(it, |i| i.count());
check_specialized(it, |i| i.last());
for n in 0..size + 2 {
check_specialized(it, |mut i| i.nth(n));
}
let mut it_sh = it.clone();
for n in 0..size + 2 {
let len = it_sh.clone().count();
let (min, max) = it_sh.size_hint();
assert_eq!((size - n.min(size)), len);
assert!(min <= len);
if let Some(max) = max {
assert!(len <= max);
}
it_sh.next();
}
}

fn check_specialized_fold_xor<'a, IterItem, Iter>(it: &Iter)
where
IterItem: 'a
+ BitXor
+ Eq
+ Debug
+ BitXor<<IterItem as BitXor>::Output, Output = <IterItem as BitXor>::Output>
+ Clone,
<IterItem as BitXor>::Output:
BitXor<Output = <IterItem as BitXor>::Output> + Eq + Debug + Clone,
Iter: Iterator<Item = IterItem> + Clone + 'a,
{
check_specialized(it, |mut i| {
let first = i.next().map(|f| f.clone() ^ (f.clone() ^ f));
i.fold(first, |acc, v: IterItem| acc.map(move |a| v ^ a))
});
}

fn put_back_test(test_vec: Vec<i32>, known_expected_size: Option<usize>) {
{
// Lexical lifetimes support
let pb = itertools::put_back(test_vec.iter());
check_specialized_count_last_nth_sizeh(&pb, known_expected_size);
check_specialized_fold_xor(&pb);
}

let mut pb = itertools::put_back(test_vec.into_iter());
pb.put_back(1);
check_specialized_count_last_nth_sizeh(&pb, known_expected_size.map(|x| x + 1));
check_specialized_fold_xor(&pb)
}

#[test]
fn put_back() {
put_back_test(vec![7, 4, 1], Some(3));
}

quickcheck! {
fn put_back_qc(test_vec: Vec<i32>) -> () {
put_back_test(test_vec, None)
}
}

fn merge_join_by_test(i1: Vec<usize>, i2: Vec<usize>, known_expected_size: Option<usize>) {
let i1 = i1.into_iter();
let i2 = i2.into_iter();
let mjb = i1.clone().merge_join_by(i2.clone(), std::cmp::Ord::cmp);
check_specialized_count_last_nth_sizeh(&mjb, known_expected_size);
// Rust 1.24 compatibility:
fn eob_left_z(eob: EitherOrBoth<usize, usize>) -> usize {
eob.left().unwrap_or(0)
}
fn eob_right_z(eob: EitherOrBoth<usize, usize>) -> usize {
eob.left().unwrap_or(0)
}
fn eob_both_z(eob: EitherOrBoth<usize, usize>) -> usize {
let (a, b) = eob.both().unwrap_or((0, 0));
assert_eq!(a, b);
a
}
check_specialized_fold_xor(&mjb.clone().map(eob_left_z));
check_specialized_fold_xor(&mjb.clone().map(eob_right_z));
check_specialized_fold_xor(&mjb.clone().map(eob_both_z));

// And the other way around
let mjb = i2.merge_join_by(i1, std::cmp::Ord::cmp);
check_specialized_count_last_nth_sizeh(&mjb, known_expected_size);
check_specialized_fold_xor(&mjb.clone().map(eob_left_z));
check_specialized_fold_xor(&mjb.clone().map(eob_right_z));
check_specialized_fold_xor(&mjb.clone().map(eob_both_z));
}

#[test]
fn merge_join_by() {
let i1 = vec![1, 3, 5, 7, 8, 9];
let i2 = vec![0, 3, 4, 5];
merge_join_by_test(i1, i2, Some(8));
}

quickcheck! {
fn merge_join_by_qc(i1: Vec<usize>, i2: Vec<usize>) -> () {
merge_join_by_test(i1, i2, None)
}
}

0 comments on commit e7ebc13

Please sign in to comment.