Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added get_or_insert_with function #718

Merged
merged 6 commits into from Jul 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 25 additions & 4 deletions crossbeam-skiplist/src/base.rs
Expand Up @@ -471,6 +471,21 @@ where

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist.
pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
self.insert_internal(key, || value, false, guard)
}

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
/// where value is calculated with a function.
///
///
/// <b>Note:</b> Another thread may write key value first, leading to the result of this closure
/// discarded. If closure is modifying some other state (such as shared counters or shared
/// objects), it may lead to <u>undesired behaviour</u> such as counters being changed without
/// result of closure inserted
pub fn get_or_insert_with<F>(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V>
where
F: FnOnce() -> V,
{
self.insert_internal(key, value, false, guard)
}

Expand Down Expand Up @@ -831,13 +846,16 @@ where
/// Inserts an entry with the specified `key` and `value`.
///
/// If `replace` is `true`, then any existing entry with this key will first be removed.
fn insert_internal(
fn insert_internal<F>(
&self,
key: K,
value: V,
value: F,
replace: bool,
guard: &Guard,
) -> RefEntry<'_, K, V> {
) -> RefEntry<'_, K, V>
where
F: FnOnce() -> V,
{
self.check_guard(guard);

unsafe {
Expand Down Expand Up @@ -876,6 +894,9 @@ where
}
}

// create value before creating node, so extra allocation doesn't happen if value() function panics
let value = value();

// Create a new node.
let height = self.random_height();
let (node, n) = {
Expand Down Expand Up @@ -1061,7 +1082,7 @@ where
/// If there is an existing entry with this key, it will be removed before inserting the new
/// one.
pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> {
self.insert_internal(key, value, true, guard)
self.insert_internal(key, || value, true, guard)
}

/// Removes an entry with the specified `key` from the map and returns it.
Expand Down
33 changes: 33 additions & 0 deletions crossbeam-skiplist/src/map.rs
Expand Up @@ -254,6 +254,39 @@ where
Entry::new(self.inner.get_or_insert(key, value, guard))
}

/// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist,
/// where value is calculated with a function.
///
///
/// <b>Note:</b> Another thread may write key value first, leading to the result of this closure
/// discarded. If closure is modifying some other state (such as shared counters or shared
/// objects), it may lead to <u>undesired behaviour</u> such as counters being changed without
/// result of closure inserted
////
/// This function returns an [`Entry`] which
/// can be used to access the key's associated value.
///
///
/// # Example
/// ```
/// use crossbeam_skiplist::SkipMap;
///
/// let ages = SkipMap::new();
/// let gates_age = ages.get_or_insert_with("Bill Gates", || 64);
/// assert_eq!(*gates_age.value(), 64);
///
/// ages.insert("Steve Jobs", 65);
/// let jobs_age = ages.get_or_insert_with("Steve Jobs", || -1);
/// assert_eq!(*jobs_age.value(), 65);
/// ```
pub fn get_or_insert_with<F>(&self, key: K, value_fn: F) -> Entry<'_, K, V>
where
F: FnOnce() -> V,
{
let guard = &epoch::pin();
Entry::new(self.inner.get_or_insert_with(key, value_fn, guard))
}

/// Returns an iterator over all entries in the map,
/// sorted by key.
///
Expand Down
70 changes: 70 additions & 0 deletions crossbeam-skiplist/tests/base.rs
Expand Up @@ -431,6 +431,76 @@ fn get_or_insert() {
assert_eq!(*s.get_or_insert(6, 600, guard).value(), 600);
}

#[test]
fn get_or_insert_with() {
let guard = &epoch::pin();
let s = SkipList::new(epoch::default_collector().clone());
s.insert(3, 3, guard);
s.insert(5, 5, guard);
s.insert(1, 1, guard);
s.insert(4, 4, guard);
s.insert(2, 2, guard);

assert_eq!(*s.get(&4, guard).unwrap().value(), 4);
assert_eq!(*s.insert(4, 40, guard).value(), 40);
assert_eq!(*s.get(&4, guard).unwrap().value(), 40);

assert_eq!(*s.get_or_insert_with(4, || 400, guard).value(), 40);
assert_eq!(*s.get(&4, guard).unwrap().value(), 40);
assert_eq!(*s.get_or_insert_with(6, || 600, guard).value(), 600);
}

#[test]
fn get_or_insert_with_panic() {
use std::panic;

let s = SkipList::new(epoch::default_collector().clone());
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
let guard = &epoch::pin();
s.get_or_insert_with(4, || panic!(), guard);
}));
assert!(res.is_err());
assert!(s.is_empty());
let guard = &epoch::pin();
assert_eq!(*s.get_or_insert_with(4, || 40, guard).value(), 40);
assert_eq!(s.len(), 1);
}

#[test]
fn get_or_insert_with_parallel_run() {
use std::sync::{Arc, Mutex};

let s = Arc::new(SkipList::new(epoch::default_collector().clone()));
let s2 = s.clone();
let called = Arc::new(Mutex::new(false));
let called2 = called.clone();
let handle = std::thread::spawn(move || {
let guard = &epoch::pin();
assert_eq!(
*s2.get_or_insert_with(
7,
|| {
*called2.lock().unwrap() = true;

// allow main thread to run before we return result
std::thread::sleep(std::time::Duration::from_secs(4));
70
},
guard,
)
.value(),
700
);
});
std::thread::sleep(std::time::Duration::from_secs(2));
let guard = &epoch::pin();

// main thread writes the value first
assert_eq!(*s.get_or_insert(7, 700, guard).value(), 700);
handle.join().unwrap();
assert!(*called.lock().unwrap());
}

#[test]
fn get_next_prev() {
let guard = &epoch::pin();
Expand Down
61 changes: 61 additions & 0 deletions crossbeam-skiplist/tests/map.rs
Expand Up @@ -370,6 +370,67 @@ fn get_or_insert() {
assert_eq!(*s.get_or_insert(6, 600).value(), 600);
}

#[test]
fn get_or_insert_with() {
let s = SkipMap::new();
s.insert(3, 3);
s.insert(5, 5);
s.insert(1, 1);
s.insert(4, 4);
s.insert(2, 2);

assert_eq!(*s.get(&4).unwrap().value(), 4);
assert_eq!(*s.insert(4, 40).value(), 40);
assert_eq!(*s.get(&4).unwrap().value(), 40);

assert_eq!(*s.get_or_insert_with(4, || 400).value(), 40);
assert_eq!(*s.get(&4).unwrap().value(), 40);
assert_eq!(*s.get_or_insert_with(6, || 600).value(), 600);
}

#[test]
fn get_or_insert_with_panic() {
use std::panic;

let s = SkipMap::new();
let res = panic::catch_unwind(panic::AssertUnwindSafe(|| {
s.get_or_insert_with(4, || panic!());
}));
assert!(res.is_err());
assert!(s.is_empty());
assert_eq!(*s.get_or_insert_with(4, || 40).value(), 40);
assert_eq!(s.len(), 1);
}

#[test]
fn get_or_insert_with_parallel_run() {
use std::sync::{Arc, Mutex};

let s = Arc::new(SkipMap::new());
let s2 = s.clone();
let called = Arc::new(Mutex::new(false));
let called2 = called.clone();
let handle = std::thread::spawn(move || {
assert_eq!(
*s2.get_or_insert_with(7, || {
*called2.lock().unwrap() = true;

// allow main thread to run before we return result
std::thread::sleep(std::time::Duration::from_secs(4));
70
})
.value(),
700
);
});
std::thread::sleep(std::time::Duration::from_secs(2));

// main thread writes the value first
assert_eq!(*s.get_or_insert(7, 700).value(), 700);
handle.join().unwrap();
assert!(*called.lock().unwrap());
}

#[test]
fn get_next_prev() {
let s = SkipMap::new();
Expand Down