diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index 88b3d3d63c5..ccf44ba8a88 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -620,6 +620,25 @@ impl<'a> SemaphorePermit<'a> { pub fn forget(mut self) { self.permits = 0; } + + /// Merge two [`SemaphorePermit`] instances together, consuming `other` + /// without releasing the permits it holds. + /// + /// Permits held by both `self` and `other` are released when `self` drops. + /// + /// # Panics + /// + /// This function panics if permits from different [`Semaphore`] instances + /// are merged. + #[track_caller] + pub fn merge(&mut self, mut other: Self) { + assert!( + std::ptr::eq(self.sem, other.sem), + "merging permits from different semaphore instances" + ); + self.permits += other.permits; + other.permits = 0; + } } impl OwnedSemaphorePermit { @@ -629,6 +648,25 @@ impl OwnedSemaphorePermit { pub fn forget(mut self) { self.permits = 0; } + + /// Merge two [`OwnedSemaphorePermit`] instances together, consuming `other` + /// without releasing the permits it holds. + /// + /// Permits held by both `self` and `other` are released when `self` drops. + /// + /// # Panics + /// + /// This function panics if permits from different [`Semaphore`] instances + /// are merged. + #[track_caller] + pub fn merge(&mut self, mut other: Self) { + assert!( + Arc::ptr_eq(&self.sem, &other.sem), + "merging permits from different semaphore instances" + ); + self.permits += other.permits; + other.permits = 0; + } } impl Drop for SemaphorePermit<'_> { diff --git a/tokio/tests/sync_panic.rs b/tokio/tests/sync_panic.rs index 11213b51544..6c23664998f 100644 --- a/tokio/tests/sync_panic.rs +++ b/tokio/tests/sync_panic.rs @@ -1,10 +1,10 @@ #![warn(rust_2018_idioms)] #![cfg(all(feature = "full", not(tokio_wasi)))] -use std::error::Error; +use std::{error::Error, sync::Arc}; use tokio::{ runtime::{Builder, Runtime}, - sync::{broadcast, mpsc, oneshot, Mutex, RwLock}, + sync::{broadcast, mpsc, oneshot, Mutex, RwLock, Semaphore}, }; mod support { @@ -160,6 +160,38 @@ fn mpsc_unbounded_receiver_blocking_recv_panic_caller() -> Result<(), Box Result<(), Box> { + let panic_location_file = test_panic(|| { + let sem1 = Arc::new(Semaphore::new(42)); + let sem2 = Arc::new(Semaphore::new(42)); + let mut p1 = sem1.try_acquire_owned().unwrap(); + let p2 = sem2.try_acquire_owned().unwrap(); + p1.merge(p2); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + +#[test] +fn semaphore_merge_unrelated_permits() -> Result<(), Box> { + let panic_location_file = test_panic(|| { + let sem1 = Semaphore::new(42); + let sem2 = Semaphore::new(42); + let mut p1 = sem1.try_acquire().unwrap(); + let p2 = sem2.try_acquire().unwrap(); + p1.merge(p2); + }); + + // The panic location should be in this file + assert_eq!(&panic_location_file.unwrap(), file!()); + + Ok(()) +} + fn current_thread() -> Runtime { Builder::new_current_thread().enable_all().build().unwrap() } diff --git a/tokio/tests/sync_semaphore.rs b/tokio/tests/sync_semaphore.rs index a061033ed7d..f12edb7dfbc 100644 --- a/tokio/tests/sync_semaphore.rs +++ b/tokio/tests/sync_semaphore.rs @@ -63,6 +63,31 @@ fn forget() { assert!(sem.try_acquire().is_err()); } +#[test] +fn merge() { + let sem = Arc::new(Semaphore::new(3)); + { + let mut p1 = sem.try_acquire().unwrap(); + assert_eq!(sem.available_permits(), 2); + let p2 = sem.try_acquire_many(2).unwrap(); + assert_eq!(sem.available_permits(), 0); + p1.merge(p2); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 3); +} + +#[test] +#[cfg(not(tokio_wasm))] // No stack unwinding on wasm targets +#[should_panic] +fn merge_unrelated_permits() { + let sem1 = Arc::new(Semaphore::new(3)); + let sem2 = Arc::new(Semaphore::new(3)); + let mut p1 = sem1.try_acquire().unwrap(); + let p2 = sem2.try_acquire().unwrap(); + p1.merge(p2); +} + #[tokio::test] #[cfg(feature = "full")] async fn stress_test() { diff --git a/tokio/tests/sync_semaphore_owned.rs b/tokio/tests/sync_semaphore_owned.rs index a09346f17f8..f6945764786 100644 --- a/tokio/tests/sync_semaphore_owned.rs +++ b/tokio/tests/sync_semaphore_owned.rs @@ -89,6 +89,31 @@ fn forget() { assert!(sem.try_acquire_owned().is_err()); } +#[test] +fn merge() { + let sem = Arc::new(Semaphore::new(3)); + { + let mut p1 = sem.clone().try_acquire_owned().unwrap(); + assert_eq!(sem.available_permits(), 2); + let p2 = sem.clone().try_acquire_many_owned(2).unwrap(); + assert_eq!(sem.available_permits(), 0); + p1.merge(p2); + assert_eq!(sem.available_permits(), 0); + } + assert_eq!(sem.available_permits(), 3); +} + +#[test] +#[cfg(not(tokio_wasm))] // No stack unwinding on wasm targets +#[should_panic] +fn merge_unrelated_permits() { + let sem1 = Arc::new(Semaphore::new(3)); + let sem2 = Arc::new(Semaphore::new(3)); + let mut p1 = sem1.try_acquire_owned().unwrap(); + let p2 = sem2.try_acquire_owned().unwrap(); + p1.merge(p2) +} + #[tokio::test] #[cfg(feature = "full")] async fn stress_test() {