rustsemaphorerust-tokio

Pause/Wait tokio tasks after a certain condition, and then resume after user input


I have a bunch of tasks that perform a certain operation. If a certain condition is met in a task I want all tasks after it to go on a pause/waiting state, and all currently running tasks to finish their operation. Only after all currently running tasks are completed do I want to ask for user input on whether to continue or end the program, this way no tasks would prematurely end if the user wants to terminate. I use semaphores to limit the number of tasks running at any given time.

So far this is what I came up with:

#[tokio::main]
async fn main() {
    let workers = 5;
    let range: u32 = 20;
    let semaphores: Arc<Semaphore> = std::sync::Arc::new(Semaphore::new(workers));
    let mut handles: Vec<JoinHandle<()>> = Vec::new();
    let pause = Arc::new(AtomicBool::new(false));
    let complete_count = Arc::new(AtomicU32::new(0));
    let start_count = Arc::new(AtomicU32::new(0));

    for i in 0..range {
        let semaphores_clone = Arc::clone(&semaphores);
        let complete_count_clone = Arc::clone(&complete_count);
        let pause_clone = Arc::clone(&pause);
        let start_count_clone = Arc::clone(&start_count);

        handles.push(tokio::spawn(async move {
            let _permit = semaphores_clone.clone().acquire_owned().await.unwrap();

            // Check if pause flag is true
            let mut execute_once = false;
            loop {
                if pause_clone.load(Ordering::SeqCst) {
                    if !execute_once {
                        println!(
                            "Task #{i} has been paused after {} started tasks and {} completed tasks",
                            start_count_clone.load(Ordering::SeqCst),
                            complete_count_clone.load(Ordering::SeqCst),
                        );
                        execute_once = true;
                    }
                    continue;
                }
                // Add to start_count when a task has started
                start_count_clone.fetch_add(1, Ordering::SeqCst);
                break;
            }

            // Perform some operation that takes time
            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

            // Pause all tasks at task #6
            if i == 6 {
                println!(
                    "Tasks paused by task #{i} after {} started tasks and {} completed tasks",
                    start_count_clone.load(Ordering::SeqCst),
                    complete_count_clone.load(Ordering::SeqCst)
                );
                pause_clone.store(true, Ordering::SeqCst);

                // Ask for user input here
            }

            // Add to complete_count when a task is finished
            // Then subtract start_count
            {
                start_count_clone.fetch_sub(1, Ordering::SeqCst);
                complete_count_clone.fetch_add(1, Ordering::SeqCst);
            }
        }));
    }

    for handle in handles {
        handle.await.unwrap();
    }
}

It works for the most part, but there are times when only a few tasks less than the amount of workers actually get paused. This can happen even when there are still a bunch of tasks left in the currently running state. Here's an example output in one of those scenarios:

Tasks paused by task #6 after 5 started tasks and 12 completed tasks
Task #19 has been paused after 4 started tasks and 15 completed tasks

Note that the number of workers and tasks can vary a lot, sometimes there might even be more workers than the tasks to be handled!

I am not exactly sure if this is the best way to go about it. I'm open to suggestions if there's a better solution for it!


Solution

  • busy looping is almost always wrong. instead you should use a Condition variable ... which tokio doesn't have, so we can make one using a Mutex and a Notify

    use tokio;
    use std::sync::Arc;
    use std::sync::atomic::Ordering;
    
    #[tokio::main]
    async fn main() {
        let workers = 5;
        let range: u32 = 20;
        let semaphores: Arc<tokio::sync::Semaphore> = std::sync::Arc::new(tokio::sync::Semaphore::new(workers));
        let mut handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
        let pause = Arc::new(tokio::sync::Mutex::new(std::sync::atomic::AtomicBool::new(false)));
        let notifier = Arc::new(tokio::sync::Notify::new());
        let complete_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
        let start_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
    
        for i in 0..range {
            let semaphores_clone = Arc::clone(&semaphores);
            let complete_count_clone = Arc::clone(&complete_count);
            let pause_clone = Arc::clone(&pause);
            let start_count_clone = Arc::clone(&start_count);
            let notifier_clone = Arc::clone(&notifier);
            handles.push(tokio::spawn(async move {
                let _permit = semaphores_clone.clone().acquire_owned().await.unwrap();
    
                // Check if pause flag is true
                let should_wait = 
                    {
                        let mut_lock = pause_clone.lock().await;
                        if mut_lock.load(Ordering::SeqCst) {
                            // register as awaiter while holding mutex
                            Some(notifier_clone.notified())
                        }
                        else
                        {
                            None
                        }
                    };
                if let Some(wait_handle) = should_wait
                {
                    println!(
                            "Task #{i} has been paused after {} started tasks and {} completed tasks",
                            start_count_clone.load(Ordering::SeqCst),
                            complete_count_clone.load(Ordering::SeqCst),
                                );
                    // await it while not holding the mutex
                    wait_handle.await;
                }
                
                // Add to start_count when a task has started
                start_count_clone.fetch_add(1, Ordering::SeqCst);
    
    
                // Perform some operation that takes time
                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
    
                // Pause all tasks at task #6
                if i == 6 {
                    println!(
                        "Tasks paused by task #{i} after {} started tasks and {} completed tasks",
                        start_count_clone.load(Ordering::SeqCst),
                        complete_count_clone.load(Ordering::SeqCst)
                    );
                    pause_clone.lock().await.store(true, Ordering::SeqCst);
    
                    // Ask for user input here
                    tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
                    pause_clone.lock().await.store(false, Ordering::SeqCst);
                    notifier_clone.notify_waiters(); // must notify waiters
                }
    
                // Add to complete_count when a task is finished
                // Then subtract start_count
                {
                    start_count_clone.fetch_sub(1, Ordering::SeqCst);
                    complete_count_clone.fetch_add(1, Ordering::SeqCst);
                }
            }));
        }
    
        for handle in handles {
            handle.await.unwrap();
        }
    }
    

    the main idea is that we must register ourselves as awaiting the notifier while holding the mutex, then await it while not holding the mutex, as a normal condition variable does. but unlike a normal condition variable we don't need to wrap it in a loop because we cannot have spurious wakeups.

    FYI: all loads and stores in your code can be Relaxed since they don't need to establish ordering for the code around them. SeqCst is an overkill here. (we don't care if the decrement happens first or the increment, we only care that they both happen in any order).


    or as suggested in a comment, you can use tokio::sync::watch.

    use tokio;
    use std::sync::Arc;
    use std::sync::atomic::Ordering;
    
    #[tokio::main]
    async fn main() {
        let workers = 5;
        let range: u32 = 20;
        let semaphores: Arc<tokio::sync::Semaphore> = std::sync::Arc::new(tokio::sync::Semaphore::new(workers));
        let mut handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
        let (pause_tx, pause_rx) = tokio::sync::watch::channel(false);
        let complete_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
        let start_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
    
        for i in 0..range {
            let semaphores_clone = Arc::clone(&semaphores);
            let complete_count_clone = Arc::clone(&complete_count);
            let start_count_clone = Arc::clone(&start_count);
            let pause_tx_clone = pause_tx.clone();
            let mut pause_rx_clone = pause_rx.clone();
            handles.push(tokio::spawn(async move {
                let _permit = semaphores_clone.clone().acquire_owned().await.unwrap();
    
                // Check if pause flag is true
                if *pause_rx_clone.borrow()
                {
                    println!(
                        "Task #{i} has been paused after {} started tasks and {} completed tasks",
                        start_count_clone.load(Ordering::SeqCst),
                        complete_count_clone.load(Ordering::SeqCst),
                            );
                    // wait for it to be false
                    pause_rx_clone.wait_for(|value| !*value).await.unwrap();
                }
    
                // Add to start_count when a task has started
                start_count_clone.fetch_add(1, Ordering::SeqCst);
    
    
                // Perform some operation that takes time
                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
    
                // Pause all tasks at task #6
                if i == 6 {
                    println!(
                        "Tasks paused by task #{i} after {} started tasks and {} completed tasks",
                        start_count_clone.load(Ordering::SeqCst),
                        complete_count_clone.load(Ordering::SeqCst)
                    );
                    pause_tx_clone.send(true).unwrap();
    
                    // Ask for user input here
                    tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
                    pause_tx_clone.send(false).unwrap();
                }
    
                // Add to complete_count when a task is finished
                // Then subtract start_count
                {
                    start_count_clone.fetch_sub(1, Ordering::SeqCst);
                    complete_count_clone.fetch_add(1, Ordering::SeqCst);
                }
            }));
        }
    
        for handle in handles {
            handle.await.unwrap();
        }
    }