Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I await an array of tasks and stop waiting on first exception?

I have an array of tasks and I am awaiting them with Task.WhenAll. My tasks are failing frequently, in which case I inform the user with a message box so that she can try again. My problem is that reporting the error is delayed until all tasks are completed. Instead I would like to inform the user as soon as the first task has thrown an exception. In other words I want a version of Task.WhenAll that fails fast. Since no such build-in method exists I tried to make my own, but my implementation does not behave the way I want. Here is what I came up with:

public static async Task<TResult[]> WhenAllFailFast<TResult>(
    params Task<TResult>[] tasks)
{
    foreach (var task in tasks)
    {
        await task.ConfigureAwait(false);
    }
    return await Task.WhenAll(tasks).ConfigureAwait(false);
}

This generally throws faster than the native Task.WhenAll, but usually not fast enough. A faulted task #2 will not be observed before the completion of task #1. How can I improve it so that it fails as fast as possible?


Update: Regarding cancellation, it is not in my requirements right now, but lets say that for consistency the first cancelled task should stop the awaiting immediately. In this case the combining task returned from WhenAllFailFast should have Status == TaskStatus.Canceled.

Clarification: Τhe cancellation scenario is about the user clicking a Cancel button to stop the tasks from completing. It is not about cancelling automatically the incomplete tasks in case of an exception.

like image 280
Theodor Zoulias Avatar asked Aug 01 '19 16:08

Theodor Zoulias


3 Answers

I recently needed once again the WhenAllFailFast method, and I revised @ZaldronGG's excellent solution to make it a bit more performant (and more in line with Stephen Cleary's recommendations). The implementation below handles around 3,500,000 tasks per second in my PC.

public static Task<TResult[]> WhenAllFailFast<TResult>(params Task<TResult>[] tasks)
{
    if (tasks is null) throw new ArgumentNullException(nameof(tasks));
    if (tasks.Length == 0) return Task.FromResult(new TResult[0]);

    var results = new TResult[tasks.Length];
    var remaining = tasks.Length;
    var tcs = new TaskCompletionSource<TResult[]>(
        TaskCreationOptions.RunContinuationsAsynchronously);

    for (int i = 0; i < tasks.Length; i++)
    {
        var task = tasks[i];
        if (task == null) throw new ArgumentException(
            $"The {nameof(tasks)} argument included a null value.", nameof(tasks));
        HandleCompletion(task, i);
    }
    return tcs.Task;

    async void HandleCompletion(Task<TResult> task, int index)
    {
        try
        {
            var result = await task.ConfigureAwait(false);
            results[index] = result;
            if (Interlocked.Decrement(ref remaining) == 0)
            {
                tcs.TrySetResult(results);
            }
        }
        catch (OperationCanceledException)
        {
            tcs.TrySetCanceled();
        }
        catch (Exception ex)
        {
            tcs.TrySetException(ex);
        }
    }
}
like image 62
Theodor Zoulias Avatar answered Nov 11 '22 12:11

Theodor Zoulias


Your best bet is to build your WhenAllFailFast method using TaskCompletionSource. You can .ContinueWith() every input task with a synchronous continuation that errors the TCS when the tasks end in the Faulted state (using the same exception object).

Perhaps something like (not fully tested):

using System;
using System.Threading;
using System.Threading.Tasks;

namespace stackoverflow
{
    class Program
    {
        static async Task Main(string[] args)
        {

            var cts = new CancellationTokenSource();
            cts.Cancel();
            var arr = await WhenAllFastFail(
                Task.FromResult(42),
                Task.Delay(2000).ContinueWith<int>(t => throw new Exception("ouch")),
                Task.FromCanceled<int>(cts.Token));

            Console.WriteLine("Hello World!");
        }

        public static Task<TResult[]> WhenAllFastFail<TResult>(params Task<TResult>[] tasks)
        {
            if (tasks is null || tasks.Length == 0) return Task.FromResult(Array.Empty<TResult>());

            // defensive copy.
            var defensive = tasks.Clone() as Task<TResult>[];

            var tcs = new TaskCompletionSource<TResult[]>();
            var remaining = defensive.Length;

            Action<Task> check = t =>
            {
                switch (t.Status)
                {
                    case TaskStatus.Faulted:
                        // we 'try' as some other task may beat us to the punch.
                        tcs.TrySetException(t.Exception.InnerException);
                        break;
                    case TaskStatus.Canceled:
                        // we 'try' as some other task may beat us to the punch.
                        tcs.TrySetCanceled();
                        break;
                    default:

                        // we can safely set here as no other task remains to run.
                        if (Interlocked.Decrement(ref remaining) == 0)
                        {
                            // get the results into an array.
                            var results = new TResult[defensive.Length];
                            for (var i = 0; i < tasks.Length; ++i) results[i] = defensive[i].Result;
                            tcs.SetResult(results);
                        }
                        break;
                }
            };

            foreach (var task in defensive)
            {
                task.ContinueWith(check, default, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Default);
            }

            return tcs.Task;
        }
    }
}

Edit: Unwraps AggregateException, Cancellation support, return array of results. Defend against array mutation, null and empty. Explicit TaskScheduler.

like image 27
ZaldronGG Avatar answered Nov 11 '22 14:11

ZaldronGG


Your loop waits for each of the tasks in pseudo-serial, so that's why it waits for task1 to complete before checking if task2 failed.

You might find this article helpful on a pattern for aborting after the first failure: http://gigi.nullneuron.net/gigilabs/patterns-for-asynchronous-composite-tasks-in-c/

    public static async Task<TResult[]> WhenAllFailFast<TResult>(
        params Task<TResult>[] tasks)
    {
        var taskList = tasks.ToList();
        while (taskList.Count > 0)
        {
            var task = await Task.WhenAny(taskList).ConfigureAwait(false);
            if(task.Exception != null)
            {
                // Left as an exercise for the reader: 
                // properly unwrap the AggregateException; 
                // handle the exception(s);
                // cancel the other running tasks.
                throw task.Exception.InnerException;           
            }

            taskList.Remove(task);
        }
        return await Task.WhenAll(tasks).ConfigureAwait(false);
     }
like image 2
stannius Avatar answered Nov 11 '22 13:11

stannius