pythonairflowairflow-taskflow

Airflow branching: A task that only sometimes depends on an upstream task


I have two tasks: task_a and task_b. There are DAG-parameters run_task_a and run_task_b that determine whether each task should be run. There is further parameter that is an input for task_a. Here's the important part:

If task_a is run, then task_b should start only after task_a has finished. However, if task_a is not run, then task_b can start whenever.

(Motivation: task_a is the main task. A new run of task_a can result in defunct artifacts, which task_b cleans up. However, one may wish to trigger task_b independently.)

This is what I have written so far:

from airflow.decorators import dag, task
from airflow.models.param import Param
from datetime import datetime

default_args = {
  'owner': 'xyz',
  'email_on_retry': False,
  'email_on_failure': False,
  'retries': 0,
  'provide_context': True,
  'depends_on_past': False
}

@dag(
  default_args=default_args,
  start_date=datetime(2024, 3, 7),
  schedule_interval=None,
  params={
    'run_task_a': Param(
      True,
      type='boolean'),
    'run_task_b': Param(
      True,
      type='boolean'),
    'param_for_task_a': Param(
      'foo',
      enum=['foo','bar'],
      type='string')
      }
)
def my_dag():

  @task
  def get_context_values(**context):

    context_values = dict()
    context_values['params'] = context['params']

    return context_values

  @task.branch
  def branching(context_values):
    tasks_to_run = []

    if context_values['params']['run_task_a']:
      tasks_to_run.append('task_a')

    if context_values['params']['run_task_b']:
      tasks_to_run.append('task_b')

    return tasks_to_run

  @task
  def task_a(context_values):

    param_for_task_a = context_values['params']['param_for_task_a']

    if param_for_task_a == 'foo':
      # Do some stuff
      pass

    if param_for_task_a == 'bar':
      # Do some different stuff
      pass

    return None

  @task
  def task_b():

    # Do some more stuff
    
    return None

  # Taskflow
  context_values = get_context_values()
  branching(context_values) >> [task_a(context_values),task_b()]

my_dag()

enter image description here

The problem is when run_task_a == True and run_task_b == True: Both tasks run, but of course task_b does not wait for task_a to finish before starting because there is no dependency. I've tried to add this dependency by making task_b a downstream task of task_a, but then task_b does not run if run_task_a == False and run_task_b == True. Trigger rules also don't seem to be the solution, since task_b should not be run if run_task_b == False.


Solution

  • After a lot of trial and error, we managed to get it to work using short-circuiting:

    from airflow.decorators import dag, task
    from airflow.models.param import Param
    from airflow.utils.trigger_rule import TriggerRule
    from datetime import datetime
    
    default_args = {
      'owner': 'xyz',
      'email_on_retry': False,
      'email_on_failure': False,
      'retries': 0,
      'provide_context': True,
      'depends_on_past': False
    }
    
    @dag(
      default_args=default_args,
      start_date=datetime(2024, 3, 7),
      schedule_interval=None,
      params={
        'run_task_a': Param(
          True,
          type='boolean'),
        'run_task_b': Param(
          True,
          type='boolean'),
        'param_for_task_a': Param(
          'foo',
          enum=['foo','bar'],
          type='string')
          }
    )
    def my_dag():
    
      @task
      def get_context_values(**context):
    
        context_values = dict()
        context_values['params'] = context['params']
    
        return context_values
    
      @task.short_circuit
      def short_circuit(context_values,key):
        return context_values['params'][key]
    
      @task
      def task_a(context_values):
    
        param_for_task_a = context_values['params']['param_for_task_a']
    
        if param_for_task_a == 'foo':
          # Do some stuff
          pass
    
        if param_for_task_a == 'bar':
          # Do some different stuff
          pass
    
        return None
    
      @task(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
      def task_b():
    
        # Do some more stuff
        
        return None
    
      # Taskflow
      context_values = get_context_values()
      short_circuit_a = short_circuit.override(
        task_id='short_circuit_a',ignore_downstream_trigger_rules=False)(context_values,'run_task_a')
      a = task_a(context_values)
      short_circuit_b = short_circuit.override(
        task_id='short_circuit_b')(context_values,'run_task_b')
      b = task_b()
      short_circuit_a >> a
      short_circuit_b >> b
      a >> b
    
    my_dag()
    

    enter image description here