Skip to content

Commit

Permalink
Replace DrainFilter with ExtractIf (#341)
Browse files Browse the repository at this point in the history
* Replace `DrainFilter` with `ExtractIf`

* Use `core::iter::FromIterator` instead of `std::iter::FromIterator`
  • Loading branch information
NaokiM03 committed Mar 5, 2024
1 parent defe74d commit 5522939
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 148 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Expand Up @@ -16,8 +16,7 @@ documentation = "https://docs.rs/smallvec/"
write = []
specialization = []
may_dangle = []
drain_filter = []
drain_keep_rest = ["drain_filter"]
extract_if = []

[dependencies]
serde = { version = "1", optional = true, default-features = false }
Expand Down
172 changes: 44 additions & 128 deletions src/lib.rs
Expand Up @@ -26,11 +26,11 @@
//! When this feature is enabled, `SmallVec<u8, _>` implements the `std::io::Write` trait.
//! This feature is not compatible with `#![no_std]` programs.
//!
//! ### `drain_filter`
//! ### `extract_if`
//!
//! **This feature is unstable.** It may change to match the unstable `drain_filter` method in libstd.
//! **This feature is unstable.** It may change to match the unstable `extract_if` method in libstd.
//!
//! Enables the `drain_filter` method, which produces an iterator that calls a user-provided
//! Enables the `extract_if` method, which produces an iterator that calls a user-provided
//! closure to determine which elements of the vector to remove and yield from the iterator.
//!
//! ### `specialization`
Expand Down Expand Up @@ -380,13 +380,13 @@ impl<'a, T: 'a, const N: usize> Drop for Drain<'a, T, N> {
}
}

#[cfg(feature = "drain_filter")]
#[cfg(feature = "extract_if")]
/// An iterator which uses a closure to determine if an element should be removed.
///
/// Returned from [`SmallVec::drain_filter`][1].
/// Returned from [`SmallVec::extract_if`][1].
///
/// [1]: struct.SmallVec.html#method.drain_filter
pub struct DrainFilter<'a, T, const N: usize, F>
/// [1]: struct.SmallVec.html#method.extract_if
pub struct ExtractIf<'a, T, const N: usize, F>
where
F: FnMut(&mut T) -> bool,
{
Expand All @@ -399,29 +399,23 @@ where
old_len: usize,
/// The filter test predicate.
pred: F,
/// A flag that indicates a panic has occurred in the filter test predicate.
/// This is used as a hint in the drop implementation to prevent consumption
/// of the remainder of the `DrainFilter`. Any unprocessed items will be
/// backshifted in the `vec`, but no further items will be dropped or
/// tested by the filter predicate.
panic_flag: bool,
}

#[cfg(feature = "drain_filter")]
impl<T, const N: usize, F> core::fmt::Debug for DrainFilter<'_, T, N, F>
#[cfg(feature = "extract_if")]
impl<T, const N: usize, F> core::fmt::Debug for ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
T: core::fmt::Debug,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("DrainFilter")
f.debug_tuple("ExtractIf")
.field(&self.vec.as_slice())
.finish()
}
}

#[cfg(feature = "drain_filter")]
impl<T, F, const N: usize> Iterator for DrainFilter<'_, T, N, F>
#[cfg(feature = "extract_if")]
impl<T, F, const N: usize> Iterator for ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
Expand All @@ -432,9 +426,7 @@ where
while self.idx < self.old_len {
let i = self.idx;
let v = core::slice::from_raw_parts_mut(self.vec.as_mut_ptr(), self.old_len);
self.panic_flag = true;
let drained = (self.pred)(&mut v[i]);
self.panic_flag = false;
// Update the index *after* the predicate is called. If the index
// is updated prior and the predicate panics, the element at this
// index would be leaked.
Expand All @@ -444,8 +436,8 @@ where
return Some(core::ptr::read(&v[i]));
} else if self.del > 0 {
let del = self.del;
let src: *const Self::Item = &v[i];
let dst: *mut Self::Item = &mut v[i - del];
let src: *const T = &v[i];
let dst: *mut T = &mut v[i - del];
core::ptr::copy_nonoverlapping(src, dst, 1);
}
}
Expand All @@ -458,109 +450,27 @@ where
}
}

#[cfg(feature = "drain_filter")]
impl<T, F, const N: usize> Drop for DrainFilter<'_, T, N, F>
#[cfg(feature = "extract_if")]
impl<T, F, const N: usize> Drop for ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
fn drop(&mut self) {
struct BackshiftOnDrop<'a, 'b, T, const N: usize, F>
where
F: FnMut(&mut T) -> bool,
{
drain: &'b mut DrainFilter<'a, T, N, F>,
}

impl<'a, 'b, T, const N: usize, F> Drop for BackshiftOnDrop<'a, 'b, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
fn drop(&mut self) {
unsafe {
if self.drain.idx < self.drain.old_len && self.drain.del > 0 {
// This is a pretty messed up state, and there isn't really an
// obviously right thing to do. We don't want to keep trying
// to execute `pred`, so we just backshift all the unprocessed
// elements and tell the vec that they still exist. The backshift
// is required to prevent a double-drop of the last successfully
// drained item prior to a panic in the predicate.
let ptr = self.drain.vec.as_mut_ptr();
let src = ptr.add(self.drain.idx);
let dst = src.sub(self.drain.del);
let tail_len = self.drain.old_len - self.drain.idx;
src.copy_to(dst, tail_len);
}
self.drain.vec.set_len(self.drain.old_len - self.drain.del);
}
}
}

let backshift = BackshiftOnDrop { drain: self };

// Attempt to consume any remaining elements if the filter predicate
// has not yet panicked. We'll backshift any remaining elements
// whether we've already panicked or if the consumption here panics.
if !backshift.drain.panic_flag {
backshift.drain.for_each(drop);
}
}
}

#[cfg(feature = "drain_keep_rest")]
impl<T, F, const N: usize> DrainFilter<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
/// Keep unyielded elements in the source `Vec`.
///
/// # Examples
///
/// ```
/// # use smallvec::{smallvec, SmallVec};
///
/// let mut vec: SmallVec<char, 2> = smallvec!['a', 'b', 'c'];
/// let mut drain = vec.drain_filter(|_| true);
///
/// assert_eq!(drain.next().unwrap(), 'a');
///
/// // This call keeps 'b' and 'c' in the vec.
/// drain.keep_rest();
///
/// // If we wouldn't call `keep_rest()`,
/// // `vec` would be empty.
/// assert_eq!(vec, SmallVec::<char, 2>::from_slice(&['b', 'c']));
/// ```
pub fn keep_rest(self) {
// At this moment layout looks like this:
//
// _____________________/-- old_len
// / \
// [kept] [yielded] [tail]
// \_______/ ^-- idx
// \-- del
//
// Normally `Drop` impl would drop [tail] (via .for_each(drop), ie still calling `pred`)
//
// 1. Move [tail] after [kept]
// 2. Update length of the original vec to `old_len - del`
// a. In case of ZST, this is the only thing we want to do
// 3. Do *not* drop self, as everything is put in a consistent state already, there is nothing to do
let mut this = ManuallyDrop::new(self);

unsafe {
// ZSTs have no identity, so we don't need to move them around.
let needs_move = core::mem::size_of::<T>() != 0;

if needs_move && this.idx < this.old_len && this.del > 0 {
let ptr = this.vec.as_mut_ptr();
let src = ptr.add(this.idx);
let dst = src.sub(this.del);
let tail_len = this.old_len - this.idx;
if self.idx < self.old_len && self.del > 0 {
// This is a pretty messed up state, and there isn't really an
// obviously right thing to do. We don't want to keep trying
// to execute `pred`, so we just backshift all the unprocessed
// elements and tell the vec that they still exist. The backshift
// is required to prevent a double-drop of the last successfully
// drained item prior to a panic in the predicate.
let ptr = self.vec.as_mut_ptr();
let src = ptr.add(self.idx);
let dst = src.sub(self.del);
let tail_len = self.old_len - self.idx;
src.copy_to(dst, tail_len);
}

let new_len = this.old_len - this.del;
this.vec.set_len(new_len);
self.vec.set_len(self.old_len - self.del);
}
}
}
Expand Down Expand Up @@ -961,11 +871,18 @@ impl<T, const N: usize> SmallVec<T, N> {
}
}

#[cfg(feature = "drain_filter")]
#[cfg(feature = "extract_if")]
/// Creates an iterator which uses a closure to determine if an element should be removed.
///
/// If the closure returns true, the element is removed and yielded. If the closure returns
/// false, the element will remain in the vector and will not be yielded by the iterator.
/// If the closure returns true, the element is removed and yielded.
/// If the closure returns false, the element will remain in the vector and will not be yielded
/// by the iterator.
///
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
///
/// [`retain`]: SmallVec::retain
///
/// Using this method is equivalent to the following code:
/// ```
Expand All @@ -984,11 +901,11 @@ impl<T, const N: usize> SmallVec<T, N> {
///
/// # assert_eq!(vec, SmallVec::<i32, 8>::from_slice(&[1i32, 4, 5]));
/// ```
/// ///
/// But `drain_filter` is easier to use. `drain_filter` is also more efficient,
///
/// But `extract_if` is easier to use. `extract_if` is also more efficient,
/// because it can backshift the elements of the array in bulk.
///
/// Note that `drain_filter` also lets you mutate every element in the filter closure,
/// Note that `extract_if` also lets you mutate every element in the filter closure,
/// regardless of whether you choose to keep or remove it.
///
/// # Examples
Expand All @@ -999,13 +916,13 @@ impl<T, const N: usize> SmallVec<T, N> {
/// # use smallvec::SmallVec;
/// let mut numbers: SmallVec<i32, 16> = SmallVec::from_slice(&[1i32, 2, 3, 4, 5, 6, 8, 9, 11, 13, 14, 15]);
///
/// let evens = numbers.drain_filter(|x| *x % 2 == 0).collect::<SmallVec<i32, 16>>();
/// let evens = numbers.extract_if(|x| *x % 2 == 0).collect::<SmallVec<i32, 16>>();
/// let odds = numbers;
///
/// assert_eq!(evens, SmallVec::<i32, 16>::from_slice(&[2i32, 4, 6, 8, 14]));
/// assert_eq!(odds, SmallVec::<i32, 16>::from_slice(&[1i32, 3, 5, 9, 11, 13, 15]));
/// ```
pub fn drain_filter<F>(&mut self, filter: F) -> DrainFilter<'_, T, N, F>
pub fn extract_if<F>(&mut self, filter: F) -> ExtractIf<'_, T, N, F>
where
F: FnMut(&mut T) -> bool,
{
Expand All @@ -1016,13 +933,12 @@ impl<T, const N: usize> SmallVec<T, N> {
self.set_len(0);
}

DrainFilter {
ExtractIf {
vec: self,
idx: 0,
del: 0,
old_len,
pred: filter,
panic_flag: false,
}
}

Expand Down
22 changes: 4 additions & 18 deletions src/tests.rs
@@ -1,6 +1,6 @@
use crate::{smallvec, SmallVec};

use std::iter::FromIterator;
use core::iter::FromIterator;

use alloc::borrow::ToOwned;
use alloc::boxed::Box;
Expand Down Expand Up @@ -1060,27 +1060,13 @@ fn test_clone_from() {
assert_eq!(&*b, &[20, 21, 22]);
}

#[cfg(feature = "drain_filter")]
#[cfg(feature = "extract_if")]
#[test]
fn drain_filter() {
fn test_extract_if() {
let mut a: SmallVec<u8, 2> = smallvec![1u8, 2, 3, 4, 5, 6, 7, 8];

let b: SmallVec<u8, 2> = a.drain_filter(|x| *x % 3 == 0).collect();
let b: SmallVec<u8, 2> = a.extract_if(|x| *x % 3 == 0).collect();

assert_eq!(a, SmallVec::<u8, 2>::from_slice(&[1u8, 2, 4, 5, 7, 8]));
assert_eq!(b, SmallVec::<u8, 2>::from_slice(&[3u8, 6]));
}

#[cfg(feature = "drain_keep_rest")]
#[test]
fn drain_keep_rest() {
let mut a: SmallVec<i32, 3> = smallvec![1i32, 2, 3, 4, 5, 6, 7, 8];
let mut df = a.drain_filter(|x| *x % 2 == 0);

assert_eq!(df.next().unwrap(), 2);
assert_eq!(df.next().unwrap(), 4);

df.keep_rest();

assert_eq!(a, SmallVec::<i32, 3>::from_slice(&[1i32, 3, 5, 6, 7, 8]));
}

0 comments on commit 5522939

Please sign in to comment.