pythonairflowairflow-2.x

is it possible to build a tree form expanded Airflow DAG tasks? (dynamic task mapping over dynamic task mapping output)


I want to generate dynamic tasks from the dynamic task output. Each mapped task returns a list, and I'd like to create a separate mapped task for each of the element of the list so the process will look like this: Airflow dynamic task tree Is it possible to expand on the output of the dynamically mapped task so it will result in a sequence of map operations instead of a map and then reduce?

What I tried:

In my local environment, I'm using:

Astronomer Runtime 9.6.0 based on Airflow 2.7.3+astro.2
Git Version: .release:9fad9363bb0e7520a991b5efe2c192bb3405b675

For the sake of the experiment, I'm using three tasks with a single string as an input and a string list as an output.

1. Expand over a group with expanded task (map over a group with mapped tasks):

import datetime
import logging

from airflow.decorators import dag, task, task_group

@dag(schedule_interval=None, start_date=datetime.datetime(2023, 9, 27))
def try_dag3():

    @task
    def first() -> list[str]:
        return ["0", "1"]

    first_task = first()

    @task_group
    def my_group(input: str) -> list[str]:
    
        @task
        def second(input: str) -> list[str]:
            logging.info(f"input: {input}")
            result = []
            for i in range(3):
                result.append(f"{input}_{i}")

            # ['0_0', '0_1', '0_2']
            # ['1_0', '1_1', '1_2']
            return result

        second_task = second.expand(input=first_task)

        @task
        def third(input: str, input1: str = None):
            logging.info(f"input: {input}, input1: {input1}")
            return input

        third_task = third.expand(input=second_task)
        
    my_group.expand(input=first_task)

try_dag3()

but it causes NotImplementedError: operator expansion in an expanded task group is not yet supported

2. expand over the expanded task result (map over a mapped tasks):

import datetime
import logging

from airflow.decorators import dag, task

@dag(start_date=datetime.datetime(2023, 9, 27))
def try_dag1():

    @task
    def first() -> list[str]:
        return ["0", "1"]

    first_task = first()

    @task
    def second(input: str) -> list[str]:
        logging.info(f"source: {input}")
        result = []
        for i in range(3):
            result.append(f"{input}_{i}")

        # ['0_0', '0_1', '0_2']
        # ['1_0', '1_1', '1_2']
        return result

    # this expands fine into two tasks from the list returned by first_task
    second_task = second.expand(input=first_task)

    @task
    def third(input: str):
        logging.info(f"source: {input}")
        return input

    # this doesn't expand - there are two mapped tasks, and the input value is a list, not a string
    third_task = third.expand(input=second_task)


try_dag1()

but the result of second dag is not expanded, and third task input is a string list instead: dag1 graph third[0] task log: [2024-01-05, 11:40:30 UTC] {try_dag1.py:30} INFO - source: ['0_0', '0_1', '0_2']

3. Expand over the expanded task with const input (to test if the structure is possible):

import datetime
import logging

from airflow.decorators import dag, task

@dag(start_date=datetime.datetime(2023, 9, 27))
def try_dag0():

    @task
    def first() -> list[str]:
        return ["0", "1"]

    first_task = first()

    @task
    def second(input: str) -> list[str]:
        logging.info(f"input: {input}")
        result = []
        for i in range(3):
            result.append(f"{input}_{i}")

        # ['0_0', '0_1', '0_2']
        # ['1_0', '1_1', '1_2']
        return result

    second_task = second.expand(input=first_task)

    @task
    def third(input: str, input1: str = None):
        logging.info(f"input: {input}, input1: {input1}")
        return input

    third_task = third.expand(input=second_task, input1=["a", "b", "c"])


try_dag0()

It looks like the mapped tasks can be expanded over a constant list passed to input1, but input value is a nonexpanded list: dag0 graph third[0] task log: [2024-01-05, 12:51:39 UTC] {try_dag0.py:33} INFO - input: ['0_0', '0_1', '0_2'], input1: a


Solution

  • You would need to add a task which collects and flattens the result of second.

    @task
    def first() -> list[str]:
        return ['1', '2']
    
    @task
    def second(input: str) -> list[str]:
        return [f"{input}_{i}" for i in ['1', '2', '3']]
    
    @task
    def second_collect(input: list[list[str]]) -> list[str]:
        return list(chain.from_iterable(input))
    
    @task
    def third(input: str) -> str:
        return f"Result: {input}!"
    
    sc = second_collect(second.expand(input=first()))
    third.expand(input=sc)
    

    enter image description here

    Result of second_collect is ['1_1', '1_2', '1_3', '2_1', '2_2', '2_3'] (flattened result of mapped tasks).

    Results of third mapped tasks are: