Skip to content

Commit

Permalink
sync: add merge() to semaphore permits (#4948)
Browse files Browse the repository at this point in the history
  • Loading branch information
domodwyer committed Sep 27, 2022
1 parent 3db330d commit 909de08
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 2 deletions.
38 changes: 38 additions & 0 deletions tokio/src/sync/semaphore.rs
Expand Up @@ -620,6 +620,25 @@ impl<'a> SemaphorePermit<'a> {
pub fn forget(mut self) {
self.permits = 0;
}

/// Merge two [`SemaphorePermit`] instances together, consuming `other`
/// without releasing the permits it holds.
///
/// Permits held by both `self` and `other` are released when `self` drops.
///
/// # Panics
///
/// This function panics if permits from different [`Semaphore`] instances
/// are merged.
#[track_caller]
pub fn merge(&mut self, mut other: Self) {
assert!(
std::ptr::eq(self.sem, other.sem),
"merging permits from different semaphore instances"
);
self.permits += other.permits;
other.permits = 0;
}
}

impl OwnedSemaphorePermit {
Expand All @@ -629,6 +648,25 @@ impl OwnedSemaphorePermit {
pub fn forget(mut self) {
self.permits = 0;
}

/// Merge two [`OwnedSemaphorePermit`] instances together, consuming `other`
/// without releasing the permits it holds.
///
/// Permits held by both `self` and `other` are released when `self` drops.
///
/// # Panics
///
/// This function panics if permits from different [`Semaphore`] instances
/// are merged.
#[track_caller]
pub fn merge(&mut self, mut other: Self) {
assert!(
Arc::ptr_eq(&self.sem, &other.sem),
"merging permits from different semaphore instances"
);
self.permits += other.permits;
other.permits = 0;
}
}

impl Drop for SemaphorePermit<'_> {
Expand Down
36 changes: 34 additions & 2 deletions tokio/tests/sync_panic.rs
@@ -1,10 +1,10 @@
#![warn(rust_2018_idioms)]
#![cfg(all(feature = "full", not(tokio_wasi)))]

use std::error::Error;
use std::{error::Error, sync::Arc};
use tokio::{
runtime::{Builder, Runtime},
sync::{broadcast, mpsc, oneshot, Mutex, RwLock},
sync::{broadcast, mpsc, oneshot, Mutex, RwLock, Semaphore},
};

mod support {
Expand Down Expand Up @@ -160,6 +160,38 @@ fn mpsc_unbounded_receiver_blocking_recv_panic_caller() -> Result<(), Box<dyn Er
Ok(())
}

#[test]
fn semaphore_merge_unrelated_owned_permits() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let sem1 = Arc::new(Semaphore::new(42));
let sem2 = Arc::new(Semaphore::new(42));
let mut p1 = sem1.try_acquire_owned().unwrap();
let p2 = sem2.try_acquire_owned().unwrap();
p1.merge(p2);
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

#[test]
fn semaphore_merge_unrelated_permits() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let sem1 = Semaphore::new(42);
let sem2 = Semaphore::new(42);
let mut p1 = sem1.try_acquire().unwrap();
let p2 = sem2.try_acquire().unwrap();
p1.merge(p2);
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

fn current_thread() -> Runtime {
Builder::new_current_thread().enable_all().build().unwrap()
}
25 changes: 25 additions & 0 deletions tokio/tests/sync_semaphore.rs
Expand Up @@ -63,6 +63,31 @@ fn forget() {
assert!(sem.try_acquire().is_err());
}

#[test]
fn merge() {
let sem = Arc::new(Semaphore::new(3));
{
let mut p1 = sem.try_acquire().unwrap();
assert_eq!(sem.available_permits(), 2);
let p2 = sem.try_acquire_many(2).unwrap();
assert_eq!(sem.available_permits(), 0);
p1.merge(p2);
assert_eq!(sem.available_permits(), 0);
}
assert_eq!(sem.available_permits(), 3);
}

#[test]
#[cfg(not(tokio_wasm))] // No stack unwinding on wasm targets
#[should_panic]
fn merge_unrelated_permits() {
let sem1 = Arc::new(Semaphore::new(3));
let sem2 = Arc::new(Semaphore::new(3));
let mut p1 = sem1.try_acquire().unwrap();
let p2 = sem2.try_acquire().unwrap();
p1.merge(p2);
}

#[tokio::test]
#[cfg(feature = "full")]
async fn stress_test() {
Expand Down
25 changes: 25 additions & 0 deletions tokio/tests/sync_semaphore_owned.rs
Expand Up @@ -89,6 +89,31 @@ fn forget() {
assert!(sem.try_acquire_owned().is_err());
}

#[test]
fn merge() {
let sem = Arc::new(Semaphore::new(3));
{
let mut p1 = sem.clone().try_acquire_owned().unwrap();
assert_eq!(sem.available_permits(), 2);
let p2 = sem.clone().try_acquire_many_owned(2).unwrap();
assert_eq!(sem.available_permits(), 0);
p1.merge(p2);
assert_eq!(sem.available_permits(), 0);
}
assert_eq!(sem.available_permits(), 3);
}

#[test]
#[cfg(not(tokio_wasm))] // No stack unwinding on wasm targets
#[should_panic]
fn merge_unrelated_permits() {
let sem1 = Arc::new(Semaphore::new(3));
let sem2 = Arc::new(Semaphore::new(3));
let mut p1 = sem1.try_acquire_owned().unwrap();
let p2 = sem2.try_acquire_owned().unwrap();
p1.merge(p2)
}

#[tokio::test]
#[cfg(feature = "full")]
async fn stress_test() {
Expand Down

0 comments on commit 909de08

Please sign in to comment.