I have a bunch of tasks that perform a certain operation. If a certain condition is met in a task I want all tasks after it to go on a pause/waiting state, and all currently running tasks to finish their operation. Only after all currently running tasks are completed do I want to ask for user input on whether to continue or end the program, this way no tasks would prematurely end if the user wants to terminate. I use semaphores to limit the number of tasks running at any given time.
So far this is what I came up with:
#[tokio::main]
async fn main() {
let workers = 5;
let range: u32 = 20;
let semaphores: Arc<Semaphore> = std::sync::Arc::new(Semaphore::new(workers));
let mut handles: Vec<JoinHandle<()>> = Vec::new();
let pause = Arc::new(AtomicBool::new(false));
let complete_count = Arc::new(AtomicU32::new(0));
let start_count = Arc::new(AtomicU32::new(0));
for i in 0..range {
let semaphores_clone = Arc::clone(&semaphores);
let complete_count_clone = Arc::clone(&complete_count);
let pause_clone = Arc::clone(&pause);
let start_count_clone = Arc::clone(&start_count);
handles.push(tokio::spawn(async move {
let _permit = semaphores_clone.clone().acquire_owned().await.unwrap();
// Check if pause flag is true
let mut execute_once = false;
loop {
if pause_clone.load(Ordering::SeqCst) {
if !execute_once {
println!(
"Task #{i} has been paused after {} started tasks and {} completed tasks",
start_count_clone.load(Ordering::SeqCst),
complete_count_clone.load(Ordering::SeqCst),
);
execute_once = true;
}
continue;
}
// Add to start_count when a task has started
start_count_clone.fetch_add(1, Ordering::SeqCst);
break;
}
// Perform some operation that takes time
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Pause all tasks at task #6
if i == 6 {
println!(
"Tasks paused by task #{i} after {} started tasks and {} completed tasks",
start_count_clone.load(Ordering::SeqCst),
complete_count_clone.load(Ordering::SeqCst)
);
pause_clone.store(true, Ordering::SeqCst);
// Ask for user input here
}
// Add to complete_count when a task is finished
// Then subtract start_count
{
start_count_clone.fetch_sub(1, Ordering::SeqCst);
complete_count_clone.fetch_add(1, Ordering::SeqCst);
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}
It works for the most part, but there are times when only a few tasks less than the amount of workers actually get paused. This can happen even when there are still a bunch of tasks left in the currently running state. Here's an example output in one of those scenarios:
Tasks paused by task #6 after 5 started tasks and 12 completed tasks
Task #19 has been paused after 4 started tasks and 15 completed tasks
Note that the number of workers and tasks can vary a lot, sometimes there might even be more workers than the tasks to be handled!
I am not exactly sure if this is the best way to go about it. I'm open to suggestions if there's a better solution for it!
busy looping is almost always wrong. instead you should use a Condition variable ... which tokio doesn't have, so we can make one using a Mutex and a Notify
use tokio;
use std::sync::Arc;
use std::sync::atomic::Ordering;
#[tokio::main]
async fn main() {
let workers = 5;
let range: u32 = 20;
let semaphores: Arc<tokio::sync::Semaphore> = std::sync::Arc::new(tokio::sync::Semaphore::new(workers));
let mut handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
let pause = Arc::new(tokio::sync::Mutex::new(std::sync::atomic::AtomicBool::new(false)));
let notifier = Arc::new(tokio::sync::Notify::new());
let complete_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let start_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
for i in 0..range {
let semaphores_clone = Arc::clone(&semaphores);
let complete_count_clone = Arc::clone(&complete_count);
let pause_clone = Arc::clone(&pause);
let start_count_clone = Arc::clone(&start_count);
let notifier_clone = Arc::clone(¬ifier);
handles.push(tokio::spawn(async move {
let _permit = semaphores_clone.clone().acquire_owned().await.unwrap();
// Check if pause flag is true
let should_wait =
{
let mut_lock = pause_clone.lock().await;
if mut_lock.load(Ordering::SeqCst) {
// register as awaiter while holding mutex
Some(notifier_clone.notified())
}
else
{
None
}
};
if let Some(wait_handle) = should_wait
{
println!(
"Task #{i} has been paused after {} started tasks and {} completed tasks",
start_count_clone.load(Ordering::SeqCst),
complete_count_clone.load(Ordering::SeqCst),
);
// await it while not holding the mutex
wait_handle.await;
}
// Add to start_count when a task has started
start_count_clone.fetch_add(1, Ordering::SeqCst);
// Perform some operation that takes time
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Pause all tasks at task #6
if i == 6 {
println!(
"Tasks paused by task #{i} after {} started tasks and {} completed tasks",
start_count_clone.load(Ordering::SeqCst),
complete_count_clone.load(Ordering::SeqCst)
);
pause_clone.lock().await.store(true, Ordering::SeqCst);
// Ask for user input here
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
pause_clone.lock().await.store(false, Ordering::SeqCst);
notifier_clone.notify_waiters(); // must notify waiters
}
// Add to complete_count when a task is finished
// Then subtract start_count
{
start_count_clone.fetch_sub(1, Ordering::SeqCst);
complete_count_clone.fetch_add(1, Ordering::SeqCst);
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}
the main idea is that we must register ourselves as awaiting the notifier while holding the mutex, then await it while not holding the mutex, as a normal condition variable does. but unlike a normal condition variable we don't need to wrap it in a loop because we cannot have spurious wakeups.
FYI: all loads and stores in your code can be Relaxed
since they don't need to establish ordering for the code around them. SeqCst
is an overkill here. (we don't care if the decrement happens first or the increment, we only care that they both happen in any order).
or as suggested in a comment, you can use tokio::sync::watch.
use tokio;
use std::sync::Arc;
use std::sync::atomic::Ordering;
#[tokio::main]
async fn main() {
let workers = 5;
let range: u32 = 20;
let semaphores: Arc<tokio::sync::Semaphore> = std::sync::Arc::new(tokio::sync::Semaphore::new(workers));
let mut handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
let (pause_tx, pause_rx) = tokio::sync::watch::channel(false);
let complete_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let start_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
for i in 0..range {
let semaphores_clone = Arc::clone(&semaphores);
let complete_count_clone = Arc::clone(&complete_count);
let start_count_clone = Arc::clone(&start_count);
let pause_tx_clone = pause_tx.clone();
let mut pause_rx_clone = pause_rx.clone();
handles.push(tokio::spawn(async move {
let _permit = semaphores_clone.clone().acquire_owned().await.unwrap();
// Check if pause flag is true
if *pause_rx_clone.borrow()
{
println!(
"Task #{i} has been paused after {} started tasks and {} completed tasks",
start_count_clone.load(Ordering::SeqCst),
complete_count_clone.load(Ordering::SeqCst),
);
// wait for it to be false
pause_rx_clone.wait_for(|value| !*value).await.unwrap();
}
// Add to start_count when a task has started
start_count_clone.fetch_add(1, Ordering::SeqCst);
// Perform some operation that takes time
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Pause all tasks at task #6
if i == 6 {
println!(
"Tasks paused by task #{i} after {} started tasks and {} completed tasks",
start_count_clone.load(Ordering::SeqCst),
complete_count_clone.load(Ordering::SeqCst)
);
pause_tx_clone.send(true).unwrap();
// Ask for user input here
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
pause_tx_clone.send(false).unwrap();
}
// Add to complete_count when a task is finished
// Then subtract start_count
{
start_count_clone.fetch_sub(1, Ordering::SeqCst);
complete_count_clone.fetch_add(1, Ordering::SeqCst);
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}