airflowairflow-2.xairflow-taskflow

Airflow 2.0: Encapsulating DAG in class using Taskflow API


I have pipelines where the mechanics are always the same, a sequence of two tasks. So I try to abstract the construction of it through a parent abstract class (using TaskFlow API):

from abc import ABC, abstractmethod
from airflow.decorators import dag, task
from datetime import datetime

def AbstractDag(ABC):
    @abstractmethod
    def task_1(self):
        """task 1"""

    @abstractmethod
    def task_2(self, data):
        """task 2"""

    def dag_wrapper(self):
        @dag(schedule_interval=None, start_date=datetime(2022, 1, 1))
        def dag():
            @task(task_id='task_1')
            def task_1():
                return self.task_1()

            @task(task_id='task_2')
            def task_2(data):
                return self.task_2(data)

            task_2(task_1())

        return dag

But when I try to inherit this class, I can't see my dag in the interface:

class MyCustomDag(AbstractDag):
    def task_1(self):
        return 2

    @abstractmethod
    def task_2(self, data):
        print(data)


custom_dag = MyCustomDag()
dag_object = custom_dag.dag_wrapper()

Do you have any idea how to do this? or better ideas to abstract this?

Thanks a lot! Nicolas


Solution

  • I was able to get your example DAG to render in the UI with just a couple small tweaks:

    Here is the code I used:

    from abc import ABC, abstractmethod
    from airflow.decorators import dag, task
    from datetime import datetime
    
    class AbstractDag(ABC):
        @abstractmethod
        def task_1(self):
            """task 1"""
    
        @abstractmethod
        def task_2(self, data):
            """task 2"""
    
        def dag_wrapper(self):
            @dag(schedule_interval=None, start_date=datetime(2022, 1, 1))
            def _dag():
                @task(task_id='task_1')
                def task_1():
                    return self.task_1()
    
                @task(task_id='task_2')
                def task_2(data):
                    return self.task_2(data)
    
                task_2(task_1())
    
            return _dag()
    
    
    class MyCustomDag(AbstractDag):
        def task_1(self):
            return 2
    
        def task_2(self, data):
            print(data)
    
    
    custom_dag = MyCustomDag()
    dag_object = custom_dag.dag_wrapper()
    

    enter image description here