Airflow : Run a task when some upstream is skipped by shortcircuit



I have a task that I'll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this task gets skipped as well. I don't want final task to get skipped as it has to report on DAG success.

To avoid it getting skipped I used trigger_rule='all_done', but it still gets skipped.

If I use BranchPythonOperator instead of ShortCircuitOperator final task doesn't get skipped. It would seem like branching workflow could be a solution, even though not optimal, but now final will not respect failures of upstream tasks.

How do I get it to only run when upstreams are successful or skipped?

Sample ShortCircuit DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    schedule_interval='* * * * *',

def shortcircuit_fn():
    return randint(0, 1) == 1

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> short >> work >> final
task_1 >> task_2 >> final

DAG with shortcircuit operator

Sample Branch DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    schedule_interval='* * * * *',

# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')

def branch_fn():
    return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final

DAG with branch operator

I've ended up with developing custom ShortCircuitOperator based on the original one:

class ShortCircuitOperator(PythonOperator, SkipMixin):
    Allows a workflow to continue only if a condition is met. Otherwise, the
    workflow "short-circuits" and downstream tasks that only rely on this operator
    are skipped.

    The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
    condition and short-circuits the workflow if the condition is False. Any
    downstream tasks that only rely on this operator are marked with a state of "skipped".
    If the condition is True, downstream tasks proceed as normal.

    The condition is determined by the result of `python_callable`.

    def find_tasks_to_skip(self, task, found_tasks=None):
        if not found_tasks:
            found_tasks = []
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            if len(t.upstream_task_ids) == 1:
                self.find_tasks_to_skip(t, found_tasks)
        return found_tasks

    def execute(self, context):
        condition = super(ShortCircuitOperator, self).execute(context)
        self.log.info("Condition result is %s", condition)

        if condition:
            self.log.info('Proceeding with downstream tasks...')

            'Skipping downstream tasks that only rely on this path...')

        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)

        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,


This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.

