multithreadingrustbarrier

What is wrong with this implementation of Barrier using atomics in Rust?


I've wrote the following implementation of a Barrier using atomics only:

use std::sync::atomic::{AtomicUsize, Ordering};

pub struct Barrier {
  pub done: AtomicUsize,
  pub tids: usize,
}

impl Barrier {
  pub fn new(tids: usize) -> Barrier {
    Barrier {
      done: AtomicUsize::new(0),
      tids,
    }
  }

  pub fn wait(&self) {
    let done = self.done.fetch_add(1, Ordering::SeqCst);
    if done + 1 == self.tids {
      self.done.store(0, Ordering::SeqCst);
    } else {
      while self.done.load(Ordering::SeqCst) != 0 {}
    }
  }
}

It doesn't work as expected. For example,

// inside threads loop
barrier.wait();
println!("a");
barrier.wait();
println!("b");

Intuitively, it should work, since, once .wait() is called, it will hang on the while loop, breaking free from it after all the threads have called .wait(), and resetting the counter for the next .wait(). Instead, eventually it will hang. Below is an usage example:

fn main() {
  println!("Hello, world!");

  let barrier = &Barrier::new(10);

  std::thread::scope(|s| {
    for tid in 0 .. 10 {
      s.spawn(move || {
        loop {
          barrier.wait();
          println!("{} a", tid);
          barrier.wait();
          println!("{} b", tid);
        }
      });
    }
  });
}

Solution

  • The problem is that there is a race condition between two consecutive barriers:

    If you are sure that you are always using the same threads, you could fix that by utilizing two counters, and flipping back and forth between them. That way all threads wait either for the first or the second one. But there is no way for one thread to bypass the others, as it would have to go through the second counter to block the first one again, and the second one will only unblock if no thread is still left in the first one.

    This one seems to work:

    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
    
    pub struct Barrier {
        pub done: [AtomicUsize; 2],
        pub use_first_done: AtomicBool,
        pub tids: usize,
    }
    
    impl Barrier {
        pub fn new(tids: usize) -> Barrier {
            Barrier {
                done: [AtomicUsize::new(0), AtomicUsize::new(0)],
                use_first_done: AtomicBool::new(true),
                tids,
            }
        }
    
        pub fn wait(&self) {
            let done = if self.use_first_done.load(Ordering::SeqCst) {
                &self.done[0]
            } else {
                &self.done[1]
            };
    
            let num_done = done.fetch_add(1, Ordering::SeqCst) + 1;
            if num_done == self.tids {
                self.use_first_done.fetch_xor(true, Ordering::SeqCst);
                done.store(0, Ordering::SeqCst);
            } else {
                while done.load(Ordering::SeqCst) != 0 {}
            }
        }
    }
    
    fn main() {
        println!("Hello, world!");
    
        let barrier = &Barrier::new(10);
    
        std::thread::scope(|s| {
            for tid in 0..10 {
                s.spawn(move || loop {
                    barrier.wait();
                    println!("{} a", tid);
                    barrier.wait();
                    println!("{} b", tid);
                });
            }
        });
    }
    

    An alternative would be to use an iteration counter.

    For the same reason as why flipping between two done counters works, an iteration counter of two iterations (= a boolean) should be sufficient.

    This one works for me as well:

    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
    
    pub struct Barrier {
        pub done: AtomicUsize,
        pub iteration: AtomicBool,
        pub tids: usize,
    }
    
    impl Barrier {
        pub fn new(tids: usize) -> Barrier {
            Barrier {
                done: AtomicUsize::new(0),
                iteration: AtomicBool::new(false),
                tids,
            }
        }
    
        pub fn wait(&self) {
            let iteration = self.iteration.load(Ordering::SeqCst);
            let num_done = self.done.fetch_add(1, Ordering::SeqCst) + 1;
            if num_done == self.tids {
                self.done.store(0, Ordering::SeqCst);
                self.iteration.fetch_xor(true, Ordering::SeqCst);
            } else {
                while iteration == self.iteration.load(Ordering::SeqCst) {}
            }
        }
    }
    
    fn main() {
        println!("Hello, world!");
    
        let barrier = &Barrier::new(10);
    
        std::thread::scope(|s| {
            for tid in 0..10 {
                s.spawn(move || loop {
                    barrier.wait();
                    println!("{} a", tid);
                    barrier.wait();
                    println!("{} b", tid);
                });
            }
        });
    }
    

    IMPORTANT: This only works if the threads are always identical. If different threads use this barrier, then it's necessary to have a bigger iteration counter.