asynchronousrusttimeasync-awaitrust-tokio

How to share and reset an async timer (tokio::time::Sleep) across multiple tasks in Rust


I’m working on an async timer that can be shared between multiple async tasks:

use std::{
    future::{self, Future},
    pin::Pin,
    sync::{Arc, Mutex},
    time::Duration,
};
use tokio::time::{self, Instant, Sleep};

struct Foo(Mutex<Pin<Box<Sleep>>>);

impl Foo {
    fn new(sleep: Sleep) -> Self {
        Self(Mutex::new(Box::pin(sleep)))
    }

    async fn sleep(&self) {
        future::poll_fn(|cx| self.0.lock().unwrap().as_mut().poll(cx)).await
    }

    fn reset(&self, deadline: Instant) {
        self.0.lock().unwrap().as_mut().reset(deadline);
    }
}

async fn task1(foo: Arc<Foo>) {
    println!("starting task 1 ...");
    let start = Instant::now();

    foo.sleep().await;

    let time = start.elapsed().as_millis();
    println!("task 1 complete in {time} millis ");
}

async fn task2(foo: Arc<Foo>) {
    println!("starting task 2 ...");
    let start = Instant::now();

    foo.sleep().await;

    let time = start.elapsed().as_millis();
    println!("task 2 complete in {time} millis ");
}

#[tokio::main]
pub async fn main() {
    let sleep = time::sleep(Duration::from_secs(3));
    let foo = Arc::new(Foo::new(sleep));

    let task1 = tokio::spawn(task1(foo.clone()));
    let task2 = tokio::spawn(task2(foo));

    tokio::join!(task1, task2);
}

Output:

starting task 2 ...
starting task 1 ...
task 1 complete in 3005 millis
// stuck here

The issue is that only one task completes, while the other gets stuck. Could this be happening because the second polled future’s waker is overwriting the first one?

I came across FutureExt::shared, but it takes ownership of the future. I need the ability to reset the timer while other futures are waiting on it.


Solution

  • After some digging, I found that an instance of Sleep can only hold a single async task waker:

    https://github.com/tokio-rs/tokio/blob/21df16d7595880247642c4fb38f1c365a49de75b/tokio/src/runtime/time/entry.rs#L102

    Every call of poll on Sleep overwrites the old waker registered by the previous poll, so the ready state needs to be propagated to all wakers manually:

    struct Foo(Mutex<FooInner>);
    
    struct FooInner {
        sleep: Pin<Box<Sleep>>,
        wakers: Vec<Waker>,
    }
    
    impl Foo {
        fn new(sleep: Sleep) -> Self {
            Self(Mutex::new(FooInner {
                sleep: Box::pin(sleep),
                wakers: Vec::new(),
            }))
        }
    
        async fn sleep(&self) {
            future::poll_fn(|cx| {
                let mut inner = self.0.lock().unwrap();
    
                match inner.sleep.as_mut().poll(cx) {
                    Poll::Ready(()) => {
                        // propagate the ready state to all wakers
                        for waker in inner.wakers.drain(..) {
                            waker.wake();
                        }
    
                        Poll::Ready(())
                    }
                    Poll::Pending => {
                        inner.wakers.push(cx.waker().clone());
                        Poll::Pending
                    }
                }
            })
            .await
        }
    
        fn reset(&self, deadline: Instant) {
            let mut inner = self.0.lock().unwrap();
    
            inner.sleep.as_mut().reset(deadline);
    
            // deadline might have been reset to an earlier time
            // wake up all wakers to re-evaluate the new deadline
            for waker in inner.wakers.drain(..) {
                waker.wake();
            }
        }
    }
    

    As mentioned by @啊鹿Dizzyi, switching to tokio::sync::Mutex solves the problem. This is because the lock guarantees there is only a single task can register its waker since other tasks cannot access the inner Sleep before the Sleep is resolved. In other words, tokio::sync::Mutex helps propagating the ready state.

    However, to achieve the FIFO basis, tokio::sync::Mutex is implemented using a semaphore which basically is an std::sync::Mutex<LinkedList<Waker>>. If the completion order of tasks does not required to be FIFO, manually storing wakers in a Vec or Slab can be more performant.