python-3.xairflowairflow-2.xorchestrationairflow-taskflow

Airflow : Complete all tasks in a TaskGroup before running to the next one and avoid dependancies between TaskGroup


I would like to set up a DAG where all tasks in a single TaskGroup are done before running to the next one. Meaning that in the example (cf screenshot), the Workflow_FRA has to be done with the tasks run_task_FRA and run_next_task_FRA then the Workflow_BEL would run and on and on.

Below the DAG script but it's running the tasks in parallel regardless the TaskGroup.

enter image description here

import time
from datetime import datetime
from airflow.utils.task_group import TaskGroup
from airflow.decorators import task
from airflow import DAG

with DAG(
        dag_id="dev_dag",
        concurrency=1,
        start_date=datetime(2024, 2, 27),
        schedule_interval='*/1 * * * *',
        catchup=False
) as dag:

    @task(task_id="start_task")
    def start_task():
        print("start")

    start_task = start_task()

    @task(task_id="end_task")
    def end_task():
        print("end")

    end_task = end_task()

    for country in ["FR", "BE", "SP", "EN"]:
        with TaskGroup(group_id=f"workflow_{country}") as workflow:
            @task(task_id=f"run_task_{country}")
            def run_task():
                time.sleep(5)
                print("run task")


            @task(task_id=f"run_next_task_{country}")
            def run_next_task():
                time.sleep(5)
                print("run next task")

            start_task >> run_task() >> run_next_task() >> end_task
    start_task >> workflow >> end_task

What I want to achieve it's, if the TaskGroup workflow_BE fails, the next TaskGroup worflows are able to run and I'd like to clear the tasks from the one that has failed without running the next taskgroups again


Solution

  • You can use the chain operator:

    import time
    from datetime import datetime
    from airflow.utils.task_group import TaskGroup
    from airflow.decorators import task
    from airflow.models.baseoperator import chain
    from airflow import DAG
    
    with DAG(
            dag_id="dev_dag",
            concurrency=1,
            start_date=datetime(2024, 2, 27),
            schedule_interval=None,
            catchup=False
    ) as dag:
    
        @task(task_id="start_task")
        def start_task():
            print("start")
    
        start_task = start_task()
    
        @task(task_id="end_task")
        def end_task():
            print("end")
    
        end_task = end_task()
    
        tasks = [start_task]
        for country in ["FR", "BE", "SP", "EN"]:
            with TaskGroup(group_id=f"workflow_{country}") as workflow:
                @task(task_id=f"run_task_{country}")
                def run_task():
                    time.sleep(5)
                    print("run task")
    
    
                @task(
                    task_id=f"run_next_task_{country}",
                    trigger_rule="all_done"
                )
                def run_next_task():
                    time.sleep(5)
                    print("run next task")
    
                start_task >> run_task() >> run_next_task() >> end_task
            tasks.append(workflow)
        tasks.append(end_task)
        chain(*tasks)
    

    The dependency graph will look like this: enter image description here