diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index bff5de98cda..2d3ac3ca8ec 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -143,7 +143,7 @@ impl Semaphore { } } - /// Tries to acquire n permits from the semaphore. + /// Tries to acquire `n` permits from the semaphore. /// /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, @@ -180,6 +180,27 @@ impl Semaphore { }) } + /// Acquires `n` permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. + /// If the semaphore has been closed, this returns an [`AcquireError`]. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// [`Arc`]: std::sync::Arc + /// [`AcquireError`]: crate::sync::AcquireError + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub async fn acquire_many_owned( + self: Arc, + n: u32, + ) -> Result { + self.ll_sem.acquire(n).await?; + Ok(OwnedSemaphorePermit { + sem: self, + permits: n, + }) + } + /// Tries to acquire a permit from the semaphore. /// /// The semaphore must be wrapped in an [`Arc`] to call this method. If @@ -202,6 +223,31 @@ impl Semaphore { } } + /// Tries to acquire `n` permits from the semaphore. + /// + /// The semaphore must be wrapped in an [`Arc`] to call this method. If + /// the semaphore has been closed, this returns a [`TryAcquireError::Closed`] + /// and a [`TryAcquireError::NoPermits`] if there are no permits left. + /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the + /// acquired permit. + /// + /// [`Arc`]: std::sync::Arc + /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed + /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits + /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit + pub fn try_acquire_many_owned( + self: Arc, + n: u32, + ) -> Result { + match self.ll_sem.try_acquire(n) { + Ok(_) => Ok(OwnedSemaphorePermit { + sem: self, + permits: n, + }), + Err(e) => Err(e), + } + } + /// Closes the semaphore. /// /// This prevents the semaphore from issuing new permits and notifies all pending waiters. diff --git a/tokio/tests/sync_semaphore_owned.rs b/tokio/tests/sync_semaphore_owned.rs index 8ed6209f3b9..478c3a3326e 100644 --- a/tokio/tests/sync_semaphore_owned.rs +++ b/tokio/tests/sync_semaphore_owned.rs @@ -16,6 +16,22 @@ fn try_acquire() { assert!(p3.is_ok()); } +#[test] +fn try_acquire_many() { + let sem = Arc::new(Semaphore::new(42)); + { + let p1 = sem.clone().try_acquire_many_owned(42); + assert!(p1.is_ok()); + let p2 = sem.clone().try_acquire_owned(); + assert!(p2.is_err()); + } + let p3 = sem.clone().try_acquire_many_owned(32); + assert!(p3.is_ok()); + let p4 = sem.clone().try_acquire_many_owned(10); + assert!(p4.is_ok()); + assert!(sem.try_acquire_owned().is_err()); +} + #[tokio::test] async fn acquire() { let sem = Arc::new(Semaphore::new(1)); @@ -28,6 +44,21 @@ async fn acquire() { j.await.unwrap(); } +#[tokio::test] +async fn acquire_many() { + let semaphore = Arc::new(Semaphore::new(42)); + let permit32 = semaphore.clone().try_acquire_many_owned(32).unwrap(); + let (sender, receiver) = tokio::sync::oneshot::channel(); + let join_handle = tokio::spawn(async move { + let _permit10 = semaphore.clone().acquire_many_owned(10).await.unwrap(); + sender.send(()).unwrap(); + let _permit32 = semaphore.acquire_many_owned(32).await.unwrap(); + }); + receiver.await.unwrap(); + drop(permit32); + join_handle.await.unwrap(); +} + #[tokio::test] async fn add_permits() { let sem = Arc::new(Semaphore::new(0));