I would like to set up a DAG where all tasks in a single TaskGroup are done before running to the next one.
Meaning that in the example (cf screenshot), the Workflow_FRA
has to be done with the tasks run_task_FRA
and run_next_task_FRA
then the Workflow_BEL
would run and on and on.
Below the DAG script but it's running the tasks in parallel regardless the TaskGroup.
import time
from datetime import datetime
from airflow.utils.task_group import TaskGroup
from airflow.decorators import task
from airflow import DAG
with DAG(
dag_id="dev_dag",
concurrency=1,
start_date=datetime(2024, 2, 27),
schedule_interval='*/1 * * * *',
catchup=False
) as dag:
@task(task_id="start_task")
def start_task():
print("start")
start_task = start_task()
@task(task_id="end_task")
def end_task():
print("end")
end_task = end_task()
for country in ["FR", "BE", "SP", "EN"]:
with TaskGroup(group_id=f"workflow_{country}") as workflow:
@task(task_id=f"run_task_{country}")
def run_task():
time.sleep(5)
print("run task")
@task(task_id=f"run_next_task_{country}")
def run_next_task():
time.sleep(5)
print("run next task")
start_task >> run_task() >> run_next_task() >> end_task
start_task >> workflow >> end_task
What I want to achieve it's, if the TaskGroup workflow_BE fails, the next TaskGroup worflows are able to run and I'd like to clear the tasks from the one that has failed without running the next taskgroups again
You can use the chain
operator:
import time
from datetime import datetime
from airflow.utils.task_group import TaskGroup
from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow import DAG
with DAG(
dag_id="dev_dag",
concurrency=1,
start_date=datetime(2024, 2, 27),
schedule_interval=None,
catchup=False
) as dag:
@task(task_id="start_task")
def start_task():
print("start")
start_task = start_task()
@task(task_id="end_task")
def end_task():
print("end")
end_task = end_task()
tasks = [start_task]
for country in ["FR", "BE", "SP", "EN"]:
with TaskGroup(group_id=f"workflow_{country}") as workflow:
@task(task_id=f"run_task_{country}")
def run_task():
time.sleep(5)
print("run task")
@task(
task_id=f"run_next_task_{country}",
trigger_rule="all_done"
)
def run_next_task():
time.sleep(5)
print("run next task")
start_task >> run_task() >> run_next_task() >> end_task
tasks.append(workflow)
tasks.append(end_task)
chain(*tasks)