I've been trying to write an async version of the map
function in Python for doing IO.
To do that, I'm using a queue with a producer/consumer.
At first it seems to be working well, but only without exceptions.
In particular, if I use queue.join()
, it works well when no exceptions, but blocks in case of exception.
If I use gather(*tasks)
, it works well when there is an exception, but blocks if not.
So it only finished sometimes, and I just don't understand why.
Here is the code I've implemented:
import asyncio
from asyncio import Queue
from typing import Iterable, Callable, TypeVar
Input = TypeVar("Input")
Output = TypeVar("Output")
STOP = object()
def parallel_map(func: Callable[[Input], Output], iterable: Iterable[Input]) -> Iterable[Output]:
"""
Parallel version of `map`, backed by asyncio.
Only suitable to do IO in parallel (not for CPU intensive tasks, otherwise it will block).
"""
number_of_parallel_calls = 9
async def worker(input_queue: Queue, output_queue: Queue) -> None:
while True:
data = await input_queue.get()
try:
output = func(data)
# Simulate an exception:
# raise RuntimeError("")
output_queue.put_nowait(output)
finally:
input_queue.task_done()
async def group_results(output_queue: Queue) -> Iterable[Output]:
output = []
while True:
item = await output_queue.get()
if item is not STOP:
output.append(item)
output_queue.task_done()
if item is STOP:
break
return output
async def procedure() -> Iterable[Output]:
# First, produce a queue of inputs
input_queue: Queue = asyncio.Queue()
for i in iterable:
input_queue.put_nowait(i)
# Then, assign a pool of tasks to consume it (and also produce outputs in a new queue)
output_queue: Queue = asyncio.Queue()
tasks = []
for _ in range(number_of_parallel_calls):
task = asyncio.create_task(worker(input_queue, output_queue))
tasks.append(task)
# Wait for the input queue to be fully consumed (only works if no exception occurs in the tasks), blocks otherwise.
await input_queue.join()
# Gather tasks, only works when an exception is raised in a task, blocks otherwise
# asyncio.gather(*tasks)
for task in tasks:
task.cancel()
# Indicate that the output queue is complete, to stop the worker
output_queue.put_nowait(STOP)
# Consume the output_queue, and return its data as a list
group_results_task = asyncio.create_task(group_results(output_queue))
await output_queue.join()
output = await group_results_task
return output
return asyncio.run(procedure())
if __name__ == "__main__()":
def my_function(x):
return x * x
data = [1, 2, 3, 4]
print(parallel_map(my_function, data))
I think I'm misunderstanding a basic but important with Python asyncio, but not sure what.
Problem is, your are NOT catching exceptions.
From Python doc
The count of unfinished tasks goes up whenever an item is added to the queue. The count goes down whenever a consumer coroutine calls task_done() to indicate that the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks.
So Queue
is essentially counting number of calls on put()
, and reducing counter by 1 on every task_done()
call. If worker stops before processing all queue, you are blocked in Queue.join()
.
At your worker code:
async def worker(input_queue: Queue, output_queue: Queue) -> None:
while True:
data = await input_queue.get()
try:
output = func(data)
output_queue.put_nowait(output)
finally:
input_queue.task_done()
Your worker stops when it encounters Exception, becuase try-finally
only guarantee cleanup, not actual error handling.
Therefore what is happening for your case is:
Queue.put()
call increase internal counter, lets say we called it n
times.Queue.task_done()
to decrease internal counter.finally
block.Queue.task_done()
call count n'
is n' < n
, internal counter is still in positive value.Queue.join()
Hangs indefinitely until internal counter is 0, which never happens as all worker died.This is a design Flaw.
There is multiple design factors to be changed for the sake of easy implementation, less prune to failure, and performance.
Do note that this is from my experience with python, so don't take this as concrete facts.
For design factors, I've made following changes:
function
only, so it's better to support coroutine
too.input_queue
is guaranteed to be populated before worker
. Checking Queue.empty()
would be enough to determine the end of loop.input_queue
then no need of sentinel, you know how long given iterable
was, by queue.qszie()
.await Queue.put()
rather than put_nowait()
, You can't be sure if Queue
is not available at that precise timing you're putting on it.Exception
, put errors in result and process all queue, then simply re-raise it on user's choice.for
is not needed for this task - and list.append
will hinder your script's performance.Queue
from asyncio
, it does not warn user enough if it's from queue
or any other libraries with built-in Queue
objects.queue.join
to run group_results
- await
it altogether.Function Code:
import asyncio
def parallel_map(func, iterable, concurrent_limit=2, raise_error=False):
async def worker(input_queue: asyncio.Queue, output_queue: asyncio.Queue):
while not input_queue.empty():
# EDIT(2024-11-26): use get_nowait here not await get
# vvvvvvvvvvvv
idx, item = input_queue.get_nowait()
try:
# Support both coroutine and function. Coroutine function I mean!
if asyncio.iscoroutinefunction(func):
output = await func(item)
else:
output = func(item)
await output_queue.put((idx, output))
except Exception as err:
await output_queue.put((idx, err))
finally:
input_queue.task_done()
async def group_results(input_size, output_queue: asyncio.Queue):
output = {} # using dict to remove the need to sort list
for _ in range(input_size):
idx, val = await output_queue.get() # gets tuple(idx, result)
output[idx] = val
output_queue.task_done()
return [output[i] for i in range(input_size)]
async def procedure():
# populating input queue
input_queue: asyncio.Queue = asyncio.Queue()
for idx, item in enumerate(iterable):
input_queue.put_nowait((idx, item))
# Remember size before using Queue
input_size = input_queue.qsize()
# Generate task pool, and start collecting data.
output_queue: asyncio.Queue = asyncio.Queue()
result_task = asyncio.create_task(group_results(input_size, output_queue))
tasks = [
asyncio.create_task(worker(input_queue, output_queue))
for _ in range(concurrent_limit)
]
# Wait for tasks complete
await asyncio.gather(*tasks)
# Wait for result fetching
results = await result_task
# Re-raise errors at once if raise_error
if raise_error and (errors := [err for err in results if isinstance(err, Exception)]):
# noinspection PyUnboundLocalVariable
raise Exception(errors) # It never runs before assignment, safe to ignore.
return results
return asyncio.run(procedure())
Test code:
if __name__ == "__main__":
import random
import time
data = [1, 2, 3]
err_data = [1, 'yo', 3]
def test_normal_function(data_, raise_=False):
def my_function(x):
t = random.uniform(1, 2)
print(f"Sleep {t:.3} start")
time.sleep(t)
print(f"Awake after {t:.3}")
return x * x
print(f"Normal function: {parallel_map(my_function, data_, raise_error=raise_)}\n")
def test_coroutine(data_, raise_=False):
async def my_coro(x):
t = random.uniform(1, 2)
print(f"Coroutine sleep {t:.3} start")
await asyncio.sleep(t)
print(f"Coroutine awake after {t:.3}")
return x * x
print(f"Coroutine {parallel_map(my_coro, data_, raise_error=raise_)}\n")
# Test starts
print(f"Test for data {data}:")
test_normal_function(data)
test_coroutine(data)
print(f"Test for data {err_data} without raise:")
test_normal_function(err_data)
test_coroutine(err_data)
print(f"Test for data {err_data} with raise:")
test_normal_function(err_data, True)
test_coroutine(err_data, True) # this line will not run, but works same.
Above will test for following conditions for both function
and coroutine
:
Even with Exceptions this will not cancel task, rather it processes all queue.
Output:
Test for data [1, 2, 3]:
Sleep 1.71 start
Awake after 1.71
Sleep 1.74 start
Awake after 1.74
Sleep 1.83 start
Awake after 1.83
Normal function: [1, 4, 9]
Coroutine sleep 1.32 start
Coroutine sleep 1.01 start
Coroutine awake after 1.01
Coroutine sleep 1.98 start
Coroutine awake after 1.32
Coroutine awake after 1.98
Coroutine [1, 4, 9]
Test for data [1, 'yo', 3] without raise:
Sleep 1.57 start
Awake after 1.57
Sleep 1.98 start
Awake after 1.98
Sleep 1.39 start
Awake after 1.39
Normal function: [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
Coroutine sleep 1.22 start
Coroutine sleep 2.0 start
Coroutine awake after 1.22
Coroutine sleep 1.96 start
Coroutine awake after 2.0
Coroutine awake after 1.96
Coroutine [1, TypeError("can't multiply sequence by non-int of type 'str'"), 9]
Test for data [1, 'yo', 3] with raise:
Sleep 1.99 start
Awake after 1.99
Sleep 1.74 start
Awake after 1.74
Sleep 1.52 start
Awake after 1.52
Traceback (most recent call last):
...
line 52, in procedure
raise Exception(errors)
Exception: [TypeError("can't multiply sequence by non-int of type 'str'")]
Do note that I've set concurrent_limit
2 to demonstrate coroutine waiting for available worker. That's why one coroutine task is not running immediately out of 3.
From output you can also see some tasks are finished before others, but results are in order.
P.S.
If you're importing Queue
separately because of PEP-8 line limit violated via type-hint, you can add type-hint as following:
async def worker(input_queue, output_queue) -> None:
input_queue: asyncio.Queue
output_queue: asyncio.Queue
or
async def worker(
input_queue: asyncio.Queue,
output_queue: asyncio.Queue
) -> None:
Although it's not as clean as original way, but that will help others reading your code.