pythonthread-safety

Why this multi-threaded code behaves in a safer way when I add thread.join()?


import threading

# Shared variable
shared_variable = 0
NUM_THREADS = 999
NUM_INCREMENT = 1_000_000


# Function to increment the shared variable
def increment():
    global shared_variable
    for _ in range(NUM_INCREMENT):
        shared_variable += 1


# Creating multiple threads to increment the shared variable concurrently
threads = []
for _ in range(NUM_THREADS):
    thread = threading.Thread(target=increment)
    threads.append(thread)

for thread in threads:
    thread.start()

for thread in threads:
    thread.join()

# Display the value of the shared variable after concurrent increments
print(
    f"The value of the shared variable is: {shared_variable}, expected : {NUM_THREADS * NUM_INCREMENT}"
)

This code always print the expected number when I join the thread. But if I comment this code then some increment fails (which is the result I actually expected !).

# for thread in threads:
#    thread.join()

Why this code behaves in a thread safe way when I add join?


Solution

  • What happens when you comment out the loop which performs all the thread.join() calls is very similar to when you do have the joins:

    Your code will be struggling to perform the loop which calls thread.start() on each thread. This is because once there are more than a very few threads started, they all tend to block each other, including the main thread. All the threads will be running slowly competing with each other with the situation getting worse and worse.

    I tried the version without the joins and the last line printed like this:

    The value of the shared variable is: 997692304, expected : 999000000

    This means that the main thread was so slow at starting all the threads that about 997 of the threads had already finished when the main thread exited the start loop and finished with the final print().

    The reason that so many threads finish is that they are doing very little (just incrementing a number), whereas the main thread is calling thread.start() which is doing something much more complex and time consuming.

    Conversely, with the code intact, calling all the thread.join() methods means that the main thread will deliberately wait for all the threads to complete. This guarantees that the global variable will have all the increments that you expect.