Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: add merge() to semaphore permits #4948

Merged
merged 6 commits into from Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions tokio/src/sync/semaphore.rs
Expand Up @@ -620,6 +620,25 @@ impl<'a> SemaphorePermit<'a> {
pub fn forget(mut self) {
self.permits = 0;
}

/// Merge two [`SemaphorePermit`] instances together, consuming `other`
/// without releasing the permits it holds.
///
/// Permits held by both `self` and `other` are released when `self` drops.
domodwyer marked this conversation as resolved.
Show resolved Hide resolved
///
/// # Panics
///
/// This function panics if permits from different [`Semaphore`] instances
/// are merged.
#[track_caller]
pub fn merge(&mut self, mut other: Self) {
assert!(
std::ptr::eq(self.sem, other.sem),
"merging permits from different semaphore instances"
);
self.permits += other.permits;
other.permits = 0;
}
}

impl OwnedSemaphorePermit {
Expand All @@ -629,6 +648,25 @@ impl OwnedSemaphorePermit {
pub fn forget(mut self) {
self.permits = 0;
}

/// Merge two [`OwnedSemaphorePermit`] instances together, consuming `other`
/// without releasing the permits it holds.
///
/// Permits held by both `self` and `other` are released when `self` drops.
///
/// # Panics
///
/// This function panics if permits from different [`Semaphore`] instances
/// are merged.
#[track_caller]
pub fn merge(&mut self, mut other: Self) {
assert!(
Arc::ptr_eq(&self.sem, &other.sem),
"merging permits from different semaphore instances"
);
self.permits += other.permits;
other.permits = 0;
}
}

impl Drop for SemaphorePermit<'_> {
Expand Down
36 changes: 34 additions & 2 deletions tokio/tests/sync_panic.rs
@@ -1,10 +1,10 @@
#![warn(rust_2018_idioms)]
#![cfg(all(feature = "full", not(tokio_wasi)))]

use std::error::Error;
use std::{error::Error, sync::Arc};
use tokio::{
runtime::{Builder, Runtime},
sync::{broadcast, mpsc, oneshot, Mutex, RwLock},
sync::{broadcast, mpsc, oneshot, Mutex, RwLock, Semaphore},
};

mod support {
Expand Down Expand Up @@ -160,6 +160,38 @@ fn mpsc_unbounded_receiver_blocking_recv_panic_caller() -> Result<(), Box<dyn Er
Ok(())
}

#[test]
fn semaphore_merge_unrelated_owned_permits() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let sem1 = Arc::new(Semaphore::new(42));
let sem2 = Arc::new(Semaphore::new(42));
let mut p1 = sem1.try_acquire_owned().unwrap();
let p2 = sem2.try_acquire_owned().unwrap();
p1.merge(p2);
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

#[test]
fn semaphore_merge_unrelated_permits() -> Result<(), Box<dyn Error>> {
let panic_location_file = test_panic(|| {
let sem1 = Semaphore::new(42);
let sem2 = Semaphore::new(42);
let mut p1 = sem1.try_acquire().unwrap();
let p2 = sem2.try_acquire().unwrap();
p1.merge(p2);
});

// The panic location should be in this file
assert_eq!(&panic_location_file.unwrap(), file!());

Ok(())
}

fn current_thread() -> Runtime {
Builder::new_current_thread().enable_all().build().unwrap()
}
25 changes: 25 additions & 0 deletions tokio/tests/sync_semaphore.rs
Expand Up @@ -63,6 +63,31 @@ fn forget() {
assert!(sem.try_acquire().is_err());
}

#[test]
fn merge() {
let sem = Arc::new(Semaphore::new(3));
{
let mut p1 = sem.try_acquire().unwrap();
assert_eq!(sem.available_permits(), 2);
let p2 = sem.try_acquire_many(2).unwrap();
assert_eq!(sem.available_permits(), 0);
p1.merge(p2);
assert_eq!(sem.available_permits(), 0);
}
assert_eq!(sem.available_permits(), 3);
}

#[test]
#[cfg(not(tokio_wasm))] // No stack unwinding on wasm targets
#[should_panic]
fn merge_unrelated_permits() {
let sem1 = Arc::new(Semaphore::new(3));
let sem2 = Arc::new(Semaphore::new(3));
let mut p1 = sem1.try_acquire().unwrap();
let p2 = sem2.try_acquire().unwrap();
p1.merge(p2);
}

#[tokio::test]
#[cfg(feature = "full")]
async fn stress_test() {
Expand Down
25 changes: 25 additions & 0 deletions tokio/tests/sync_semaphore_owned.rs
Expand Up @@ -89,6 +89,31 @@ fn forget() {
assert!(sem.try_acquire_owned().is_err());
}

#[test]
fn merge() {
let sem = Arc::new(Semaphore::new(3));
{
let mut p1 = sem.clone().try_acquire_owned().unwrap();
assert_eq!(sem.available_permits(), 2);
let p2 = sem.clone().try_acquire_many_owned(2).unwrap();
assert_eq!(sem.available_permits(), 0);
p1.merge(p2);
assert_eq!(sem.available_permits(), 0);
}
assert_eq!(sem.available_permits(), 3);
}

#[test]
#[cfg(not(tokio_wasm))] // No stack unwinding on wasm targets
#[should_panic]
fn merge_unrelated_permits() {
let sem1 = Arc::new(Semaphore::new(3));
let sem2 = Arc::new(Semaphore::new(3));
let mut p1 = sem1.try_acquire_owned().unwrap();
let p2 = sem2.try_acquire_owned().unwrap();
p1.merge(p2)
}

#[tokio::test]
#[cfg(feature = "full")]
async fn stress_test() {
Expand Down