From cd30abe020db0bf270d6409a6cce8f736df52fe2 Mon Sep 17 00:00:00 2001 From: Shailendra Sharma Date: Tue, 13 Jul 2021 14:12:55 +0530 Subject: [PATCH 1/6] add get_or_insert_with function --- crossbeam-skiplist/src/base.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs index 862a1ae3d..cc7c86e79 100644 --- a/crossbeam-skiplist/src/base.rs +++ b/crossbeam-skiplist/src/base.rs @@ -471,6 +471,11 @@ where /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist. pub fn get_or_insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> { + self.insert_internal(key, || value, false, guard) + } + + /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, where value is calculated with a function. + pub fn get_or_insert_with(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V> where F: FnOnce() -> V { self.insert_internal(key, value, false, guard) } @@ -831,13 +836,13 @@ where /// Inserts an entry with the specified `key` and `value`. /// /// If `replace` is `true`, then any existing entry with this key will first be removed. - fn insert_internal( + fn insert_internal( &self, key: K, - value: V, + value: F, replace: bool, guard: &Guard, - ) -> RefEntry<'_, K, V> { + ) -> RefEntry<'_, K, V> where F: FnOnce() -> V { self.check_guard(guard); unsafe { @@ -886,7 +891,7 @@ where // Write the key and the value into the node. ptr::write(&mut (*n).key, key); - ptr::write(&mut (*n).value, value); + ptr::write(&mut (*n).value, value()); (Shared::>::from(n as *const _), &*n) }; @@ -1061,7 +1066,7 @@ where /// If there is an existing entry with this key, it will be removed before inserting the new /// one. pub fn insert(&self, key: K, value: V, guard: &Guard) -> RefEntry<'_, K, V> { - self.insert_internal(key, value, true, guard) + self.insert_internal(key, || value, true, guard) } /// Removes an entry with the specified `key` from the map and returns it. From 5073b548806ccb975ff6eb4759342b8f26236708 Mon Sep 17 00:00:00 2001 From: Shailendra Sharma Date: Tue, 13 Jul 2021 15:32:11 +0530 Subject: [PATCH 2/6] updated for rustfmt ci --- crossbeam-skiplist/src/base.rs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs index cc7c86e79..0c548791c 100644 --- a/crossbeam-skiplist/src/base.rs +++ b/crossbeam-skiplist/src/base.rs @@ -475,7 +475,10 @@ where } /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, where value is calculated with a function. - pub fn get_or_insert_with(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V> where F: FnOnce() -> V { + pub fn get_or_insert_with(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V> + where + F: FnOnce() -> V, + { self.insert_internal(key, value, false, guard) } @@ -842,7 +845,10 @@ where value: F, replace: bool, guard: &Guard, - ) -> RefEntry<'_, K, V> where F: FnOnce() -> V { + ) -> RefEntry<'_, K, V> + where + F: FnOnce() -> V, + { self.check_guard(guard); unsafe { @@ -887,11 +893,14 @@ where // The reference count is initially two to account for: // 1. The entry that will be returned. // 2. The link at the level 0 of the tower. + + // create value before creating node, so extra allocation doesn't happen if value() function panics + let value = value(); let n = Node::::alloc(height, 2); // Write the key and the value into the node. ptr::write(&mut (*n).key, key); - ptr::write(&mut (*n).value, value()); + ptr::write(&mut (*n).value, value); (Shared::>::from(n as *const _), &*n) }; From 134721a61b09688e722f8008b29fd49a3143343d Mon Sep 17 00:00:00 2001 From: Shailendra Sharma Date: Mon, 19 Jul 2021 22:44:28 +0530 Subject: [PATCH 3/6] Updated with test case for get_or_insert_with - (a) valid case (b) where closure panics. --- crossbeam-skiplist/tests/base.rs | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs index f08e1409e..7bfde502e 100644 --- a/crossbeam-skiplist/tests/base.rs +++ b/crossbeam-skiplist/tests/base.rs @@ -431,6 +431,43 @@ fn get_or_insert() { assert_eq!(*s.get_or_insert(6, 600, guard).value(), 600); } +#[test] +fn get_or_insert_with() { + let guard = &epoch::pin(); + let s = SkipList::new(epoch::default_collector().clone()); + s.insert(3, 3, guard); + s.insert(5, 5, guard); + s.insert(1, 1, guard); + s.insert(4, 4, guard); + s.insert(2, 2, guard); + + assert_eq!(*s.get(&4, guard).unwrap().value(), 4); + assert_eq!(*s.insert(4, 40, guard).value(), 40); + assert_eq!(*s.get(&4, guard).unwrap().value(), 40); + + assert_eq!(*s.get_or_insert_with(4, || 400, guard).value(), 40); + assert_eq!(*s.get(&4, guard).unwrap().value(), 40); + assert_eq!(*s.get_or_insert_with(6, || 600, guard).value(), 600); +} + +#[test] +fn get_or_insert_with_panic() { + let guard = &epoch::pin(); + let s = SkipList::new(epoch::default_collector().clone()); + s.insert(3, 3, guard); + s.insert(5, 5, guard); + s.insert(1, 1, guard); + s.insert(4, 4, guard); + s.insert(2, 2, guard); + + assert_eq!(*s.get(&4, guard).unwrap().value(), 4); + assert_eq!(*s.insert(4, 40, guard).value(), 40); + assert_eq!(*s.get(&4, guard).unwrap().value(), 40); + + assert_eq!(*s.get_or_insert_with(4, || panic!(), guard).value(), 40); + assert_eq!(*s.get(&4, guard).unwrap().value(), 40); +} + #[test] fn get_next_prev() { let guard = &epoch::pin(); From 36152dcc64871b7a10a92717950563f113209bbc Mon Sep 17 00:00:00 2001 From: Shailendra Sharma Date: Thu, 22 Jul 2021 20:34:35 +0530 Subject: [PATCH 4/6] Update test for closure with panic call. As suggested by taiki-e Co-authored-by: Taiki Endo --- crossbeam-skiplist/tests/base.rs | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs index 7bfde502e..07ed4c77f 100644 --- a/crossbeam-skiplist/tests/base.rs +++ b/crossbeam-skiplist/tests/base.rs @@ -452,20 +452,18 @@ fn get_or_insert_with() { #[test] fn get_or_insert_with_panic() { - let guard = &epoch::pin(); - let s = SkipList::new(epoch::default_collector().clone()); - s.insert(3, 3, guard); - s.insert(5, 5, guard); - s.insert(1, 1, guard); - s.insert(4, 4, guard); - s.insert(2, 2, guard); + use std::panic; - assert_eq!(*s.get(&4, guard).unwrap().value(), 4); - assert_eq!(*s.insert(4, 40, guard).value(), 40); - assert_eq!(*s.get(&4, guard).unwrap().value(), 40); - - assert_eq!(*s.get_or_insert_with(4, || panic!(), guard).value(), 40); - assert_eq!(*s.get(&4, guard).unwrap().value(), 40); + let s = SkipList::new(epoch::default_collector().clone()); + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + let guard = &epoch::pin(); + s.get_or_insert_with(4, || panic!(), guard); + })); + assert!(res.is_err()); + assert!(s.is_empty()); + let guard = &epoch::pin(); + assert_eq!(*s.get_or_insert_with(4, || 40, guard).value(), 40); + assert_eq!(s.len(), 1); } #[test] From 2027c1bcbb0e7fa59291f0392d2c24e168772e04 Mon Sep 17 00:00:00 2001 From: Shailendra Sharma Date: Thu, 22 Jul 2021 20:56:37 +0530 Subject: [PATCH 5/6] 1) Added note for potentially undesired behaviour closure may have, in case another thread writes before. 2) Added test case for the same. --- crossbeam-skiplist/src/base.rs | 15 ++++++++++---- crossbeam-skiplist/tests/base.rs | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs index 0c548791c..9ff107221 100644 --- a/crossbeam-skiplist/src/base.rs +++ b/crossbeam-skiplist/src/base.rs @@ -474,7 +474,14 @@ where self.insert_internal(key, || value, false, guard) } - /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, where value is calculated with a function. + /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, + /// where value is calculated with a function. + ///
+ ///
+ /// Note: Another thread may write key value first, leading to the result of this closure + /// discarded. If closure is modifying some other state (such as shared counters or shared + /// objects), it may lead to undesired behaviour such as counters being changed without + /// result of closure inserted pub fn get_or_insert_with(&self, key: K, value: F, guard: &Guard) -> RefEntry<'_, K, V> where F: FnOnce() -> V, @@ -887,15 +894,15 @@ where } } + // create value before creating node, so extra allocation doesn't happen if value() function panics + let value = value(); + // Create a new node. let height = self.random_height(); let (node, n) = { // The reference count is initially two to account for: // 1. The entry that will be returned. // 2. The link at the level 0 of the tower. - - // create value before creating node, so extra allocation doesn't happen if value() function panics - let value = value(); let n = Node::::alloc(height, 2); // Write the key and the value into the node. diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs index 07ed4c77f..aaa6e75d0 100644 --- a/crossbeam-skiplist/tests/base.rs +++ b/crossbeam-skiplist/tests/base.rs @@ -466,6 +466,41 @@ fn get_or_insert_with_panic() { assert_eq!(s.len(), 1); } +#[test] +fn get_or_insert_with_parallel_run() { + use std::sync::{Arc, Mutex}; + + let s = Arc::new(SkipList::new(epoch::default_collector().clone())); + let s2 = s.clone(); + let called = Arc::new(Mutex::new(false)); + let called2 = called.clone(); + let handle = std::thread::spawn(move || { + let guard = &epoch::pin(); + assert_eq!( + *s2.get_or_insert_with( + 7, + || { + *called2.lock().unwrap() = true; + + // allow main thread to run before we return result + std::thread::sleep(std::time::Duration::from_secs(4)); + 70 + }, + guard, + ) + .value(), + 700 + ); + }); + std::thread::sleep(std::time::Duration::from_secs(2)); + let guard = &epoch::pin(); + + // main thread writes the value first + assert_eq!(*s.get_or_insert(7, 700, guard).value(), 700); + handle.join().unwrap(); + assert!(*called.lock().unwrap()); +} + #[test] fn get_next_prev() { let guard = &epoch::pin(); From 403e899a846c06d0ece944bbdbfe161102747724 Mon Sep 17 00:00:00 2001 From: Shailendra Sharma Date: Thu, 22 Jul 2021 21:27:08 +0530 Subject: [PATCH 6/6] 1) Added get_or_insert_with method for SkipMap too 2) Added corresponding test case for SkipMap --- crossbeam-skiplist/src/base.rs | 4 +-- crossbeam-skiplist/src/map.rs | 33 +++++++++++++++++ crossbeam-skiplist/tests/base.rs | 2 +- crossbeam-skiplist/tests/map.rs | 61 ++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) diff --git a/crossbeam-skiplist/src/base.rs b/crossbeam-skiplist/src/base.rs index 9ff107221..826e8188c 100644 --- a/crossbeam-skiplist/src/base.rs +++ b/crossbeam-skiplist/src/base.rs @@ -476,8 +476,8 @@ where /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, /// where value is calculated with a function. - ///
- ///
+ /// + /// /// Note: Another thread may write key value first, leading to the result of this closure /// discarded. If closure is modifying some other state (such as shared counters or shared /// objects), it may lead to undesired behaviour such as counters being changed without diff --git a/crossbeam-skiplist/src/map.rs b/crossbeam-skiplist/src/map.rs index b035d1fc1..6beb3fb91 100644 --- a/crossbeam-skiplist/src/map.rs +++ b/crossbeam-skiplist/src/map.rs @@ -254,6 +254,39 @@ where Entry::new(self.inner.get_or_insert(key, value, guard)) } + /// Finds an entry with the specified key, or inserts a new `key`-`value` pair if none exist, + /// where value is calculated with a function. + /// + /// + /// Note: Another thread may write key value first, leading to the result of this closure + /// discarded. If closure is modifying some other state (such as shared counters or shared + /// objects), it may lead to undesired behaviour such as counters being changed without + /// result of closure inserted + //// + /// This function returns an [`Entry`] which + /// can be used to access the key's associated value. + /// + /// + /// # Example + /// ``` + /// use crossbeam_skiplist::SkipMap; + /// + /// let ages = SkipMap::new(); + /// let gates_age = ages.get_or_insert_with("Bill Gates", || 64); + /// assert_eq!(*gates_age.value(), 64); + /// + /// ages.insert("Steve Jobs", 65); + /// let jobs_age = ages.get_or_insert_with("Steve Jobs", || -1); + /// assert_eq!(*jobs_age.value(), 65); + /// ``` + pub fn get_or_insert_with(&self, key: K, value_fn: F) -> Entry<'_, K, V> + where + F: FnOnce() -> V, + { + let guard = &epoch::pin(); + Entry::new(self.inner.get_or_insert_with(key, value_fn, guard)) + } + /// Returns an iterator over all entries in the map, /// sorted by key. /// diff --git a/crossbeam-skiplist/tests/base.rs b/crossbeam-skiplist/tests/base.rs index aaa6e75d0..c0af2d1b4 100644 --- a/crossbeam-skiplist/tests/base.rs +++ b/crossbeam-skiplist/tests/base.rs @@ -488,7 +488,7 @@ fn get_or_insert_with_parallel_run() { }, guard, ) - .value(), + .value(), 700 ); }); diff --git a/crossbeam-skiplist/tests/map.rs b/crossbeam-skiplist/tests/map.rs index 06a0567e0..d00658505 100644 --- a/crossbeam-skiplist/tests/map.rs +++ b/crossbeam-skiplist/tests/map.rs @@ -370,6 +370,67 @@ fn get_or_insert() { assert_eq!(*s.get_or_insert(6, 600).value(), 600); } +#[test] +fn get_or_insert_with() { + let s = SkipMap::new(); + s.insert(3, 3); + s.insert(5, 5); + s.insert(1, 1); + s.insert(4, 4); + s.insert(2, 2); + + assert_eq!(*s.get(&4).unwrap().value(), 4); + assert_eq!(*s.insert(4, 40).value(), 40); + assert_eq!(*s.get(&4).unwrap().value(), 40); + + assert_eq!(*s.get_or_insert_with(4, || 400).value(), 40); + assert_eq!(*s.get(&4).unwrap().value(), 40); + assert_eq!(*s.get_or_insert_with(6, || 600).value(), 600); +} + +#[test] +fn get_or_insert_with_panic() { + use std::panic; + + let s = SkipMap::new(); + let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { + s.get_or_insert_with(4, || panic!()); + })); + assert!(res.is_err()); + assert!(s.is_empty()); + assert_eq!(*s.get_or_insert_with(4, || 40).value(), 40); + assert_eq!(s.len(), 1); +} + +#[test] +fn get_or_insert_with_parallel_run() { + use std::sync::{Arc, Mutex}; + + let s = Arc::new(SkipMap::new()); + let s2 = s.clone(); + let called = Arc::new(Mutex::new(false)); + let called2 = called.clone(); + let handle = std::thread::spawn(move || { + assert_eq!( + *s2.get_or_insert_with(7, || { + *called2.lock().unwrap() = true; + + // allow main thread to run before we return result + std::thread::sleep(std::time::Duration::from_secs(4)); + 70 + }) + .value(), + 700 + ); + }); + std::thread::sleep(std::time::Duration::from_secs(2)); + + // main thread writes the value first + assert_eq!(*s.get_or_insert(7, 700).value(), 700); + handle.join().unwrap(); + assert!(*called.lock().unwrap()); +} + #[test] fn get_next_prev() { let s = SkipMap::new();