pythonconcurrencythreadpoolthreadpoolexecutorconcurrent.futures

Python ThreadPoolExecutor: how to submit inside a function submitted


I am using Python's concurrent.futures library with ThreadPoolExecutor and I want to submit inside a function submitted. The following code is a minimal example trying to submit f2 inside submitted f1:

import concurrent.futures


def f2():
    print("hello, f2")


def f1():
    print("hello, f1")
    executor.submit(f2)


with concurrent.futures.ThreadPoolExecutor(16) as executor:
    executor.submit(f1)

Output (Python 3.12):

hello, f1

Output is the same even with:

import concurrent.futures


def f2():
    print("hello, f2")
    return 3


def f1():
    print("hello, f1")
    print(executor.submit(f2).result())


with concurrent.futures.ThreadPoolExecutor(16) as executor:
    executor.submit(f1)

But this code works:

import concurrent.futures


def f2():
    print("hello, f2")
    return 3


def f1():
    print("hello, f1")
    return executor.submit(f2).result()


with concurrent.futures.ThreadPoolExecutor(16) as executor:
    print(executor.submit(f1).result())

Output:

hello, f1
hello, f2
3

Why f2 is not called in example 1?

Update:

Without the with statement, the output of example 1 is random (i.e. sometimes outputs hello, f2 while sometimes not)


The real use is like:

import concurrent.futures


def f3(arg1, arg2):
    print(f"hello, f3, {arg1}, {arg2}")


def f2(arg1, arg2):
    print(f"hello, f2 {arg1}")
    for i in range(10):
        executor.submit(f3, arg2, i)


def f1(arg1, arg2, arg3):
    print(f"hello, f1 {arg1}")
    for i in range(10):
        executor.submit(f2, i, arg2, arg3)


with concurrent.futures.ThreadPoolExecutor(16) as executor:
    for i in range(10):
        executor.submit(f1, i, 1, 2)

It's complicated, so I want a simple solution.

The call to f1 and f2 is also expensive so I want all the tasks to run concurrently (thread pool of size 16 is OK though). f1 and f2 should return once all tasks have been submitted (not completed).


Solution

  • You have:

    with concurrent.futures.ThreadPoolExecutor(16) as executor:
        executor.submit(f1)
    

    The main thread submits a task specifying worker function f1. Then the main thread exits the block and an implicit call to executor.shutdown() is made. Any tasks already submitted and have started to execute will complete but onceshutdown is called submitted tasks that have not yet started execution will be thrown away. In your code the call to shutdown occurs before worker function f1 has had a chance to submit the new task with f2 as the worker function and get its execution started. This can be demonstrated as follows:

    with concurrent.futures.ThreadPoolExecutor(16) as executor:
        executor.submit(f1)
        import time
        time.sleep(.1)
    

    We have delayed the call to shutdown by .1 seconds giving f1 a chance to get f2 started. But even this has a race condition: Is .1 seconds always enough time to allow f1 to submit the second task and for that task to start? We cannot depend on this method.

    TL;DR

    You can skip to the final section Solution if you wish and not read the following solutions for simpler cases.

    Attempts

    To remove that race condition we can use a multithreading.Event that gets set only after all tasks that we need to submit have started executing:

    import concurrent.futures
    from threading import Event
    
    all_tasks_submitted = Event()
    
    def f2():
        all_tasks_submitted.set()
        print("hello, f2")
        return 3
    
    
    def f1():
        print("hello, f1")
        print(executor.submit(f2).result())
    
    
    with concurrent.futures.ThreadPoolExecutor(16) as executor:
        executor.submit(f1)
        all_tasks_submitted.wait()
    

    Prints:

    hello, f1
    hello, f2
    3
    

    So now let's look at your actual case. First, there is a slight bug: f2 takes only two arguments but f1 is trying to invoke it with 3 arguments.

    This is far more complicated case in that we are ultimately trying to start 10 * 10 * 10 = 1000 f3 tasks. So we now need a counter to keep track of how many f3 have been started:

    import concurrent.futures
    from threading import Event, Lock
    
    all_tasks_started = Event()
    lock = Lock()
    
    
    NUM_F3_TASKS = 1_000
    total_f3_tasks_started = 0
    
    def f3(arg1, arg2):
        global total_f3_tasks_started
    
        with lock:
            total_f3_tasks_started += 1
            n = total_f3_tasks_started
    
        if n == NUM_F3_TASKS:
            all_tasks_started.set()
    
        print(f"hello, f3, {arg1}, {arg2}, f3 tasks started = {n}")
    
    
    def f2(arg1, arg2):
        print(f"hello, f2 {arg1}")
        for i in range(10):
            executor.submit(f3, arg2, i)
    
    
    def f1(arg1, arg2, arg3):
        print(f"hello, f1 {arg1}")
        for i in range(10):
            executor.submit(f2, arg2, arg3)
    
    
    with concurrent.futures.ThreadPoolExecutor(16) as executor:
        for i in range(10):
            executor.submit(f1, i, 1, 2)
        all_tasks_started.wait()
    

    Prints:

    hello, f1 0
    hello, f1 1
    hello, f1 2
    hello, f1 3
    hello, f1 4
    hello, f1 5
    hello, f1 6
    hello, f1 7
    hello, f1 8
    hello, f1 9
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    
    ...
    
    hello, f3, 2, 2, f3 tasks started = 993
    hello, f3, 2, 4, f3 tasks started = 995
    hello, f3, 2, 6, f3 tasks started = 997
    hello, f3, 2, 8, f3 tasks started = 999
    hello, f3, 2, 3, f3 tasks started = 994
    hello, f3, 2, 7, f3 tasks started = 998
    hello, f3, 2, 5, f3 tasks started = 996
    hello, f3, 2, 9, f3 tasks started = 1000
    

    But this means that you need to know in advance exactly how many f3 tasks need to be created. You might be tempted to solve the problem by having f1 not return until all tasks it has submitted complete and having f2 not return until all tasks it has submitted complete. You would thus be having a 10 f1 tasks, 100 f2 tasks and 1000 f3 tasks running concurrently for which you would need a thread pool of size 1110.

    Solution

    We use an explicit task queue and a task executor as follows:

    import concurrent.futures
    from queue import Queue
    from threading import Lock
    
    task_queue = Queue()
    lock = Lock()
    
    task_number = 0
    
    def f3(arg1, arg2):
        global task_number
    
        with lock:
            task_number += 1
            n = task_number
    
        print(f"hello, f3, {arg1}, {arg2}, task_number = {n}")
    
    
    def f2(arg1, arg2):
        print(f"hello, f2 {arg1}")
        for i in range(10):
            task_queue.put((f3, arg2, i))
    
    
    def f1(arg1, arg2, arg3):
        print(f"hello, f1 {arg1}")
        for i in range(10):
            task_queue.put((f2, arg2, arg3))
    
    
    def pool_executor():
        while True:
            task = task_queue.get()
            if task is None:
                # sentinel to terminate
                return
    
            fn, *args = task
            fn(*args)
            # Show this work has been completed:
            task_queue.task_done()
    
    
    POOL_SIZE = 16
    
    with concurrent.futures.ThreadPoolExecutor(POOL_SIZE) as executor:
        for _ in range(POOL_SIZE):
            executor.submit(pool_executor)
    
        for i in range(10):
            task_queue.put((f1, i, 1, 2))
    
        # Wait for all tasks to complete
        task_queue.join()
        # Now we need to terminate the running pool_executor tasks:
        # Add sentinels:
        for _ in range(POOL_SIZE):
            task_queue.put(None)
    

    Prints:

    hello, f1 0
    hello, f1 1
    hello, f1 3
    hello, f1 5
    hello, f1 7
    hello, f1 9
    hello, f1 2
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f1 4
    hello, f1 6
    hello, f1 8
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    hello, f2 1
    
    ...
    
    hello, f3, 2, 1, task_number = 992
    hello, f3, 2, 2, task_number = 993
    hello, f3, 2, 4, task_number = 995
    hello, f3, 2, 6, task_number = 997
    hello, f3, 2, 8, task_number = 999
    hello, f3, 2, 3, task_number = 994
    hello, f3, 2, 7, task_number = 998
    hello, f3, 2, 5, task_number = 996
    hello, f3, 2, 9, task_number = 1000
    

    Perhaps you should consider creating your own thread pool with dameon threads, which will terminate when the main process terminates (you could still use the technique of adding sentinel values to signal these threads to terminate when we no longer require them in which case the threads need not be daemon threads).

    from queue import Queue
    from threading import Lock, Thread
    
    ...
    
    def pool_executor():
        while True:
            fn, *args = task_queue.get()
            fn(*args)
            # Show this work has been completed:
            task_queue.task_done()
    
    
    POOL_SIZE = 16
    
    for _ in range(POOL_SIZE):
        Thread(target=pool_executor, daemon=True).start()
    
    for i in range(10):
        task_queue.put((f1, i, 1, 2))
    
    # Wait for all tasks to complete
    task_queue.join()
    

    A New Type of Multithreading Pool

    We can abstract a multithreading pool that allows running tasks to continue to arbitrarily submit additional tasks and then be able to wait for all tasks to complete. That is, we wait until the task queue has quiesced, the condition where the task queue is empty and no new tasks will be added because there are no tasks currently running:

    from queue import Queue
    from threading import Thread
    
    class ThreadPool:
        def __init__(self, pool_size):
            self._pool_size = pool_size
            self._task_queue = Queue()
            self._shutting_down = False
            for _ in range(self._pool_size):
                Thread(target=self._executor, daemon=True).start()
    
        def __enter__(self):
            return self
    
        def __exit__(self, exc_type, exc_val, exc_tb):
            self.shutdown()
    
        def _terminate_threads(self):
            """Tell threads to terminate."""
            # No new tasks in case this is an immediate shutdown:
            self._shutting_down = True
    
            for _ in range(self._pool_size):
                self._task_queue.put(None)
            self._task_queue.join()  # Wait for all threads to terminate
    
    
        def shutdown(self, wait=True):
            if wait:
                # Wait until the task queue quiesces (becomes empty).
                # Running tasks may be continuing to submit tasks to the queue but
                # the expectation is that at some point no more tasks will be added
                # and we wait for the queue to become empty:
                self._task_queue.join()
            self._terminate_threads()
    
        def submit(self, fn, *args):
            if self._shutting_down:
                return
            self._task_queue.put((fn, args))
    
        def _executor(self):
            while True:
                task = self._task_queue.get()
                if task is None:  # sentinel
                    self._task_queue.task_done()
                    return
                fn, args = task
                try:
                    fn(*args)
                except Exception as e:
                    print(e)
                # Show this work has been completed:
                self._task_queue.task_done()
    
    ###############################################
    
    from threading import Lock
    
    lock = Lock()
    
    task_number = 0
    
    results = []
    
    def f3(arg1, arg2):
        global task_number
    
        with lock:
            task_number += 1
            n = task_number
    
        #print(f"hello, f3, {arg1}, {arg2}, task_number = {n}")
        results.append(f"hello, f3, {arg1}, {arg2}, task_number = {n}")
    
    
    def f2(arg1, arg2):
        for i in range(10):
            pool.submit(f3, arg2, i)
    
    def f1(arg1, arg2, arg3):
        for i in range(10):
            pool.submit(f2, arg2, arg3)
    
    
    with ThreadPool(16) as pool:
        for i in range(10):
            pool.submit(f1, i, 1, 2)
    
    for result in results:
        print(result)
    

    Another Way That Uses Standard concurrent.futures Methods

    As you have observed, in the above solution an f1 task will complete before the f2 tasks it has submitted has completed and f2 tasks will terminate before f3 tasks have terminated. The problem with your original code was due to a shutdown being implicitly called before all 1000 f3 tasks were submitted. We can prevent this premature shutdown from occuring by having each worker function return a list of Future instance whose results we await:

    from concurrent.futures import ThreadPoolExecutor, Future
    from threading import Lock
    
    task_number = 0
    
    lock = Lock()
    
    futures = []
    
    def f3(arg1, arg2):
        global task_number
    
        with lock:
            task_number += 1
            n = task_number
    
        print(f"hello, f3, {arg1}, {arg2}, f3 tasks started = {n}")
    
    
    def f2(arg1, arg2):
        print(f"hello, f2 {arg1}")
        futures.extend(
            executor.submit(f3, arg2, i)
            for i in range(10)
        )
    
    
    def f1(arg1, arg2, arg3):
        print(f"hello, f1 {arg1}")
        futures.extend(
            executor.submit(f2, arg2, arg3)
            for i in range(10)
        )
    
    
    with ThreadPoolExecutor(16) as executor:
        futures.extend(
            executor.submit(f1, i, 1, 2)
            for i in range(10)
        )
    
        cnt = 0
        for future in futures:
            future.result()
            cnt += 1
        print(cnt, 'tasks completed.')
    

    Prints:

    ...
    hello, f3, 2, 4, f3 tasks started = 995
    hello, f3, 2, 6, f3 tasks started = 997
    hello, f3, 2, 8, f3 tasks started = 999
    hello, f3, 2, 3, f3 tasks started = 994
    hello, f3, 2, 7, f3 tasks started = 998
    hello, f3, 2, 5, f3 tasks started = 996
    hello, f3, 2, 9, f3 tasks started = 1000
    1110 tasks completed.