Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

DAG marked as "success" if one task fails, because of trigger rule ALL_DONE

Tags:

airflow

I have the following DAG with 3 tasks:

start --> special_task --> end

The task in the middle can succeed or fail, but end must always be executed (imagine this is a task for cleanly closing resources). For that, I used the trigger rule ALL_DONE:

end.trigger_rule = trigger_rule.TriggerRule.ALL_DONE

Using that, end is properly executed if special_task fails. However, since end is the last task and succeeds, the DAG is always marked as SUCCESS.

How can I configure my DAG so that if one of the tasks failed, the whole DAG is marked as FAILED?

Example to reproduce

import datetime

from airflow import DAG
from airflow.operators.bash_operator import BashOperator
from airflow.utils import trigger_rule

dag = DAG(
    dag_id='my_dag',
    start_date=datetime.datetime.today(),
    schedule_interval=None
)

start = BashOperator(
    task_id='start',
    bash_command='echo start',
    dag=dag
)

special_task = BashOperator(
    task_id='special_task',
    bash_command='exit 1', # force failure
    dag=dag
)

end = BashOperator(
    task_id='end',
    bash_command='echo end',
    dag=dag
)
end.trigger_rule = trigger_rule.TriggerRule.ALL_DONE

start.set_downstream(special_task)
special_task.set_downstream(end)

This post seems to be related, but the answer does not suit my needs, since the downstream task end must be executed (hence the mandatory trigger_rule).

like image 281
norbjd Avatar asked Aug 07 '18 13:08

norbjd


3 Answers

I thought it was an interesting question and spent some time figuring out how to achieve it without an extra dummy task. It became a bit of a superfluous task, but here's the end result:

This is the full DAG:

import airflow
from airflow import AirflowException
from airflow.models import DAG, TaskInstance, BaseOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule

default_args = {"owner": "airflow", "start_date": airflow.utils.dates.days_ago(3)}

dag = DAG(
    dag_id="finally_task_set_end_state",
    default_args=default_args,
    schedule_interval="0 0 * * *",
    description="Answer for question https://stackoverflow.com/questions/51728441",
)

start = BashOperator(task_id="start", bash_command="echo start", dag=dag)
failing_task = BashOperator(task_id="failing_task", bash_command="exit 1", dag=dag)


@provide_session
def _finally(task, execution_date, dag, session=None, **_):
    upstream_task_instances = (
        session.query(TaskInstance)
        .filter(
            TaskInstance.dag_id == dag.dag_id,
            TaskInstance.execution_date == execution_date,
            TaskInstance.task_id.in_(task.upstream_task_ids),
        )
        .all()
    )
    upstream_states = [ti.state for ti in upstream_task_instances]
    fail_this_task = State.FAILED in upstream_states

    print("Do logic here...")

    if fail_this_task:
        raise AirflowException("Failing task because one or more upstream tasks failed.")


finally_ = PythonOperator(
    task_id="finally",
    python_callable=_finally,
    trigger_rule=TriggerRule.ALL_DONE,
    provide_context=True,
    dag=dag,
)

succesful_task = DummyOperator(task_id="succesful_task", dag=dag)

start >> [failing_task, succesful_task] >> finally_

Look at the _finally function, which is called by the PythonOperator. There are a few key points here:

  1. Annotate with @provide_session and add argument session=None, so you can query the Airflow DB with session.
  2. Query all upstream task instances for the current task:
upstream_task_instances = (
    session.query(TaskInstance)
    .filter(
        TaskInstance.dag_id == dag.dag_id,
        TaskInstance.execution_date == execution_date,
        TaskInstance.task_id.in_(task.upstream_task_ids),
    )
    .all()
)
  1. From the returned task instances, get the states and check if State.FAILED is in there:
upstream_states = [ti.state for ti in upstream_task_instances]
fail_this_task = State.FAILED in upstream_states
  1. Perform your own logic:
print("Do logic here...")
  1. And finally, fail the task if fail_this_task=True:
if fail_this_task:
    raise AirflowException("Failing task because one or more upstream tasks failed.")

The end result:

enter image description here

like image 98
Bas Harenslak Avatar answered Oct 18 '22 11:10

Bas Harenslak


As @JustinasMarozas explained in a comment, a solution is to create a dummy task like :

dummy = DummyOperator(
    task_id='test',
    dag=dag
)

and bind it downstream to special_task :

failing_task.set_downstream(dummy)

Thus, the DAG is marked as failed, and the dummy task is marked as upstream_failed.

Hope there is an out-of-the-box solution, but waiting for that, this solution does the job.

like image 23
norbjd Avatar answered Oct 18 '22 09:10

norbjd


To expand on Bas Harenslak answer, a simpler _finally function which will check the state of all tasks (not only the upstream ones) can be:

def _finally(**kwargs):
    for task_instance in kwargs['dag_run'].get_task_instances():
        if task_instance.current_state() != State.SUCCESS and \
                task_instance.task_id != kwargs['task_instance'].task_id:
            raise Exception("Task {} failed. Failing this DAG run".format(task_instance.task_id))
like image 43
GuD Avatar answered Oct 18 '22 10:10

GuD