pythonpython-asyncio

Python asyncio: Queue.join() finishes only when exception are not raised, why? (context: writting an async map function)


I've been trying to write an async version of the map function in Python for doing IO.

To do that, I'm using a queue with a producer/consumer.

At first it seems to be working well, but only without exceptions.

In particular, if I use queue.join(), it works well when no exceptions, but blocks in case of exception. If I use gather(*tasks), it works well when there is an exception, but blocks if not.

So it only finished sometimes, and I just don't understand why.

Here is the code I've implemented:

import asyncio
from asyncio import Queue
from typing import Iterable, Callable, TypeVar

Input = TypeVar("Input")
Output = TypeVar("Output")
STOP = object()


def parallel_map(func: Callable[[Input], Output], iterable: Iterable[Input]) -> Iterable[Output]:
    """
    Parallel version of `map`, backed by asyncio.
    Only suitable to do IO in parallel (not for CPU intensive tasks, otherwise it will block).
    """

    number_of_parallel_calls = 9

    async def worker(input_queue: Queue, output_queue: Queue) -> None:
        while True:
            data = await input_queue.get()
            try:
                output = func(data)
                # Simulate an exception:
                # raise RuntimeError("")
                output_queue.put_nowait(output)
            finally:
                input_queue.task_done()

    async def group_results(output_queue: Queue) -> Iterable[Output]:
        output = []
        while True:
            item = await output_queue.get()
            if item is not STOP:
                output.append(item)
            output_queue.task_done()
            if item is STOP:
                break
        return output

    async def procedure() -> Iterable[Output]:
        # First, produce a queue of inputs
        input_queue: Queue = asyncio.Queue()
        for i in iterable:
            input_queue.put_nowait(i)

        # Then, assign a pool of tasks to consume it (and also produce outputs in a new queue)
        output_queue: Queue = asyncio.Queue()
        tasks = []
        for _ in range(number_of_parallel_calls):
            task = asyncio.create_task(worker(input_queue, output_queue))
            tasks.append(task)

        # Wait for the input queue to be fully consumed (only works if no exception occurs in the tasks), blocks otherwise.
        await input_queue.join()
        # Gather tasks, only works when an exception is raised in a task, blocks otherwise
        # asyncio.gather(*tasks)

        for task in tasks:
            task.cancel()

        # Indicate that the output queue is complete, to stop the worker
        output_queue.put_nowait(STOP)

        # Consume the output_queue, and return its data as a list
        group_results_task = asyncio.create_task(group_results(output_queue))
        await output_queue.join()
        output = await group_results_task
        return output

    return asyncio.run(procedure())

if __name__ == "__main__()":
    def my_function(x):
        return x * x

    data = [1, 2, 3, 4]
    print(parallel_map(my_function, data))

I think I'm misunderstanding a basic but important with Python asyncio, but not sure what.


Solution

  • Problem is, your are NOT catching exceptions.

    From Python doc

    The count of unfinished tasks goes up whenever an item is added to the queue. The count goes down whenever a consumer coroutine calls task_done() to indicate that the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks.

    So Queue is essentially counting number of calls on put(), and reducing counter by 1 on every task_done() call. If worker stops before processing all queue, you are blocked in Queue.join().

    At your worker code:

        async def worker(input_queue: Queue, output_queue: Queue) -> None:
            while True:
                data = await input_queue.get()
                try:
                    output = func(data)
                    output_queue.put_nowait(output)
                finally:
                    input_queue.task_done()
    

    Your worker stops when it encounters Exception, becuase try-finally only guarantee cleanup, not actual error handling.

    Therefore what is happening for your case is:

    1. Each Queue.put() call increase internal counter, lets say we called it n times.
    2. Workers start calling Queue.task_done() to decrease internal counter.
    3. When encountering error, worker will stop after executing finally block.
    4. Now if all workers stop, Queue.task_done() call count n' is n' < n, internal counter is still in positive value.
    5. Queue.join() Hangs indefinitely until internal counter is 0, which never happens as all worker died.

    This is a design Flaw.


    Additional helpful changes

    There is multiple design factors to be changed for the sake of easy implementation, less prune to failure, and performance.

    Do note that this is from my experience with python, so don't take this as concrete facts.

    For design factors, I've made following changes:

    Function Code:

    import asyncio
    
    
    def parallel_map(func, iterable, concurrent_limit=2, raise_error=False):
        async def worker(input_queue: asyncio.Queue, output_queue: asyncio.Queue):
    
            while not input_queue.empty():
                # EDIT(2024-11-26): use get_nowait here not await get
                #                       vvvvvvvvvvvv
                idx, item = input_queue.get_nowait()
                try:
                    # Support both coroutine and function. Coroutine function I mean!
                    if asyncio.iscoroutinefunction(func):
                        output = await func(item)
                    else:
                        output = func(item)
                    await output_queue.put((idx, output))
    
                except Exception as err:
                    await output_queue.put((idx, err))
    
                finally:
                    input_queue.task_done()
    
        async def group_results(input_size, output_queue: asyncio.Queue):
            output = {}  # using dict to remove the need to sort list
    
            for _ in range(input_size):
                idx, val = await output_queue.get()  # gets tuple(idx, result)
                output[idx] = val
                output_queue.task_done()
    
            return [output[i] for i in range(input_size)]
    
        async def procedure():
            # populating input queue
            input_queue: asyncio.Queue = asyncio.Queue()
            for idx, item in enumerate(iterable):
                input_queue.put_nowait((idx, item))
            
            # Remember size before using Queue
            input_size = input_queue.qsize()
    
            # Generate task pool, and start collecting data.
            output_queue: asyncio.Queue = asyncio.Queue()
            result_task = asyncio.create_task(group_results(input_size, output_queue))
            tasks = [
                asyncio.create_task(worker(input_queue, output_queue))
                for _ in range(concurrent_limit)
            ]
            
            # Wait for tasks complete
            await asyncio.gather(*tasks)
            
            # Wait for result fetching
            results = await result_task
            
            # Re-raise errors at once if raise_error
            if raise_error and (errors := [err for err in results if isinstance(err, Exception)]):
                # noinspection PyUnboundLocalVariable
                raise Exception(errors)  # It never runs before assignment, safe to ignore.
    
            return results
    
        return asyncio.run(procedure())
    

    Test code:

    if __name__ == "__main__":
        import random
        import time
    
        data = [1, 2, 3]
        err_data = [1, 'yo', 3]
    
        def test_normal_function(data_, raise_=False):
            def my_function(x):
                t = random.uniform(1, 2)
                print(f"Sleep {t:.3} start")
    
                time.sleep(t)
                print(f"Awake after {t:.3}")
    
                return x * x
    
            print(f"Normal function: {parallel_map(my_function, data_, raise_error=raise_)}\n")
    
        def test_coroutine(data_, raise_=False):
            async def my_coro(x):
                t = random.uniform(1, 2)
                print(f"Coroutine sleep {t:.3} start")
    
                await asyncio.sleep(t)
                print(f"Coroutine awake after {t:.3}")
    
                return x * x
    
            print(f"Coroutine {parallel_map(my_coro, data_, raise_error=raise_)}\n")
    
        # Test starts
        print(f"Test for data {data}:")
        test_normal_function(data)
        test_coroutine(data)
    
        print(f"Test for data {err_data} without raise:")
        test_normal_function(err_data)
        test_coroutine(err_data)
    
        print(f"Test for data {err_data} with raise:")
        test_normal_function(err_data, True)
        test_coroutine(err_data, True)  # this line will not run, but works same.
    

    Above will test for following conditions for both function and coroutine:

    Even with Exceptions this will not cancel task, rather it processes all queue.

    Output:

    Test for data [1, 2, 3]:
    Sleep 1.71 start
    Awake after 1.71
    Sleep 1.74 start
    Awake after 1.74
    Sleep 1.83 start
    Awake after 1.83
    Normal function: [1, 4, 9]
    
    Coroutine sleep 1.32 start
    Coroutine sleep 1.01 start
    Coroutine awake after 1.01
    Coroutine sleep 1.98 start
    Coroutine awake after 1.32
    Coroutine awake after 1.98
    Coroutine [1, 4, 9]
    
    Test for data [1, 'yo', 3] without raise:
    Sleep 1.57 start
    Awake after 1.57
    Sleep 1.98 start
    Awake after 1.98
    Sleep 1.39 start
    Awake after 1.39
    Normal function: [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
    
    Coroutine sleep 1.22 start
    Coroutine sleep 2.0 start
    Coroutine awake after 1.22
    Coroutine sleep 1.96 start
    Coroutine awake after 2.0
    Coroutine awake after 1.96
    Coroutine [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
    
    Test for data [1, 'yo', 3] with raise:
    Sleep 1.99 start
    Awake after 1.99
    Sleep 1.74 start
    Awake after 1.74
    Sleep 1.52 start
    Awake after 1.52
    Traceback (most recent call last):
    ...
    line 52, in procedure
        raise Exception(errors)
    Exception: [TypeError("can't multiply sequence by non-int of type 'str'")]
    

    Do note that I've set concurrent_limit 2 to demonstrate coroutine waiting for available worker. That's why one coroutine task is not running immediately out of 3.

    From output you can also see some tasks are finished before others, but results are in order.


    P.S.

    If you're importing Queue separately because of PEP-8 line limit violated via type-hint, you can add type-hint as following:

    async def worker(input_queue, output_queue) -> None:
        input_queue: asyncio.Queue
        output_queue: asyncio.Queue
    

    or

    async def worker(
            input_queue: asyncio.Queue,
            output_queue: asyncio.Queue
        ) -> None:
    

    Although it's not as clean as original way, but that will help others reading your code.