pythonairflow

How to make stop/contunie task in Airflow?


I'm trying to make Airflow dag to update database. If I get no mistakes while get data from API I need to insert data to database. If there's any errors - I need send errors messages.

So I need add check length of errors dict.enter image description here

Actually I need make transform data then there's no errors before insert data to database, but I think there's no need to add this code to question.

How to make stop/continue tasks with using decorators? Continue/stop logic: if len(errors) == 0: *insert table to database else: *send errors text to telegram

P. S. I make dict with table, cause I actually get from API few tables.

There's my code.

import pandas as pd
import time
import clickhouse_connect
import api_library

from airflow.decorators import dag, task

default_args = {
    'owner': 'owner',
    'depends_on_past': False,
    'retries': 2,
    'retry_delay': datetime.timedelta(minutes=5),
    'start_date': datetime.datetime(2024, 06, 20)
}

schedule_interval = '*/15 * * * *'

connect = clickhouse_connect.get_client(
    host = '*.*.*.*'
    , port = 8443
    , database = 'database'
    , username = 'admin'
    , password = 'password'
    )

bearer_key = '***'
user_key = '***'

def get_data_or_raise_error_with_retry(func, max_tries=10, **args):
    for _ in range(max_tries):
        try:
            # add sleep to avoid API break
            time.sleep(0.3)
            if len(args) == 0:
                return func()
            else:
                return func(**args)
        # I get error text instead pd.Dataframe
        except Exception as e:
            return e

def make_dict_api_tables(tables: list):
    error_text = 'Error in table {}.\nCheck function.{}.'
    # dict with functions I using to get tables
    tables = {
        'stores': {
            'text': 'stores' # table name
            , 'function': 'get_table' # fucntion to get data from API
            , 'result': tables[0]  # result with pd.Dataframe or tuple in some cases or error text
        }
    }
    tables_from_api = {}
    tables_from_api['stores'] = {
                'result': tables['stores']['result']
                , 'error_text': error_text.format(tables['stores']['text'], tables['stores']['function'])
            }
    return tables_from_api

def make_dict_with_errors(tables_dict: dict):
    messages_with_transform_errors = {}
    for key in tables_dict.keys():
        if type(tables_dict[key]['result']) not in [pd.DataFrame, tuple]: # if result not pd.Dataframe or tuple it's error
            # add error text to dict
            messages_with_transform_errors[key] = tables_dict[key]['error_text']
    return messages_with_transform_errors


@dag(default_args=default_args, schedule_interval=schedule_interval, catchup=False, concurrency=4)
def dag_update_database():
    @task
    def connect_to_api(bearer_key: str, user_key: str):
        # connecting to API
        api = api_library(bearer_key, user_key)
        return api
    
    @task
    def get_table_from_api(api, tries: int):
        # get table, result id pd.Dataframe
        result_from_salons_api = get_data_or_raise_error_with_retry(api.get_table, tries)
        return list(result_from_salons_api)


    @task
    def make_dict_with_tables_and_errors(table: list):
        # make dict with table
        tables_dict = make_dict_api_tables(table)
        # make dict with errors, dict will be empty if there's no errors
        errors = make_dict_with_errors(tables_dict)
        return tables_dict, errors

Solution

  • I believe you may benefit from using BranchPythonOperator. A bit of general reference - https://www.astronomer.io/docs/learn/airflow-branch-operator#taskbranch-branchpythonoperator. In short, the idea is simple a straightforward - add IF/else logic which depends on the condition to decide how to proceed based on the upstream XCOM result

    In code, it might look like this:

    import datetime
    from airflow.decorators import dag, task
    
    
    @dag(
        dag_id="validate_dag",
        start_date=datetime.datetime(2024, 6, 20),
        default_args={
            "owner": "owner",
            "depends_on_past": False,
            "retries": 2,
            "retry_delay": datetime.timedelta(minutes=5),
        },
        schedule_interval="*/15 * * * *",
        catchup=False,
    )
    def database_update_r():
    
        @task
        def get_table_from_api():
            # get table, result id pd.Dataframe
            result_from_salons_api = [{"id": 1, "name": "store1"}, {"id": 2, "name": "store2"}]
            return result_from_salons_api
    
        @task
        def make_dict_with_tables_and_errors(table: list) -> dict[str, dict]:
            # make dict with table
            tables_dict = {"table": "table_content"}
            # make dict with errors, dict will be empty if there's no errors
            # NOTE: case #1
            # errors = {"error": "error"}
            # NOTE: case #2
            errors = {}
            return {"tables_output": tables_dict, "errors_output": errors}
    
        @task.branch
        def validate_on_errors(errors: dict):
            if errors:
                return "send_error"
            else:
                return "write_to_database"
    
        @task
        def send_error():
            # Task to send an error to telegram
            pass
    
        @task
        def write_to_database():
            # update database
            pass
    
        validate_on_errors(make_dict_with_tables_and_errors(get_table_from_api())["errors_output"]) >> [
            send_error(),
            write_to_database(),
        ]
    
    
    database_update_r()
    

    As a result, one of the tasks after BranchOperator will be skipped enter image description here