Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

util: fuse PollSemaphore #3578

Merged
merged 2 commits into from Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions tokio-util/src/sync/poll_semaphore.rs
Expand Up @@ -55,12 +55,13 @@ impl PollSemaphore {
/// the `Waker` from the `Context` passed to the most recent call is
/// scheduled to receive a wakeup.
pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
match ready!(self.permit_fut.poll(cx)) {
Ok(permit) => {
let next_fut = Arc::clone(&self.semaphore).acquire_owned();
self.permit_fut.set(next_fut);
Poll::Ready(Some(permit))
}
let result = ready!(self.permit_fut.poll(cx));

let next_fut = Arc::clone(&self.semaphore).acquire_owned();
self.permit_fut.set(next_fut);

match result {
Ok(permit) => Poll::Ready(Some(permit)),
Err(_closed) => Poll::Ready(None),
}
}
Expand Down
36 changes: 36 additions & 0 deletions tokio-util/tests/poll_semaphore.rs
@@ -0,0 +1,36 @@
use std::future::Future;
use std::sync::Arc;
use std::task::Poll;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;

type SemRet = Option<OwnedSemaphorePermit>;

fn semaphore_poll<'a>(
sem: &'a mut PollSemaphore,
) -> tokio_test::task::Spawn<impl Future<Output = SemRet> + 'a> {
let fut = futures::future::poll_fn(move |cx| sem.poll_acquire(cx));
tokio_test::task::spawn(fut)
}

#[tokio::test]
async fn it_works() {
let sem = Arc::new(Semaphore::new(1));
let mut poll_sem = PollSemaphore::new(sem.clone());

let permit = sem.acquire().await.unwrap();
let mut poll = semaphore_poll(&mut poll_sem);
assert!(poll.poll().is_pending());
drop(permit);

assert!(matches!(poll.poll(), Poll::Ready(Some(_))));
drop(poll);

sem.close();

assert!(semaphore_poll(&mut poll_sem).await.is_none());

// Check that it is fused.
assert!(semaphore_poll(&mut poll_sem).await.is_none());
assert!(semaphore_poll(&mut poll_sem).await.is_none());
}