I wrote a small peice of program to demonstrate the usage of CountDownLatch class in java. But, it not working as expected. I created 5 threads and assigned task to each thread. Now, each thread would wait for the start signal. Once the start signal is on, all thread start its work and call countDown(). Now, my main thread wait for all the thread to finish its work till it receives the done signal. But the output is not expected. Please help if I am missing anything in the concept. Below is the program.
class Task implements Runnable{
private CountDownLatch startSignal;
private CountDownLatch doneSignal;
private int id;
Task(int id, CountDownLatch startSignal, CountDownLatch doneSignal){
this.startSignal = startSignal;
this.doneSignal = doneSignal;
this.id = id;
}
@Override
public void run() {
try {
startSignal.await();
performTask();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
private void performTask() {
try {
System.out.println("Task started by thread : " + id);
Thread.sleep(5000);
doneSignal.countDown();
System.out.println("Task ended by thread : " + id);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public class CountDownLatchExample {
public static void main(String[] args) {
CountDownLatch startSignal = new CountDownLatch(1);
CountDownLatch doneSignal = new CountDownLatch(5);
for(int i=0; i < 5; ++i) {
new Thread(new Task(i, startSignal, doneSignal)).start();
}
System.out.println("Press enter to start work");
new Scanner(System.in).nextLine();
startSignal.countDown();
try {
doneSignal.await();
System.out.println("All Tasks Completed");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
Output
Press enter to start work
Task started by thread : 0
Task started by thread : 4
Task started by thread : 3
Task started by thread : 2
Task started by thread : 1
Task ended by thread : 4
Task ended by thread : 2
Task ended by thread : 1
All Tasks Completed
Task ended by thread : 0
Task ended by thread : 3
Expected output
Press enter to start work
Task started by thread : 0
Task started by thread : 4
Task started by thread : 3
Task started by thread : 2
Task started by thread : 1
Task ended by thread : 4
Task ended by thread : 2
Task ended by thread : 1
Task ended by thread : 0
Task ended by thread : 3
All Tasks Completed
In your Task
class, you have:
doneSignal.countDown(); System.out.println("Task ended by thread : " + id);
In other words, you count down the latch before you print "task ended". That allows the main thread to wake up from its call to doneSignal.await()
and print "All Tasks Completed" before all the "task ended" print statements complete. Though note the "wrong output" will not always happen; sometimes you'll get your expected output.
Simply switch those two lines of code around to guarantee the output you want:
System.out.println("Task ended by thread : " + id);
doneSignal.countDown();
This ensures the print statement happens-before the doneSignal.countDown()
call, which itself happens-before the main thread returns from doneSignal.await()
. Thus, now the above "task ended" print statement happens-before the main thread wakes up and prints the "All Tasks Completed" message.