Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Airflow : Run a task when some upstream is skipped by shortcircuit

Tags:

airflow

I have a task that I'll call final that has multiple upstream connections. When one of the upstreams gets skipped by ShortCircuitOperator this task gets skipped as well. I don't want final task to get skipped as it has to report on DAG success.

To avoid it getting skipped I used trigger_rule='all_done', but it still gets skipped.

If I use BranchPythonOperator instead of ShortCircuitOperator final task doesn't get skipped. It would seem like branching workflow could be a solution, even though not optimal, but now final will not respect failures of upstream tasks.

How do I get it to only run when upstreams are successful or skipped?

Sample ShortCircuit DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'shortcircuit_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

def shortcircuit_fn():
    return randint(0, 1) == 1

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> short >> work >> final
task_1 >> task_2 >> final

DAG with shortcircuit operator

Sample Branch DAG:

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint

default_args = {
    'owner': 'airflow',
    'start_date': datetime(2018, 8, 1)}

dag = DAG(
    'branch_test',
    default_args=default_args,
    schedule_interval='* * * * *',
    catchup=False)

# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')

def branch_fn():
    return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id

task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')

work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")

task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final

DAG with branch operator

like image 613
Justinas Marozas Avatar asked Aug 07 '18 11:08

Justinas Marozas


1 Answers

I've ended up with developing custom ShortCircuitOperator based on the original one:

class ShortCircuitOperator(PythonOperator, SkipMixin):
    """
    Allows a workflow to continue only if a condition is met. Otherwise, the
    workflow "short-circuits" and downstream tasks that only rely on this operator
    are skipped.

    The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
    condition and short-circuits the workflow if the condition is False. Any
    downstream tasks that only rely on this operator are marked with a state of "skipped".
    If the condition is True, downstream tasks proceed as normal.

    The condition is determined by the result of `python_callable`.
    """

    def find_tasks_to_skip(self, task, found_tasks=None):
        if not found_tasks:
            found_tasks = []
        direct_relatives = task.get_direct_relatives(upstream=False)
        for t in direct_relatives:
            if len(t.upstream_task_ids) == 1:
                found_tasks.append(t)
                self.find_tasks_to_skip(t, found_tasks)
        return found_tasks

    def execute(self, context):
        condition = super(ShortCircuitOperator, self).execute(context)
        self.log.info("Condition result is %s", condition)

        if condition:
            self.log.info('Proceeding with downstream tasks...')
            return

        self.log.info(
            'Skipping downstream tasks that only rely on this path...')

        tasks_to_skip = self.find_tasks_to_skip(context['task'])
        self.log.debug("Tasks to skip: %s", tasks_to_skip)

        if tasks_to_skip:
            self.skip(context['dag_run'], context['ti'].execution_date,
                      tasks_to_skip)

        self.log.info("Done.")

This operator makes sure no downstream task that rely on multiple paths are getting skipped because of one skipped task.

like image 130
Michael Spector Avatar answered Oct 21 '22 00:10

Michael Spector