diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs index 862a1ae3d..826e8188c 100644 --- a/crossbeam-skiplist/src/base.rs +++ b/crossbeam-skiplist/src/base.rs @@ -471,6 +471,21 @@ where /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist. pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> { + self.insert_internal(key, || value, false, guard) + } + + /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, + /// where value is calculated with a function. + /// + /// + /// Note: Another thread may write key value first, leading to the result of this closure + /// discarded. If closure is modifying some other state (such as shared counters or shared + /// objects), it may lead to undesired behaviour such as counters being changed without + /// result of closure inserted + pub fn get_or_insert_with(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V> + where + F: FnOnce() -> V, + { self.insert_internal(key, value, false, guard) } @@ -831,13 +846,16 @@ where /// Inserts an entry with the specified `key` and `value`. /// /// If `replace` is `true`, then any existing entry with this key will first be removed. - fn insert_internal( + fn insert_internal( &self, key: K, - value: V, + value: F, replace: bool, guard: &Guard, - ) -> RefEntry<'_, K, V> { + ) -> RefEntry<'_, K, V> + where + F: FnOnce() -> V, + { self.check_guard(guard); unsafe { @@ -876,6 +894,9 @@ where } } + // create value before creating node, so extra allocation doesn't happen if value() function panics + let value = value(); + // Create a new node. let height = self.random_height(); let (node, n) = { @@ -1061,7 +1082,7 @@ where /// If there is an existing entry with this key, it will be removed before inserting the new /// one. pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> { - self.insert_internal(key, value, true, guard) + self.insert_internal(key, || value, true, guard) } /// Removes an entry with the specified `key` from the map and returns it. diff --git a/crossbeam-skiplist/src/map.rs b/crossbeam-skiplist/src/map.rs index b035d1fc1..6beb3fb91 100644 --- a/crossbeam-skiplist/src/map.rs +++ b/crossbeam-skiplist/src/map.rs @@ -254,6 +254,39 @@ where Entry::new(self.inner.get_or_insert(key, value, guard)) } + /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, + /// where value is calculated with a function. + /// + /// + /// Note: Another thread may write key value first, leading to the result of this closure + /// discarded. If closure is modifying some other state (such as shared counters or shared + /// objects), it may lead to undesired behaviour such as counters being changed without + /// result of closure inserted + //// + /// This function returns an [`Entry`] which + /// can be used to access the key's associated value. + /// + /// + /// # Example + /// ``` + /// use crossbeam_skiplist::SkipMap; + /// + /// let ages = SkipMap::new(); + /// let gates_age = ages.get_or_insert_with("Bill Gates", || 64); + /// assert_eq!(*gates_age.value(), 64); + /// + /// ages.insert("Steve Jobs", 65); + /// let jobs_age = ages.get_or_insert_with("Steve Jobs", || -1); + /// assert_eq!(*jobs_age.value(), 65); + /// ``` + pub fn get_or_insert_with(&self, key: K, value_fn: F) -> Entry<'_, K, V> + where + F: FnOnce() -> V, + { + let guard = &epoch::pin(); + Entry::new(self.inner.get_or_insert_with(key, value_fn, guard)) + } + /// Returns an iterator over all entries in the map, /// sorted by key. /// diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs index f08e1409e..c0af2d1b4 100644 --- a/crossbeam-skiplist/tests/base.rs +++ b/crossbeam-skiplist/tests/base.rs @@ -431,6 +431,76 @@ fn get_or_insert() { assert_eq!(*s.get_or_insert(6, 600, guard).value(), 600); } +#[test] +fn get_or_insert_with() { + let guard = &epoch::pin(); + let s = SkipList::new(epoch::default_collector().clone()); + s.insert(3, 3, guard); + s.insert(5, 5, guard); + s.insert(1, 1, guard); + s.insert(4, 4, guard); + s.insert(2, 2, guard); + + assert_eq!(*s.get(&4, guard).unwrap().value(), 4); + assert_eq!(*s.insert(4, 40, guard).value(), 40); + assert_eq!(*s.get(&4, guard).unwrap().value(), 40); + + assert_eq!(*s.get_or_insert_with(4, || 400, guard).value(), 40); + assert_eq!(*s.get(&4, guard).unwrap().value(), 40); + assert_eq!(*s.get_or_insert_with(6, || 600, guard).value(), 600); +} + +#[test] +fn get_or_insert_with_panic() { + use std::panic; + + let s = SkipList::new(epoch::default_collector().clone()); + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + let guard = &epoch::pin(); + s.get_or_insert_with(4, || panic!(), guard); + })); + assert!(res.is_err()); + assert!(s.is_empty()); + let guard = &epoch::pin(); + assert_eq!(*s.get_or_insert_with(4, || 40, guard).value(), 40); + assert_eq!(s.len(), 1); +} + +#[test] +fn get_or_insert_with_parallel_run() { + use std::sync::{Arc, Mutex}; + + let s = Arc::new(SkipList::new(epoch::default_collector().clone())); + let s2 = s.clone(); + let called = Arc::new(Mutex::new(false)); + let called2 = called.clone(); + let handle = std::thread::spawn(move || { + let guard = &epoch::pin(); + assert_eq!( + *s2.get_or_insert_with( + 7, + || { + *called2.lock().unwrap() = true; + + // allow main thread to run before we return result + std::thread::sleep(std::time::Duration::from_secs(4)); + 70 + }, + guard, + ) + .value(), + 700 + ); + }); + std::thread::sleep(std::time::Duration::from_secs(2)); + let guard = &epoch::pin(); + + // main thread writes the value first + assert_eq!(*s.get_or_insert(7, 700, guard).value(), 700); + handle.join().unwrap(); + assert!(*called.lock().unwrap()); +} + #[test] fn get_next_prev() { let guard = &epoch::pin(); diff --git a/crossbeam-skiplist/tests/map.rs b/crossbeam-skiplist/tests/map.rs index 06a0567e0..d00658505 100644 --- a/crossbeam-skiplist/tests/map.rs +++ b/crossbeam-skiplist/tests/map.rs @@ -370,6 +370,67 @@ fn get_or_insert() { assert_eq!(*s.get_or_insert(6, 600).value(), 600); } +#[test] +fn get_or_insert_with() { + let s = SkipMap::new(); + s.insert(3, 3); + s.insert(5, 5); + s.insert(1, 1); + s.insert(4, 4); + s.insert(2, 2); + + assert_eq!(*s.get(&4).unwrap().value(), 4); + assert_eq!(*s.insert(4, 40).value(), 40); + assert_eq!(*s.get(&4).unwrap().value(), 40); + + assert_eq!(*s.get_or_insert_with(4, || 400).value(), 40); + assert_eq!(*s.get(&4).unwrap().value(), 40); + assert_eq!(*s.get_or_insert_with(6, || 600).value(), 600); +} + +#[test] +fn get_or_insert_with_panic() { + use std::panic; + + let s = SkipMap::new(); + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + s.get_or_insert_with(4, || panic!()); + })); + assert!(res.is_err()); + assert!(s.is_empty()); + assert_eq!(*s.get_or_insert_with(4, || 40).value(), 40); + assert_eq!(s.len(), 1); +} + +#[test] +fn get_or_insert_with_parallel_run() { + use std::sync::{Arc, Mutex}; + + let s = Arc::new(SkipMap::new()); + let s2 = s.clone(); + let called = Arc::new(Mutex::new(false)); + let called2 = called.clone(); + let handle = std::thread::spawn(move || { + assert_eq!( + *s2.get_or_insert_with(7, || { + *called2.lock().unwrap() = true; + + // allow main thread to run before we return result + std::thread::sleep(std::time::Duration::from_secs(4)); + 70 + }) + .value(), + 700 + ); + }); + std::thread::sleep(std::time::Duration::from_secs(2)); + + // main thread writes the value first + assert_eq!(*s.get_or_insert(7, 700).value(), 700); + handle.join().unwrap(); + assert!(*called.lock().unwrap()); +} + #[test] fn get_next_prev() { let s = SkipMap::new();