Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Supsending thrugh multiple nested coroutines

I try get understating how exactly new coroutines work in C++20 but aside for very trivial examples I can't make it work.

My goal is to create deep nested functions that allow that most inner one could break and return control to most outer code and after some condition it give control back to this inner function. This is effective setjmp and longjmp.

I botch some code using some examples find in net:

#include <iostream>
#include <coroutine>
#include <optional>

template <typename T>
struct task 
{
    struct task_promise;

    using promise_type = task_promise;
    using handle_type = std::coroutine_handle<promise_type>;

    mutable handle_type m_handle;

    task(handle_type handle)
        : m_handle(handle) 
    {

    }

    task(task&& other) noexcept
        : m_handle(other.m_handle)
    {
        other.m_handle = nullptr;
    }

    bool await_ready()
    {
        return m_handle.done();
    }

    bool await_suspend(std::coroutine_handle<> handle)
    {
        if (!m_handle.done()) {
            m_handle.resume();
        }
        return !m_handle.done();
    }

    auto await_resume()
    {
        return result();
    }

    T result() const
    {     
        if (!m_handle.done())
            m_handle.resume();

        return *m_handle.promise().m_value;
    }

    //manualy wait for finish
    bool one_step()
    {
        if (!m_handle.done())
            m_handle.resume();
        return !m_handle.done();
    }

    ~task()
    {
        if (m_handle)
            m_handle.destroy();
    }

    struct task_promise 
    {
        std::optional<T>    m_value {};

        auto value()
        {
            return m_value;
        }

        auto initial_suspend()
        {
            return std::suspend_always{};
        }

        auto final_suspend()
        {
            return std::suspend_always{};
        }

        auto return_value(T t)
        {
            m_value = t;
            return std::suspend_always{};
        }

        task<T> get_return_object()
        {
            return {handle_type::from_promise(*this)};
        }

        void unhandled_exception()
        {
            std::terminate();
        }

        void rethrow_if_unhandled_exception()
        {

        }
    };

};

static task<int> suspend_one()
{
    std::cout<< "suspend_one in\n";
    co_await std::suspend_always();
    std::cout<< "suspend_one return\n";
    co_return 1;
}
static task<int> suspend_two()
{
    std::cout<< "suspend_two -> suspend_one #1\n";
    auto a = co_await suspend_one();
    std::cout<< "suspend_two -> suspend_one #2\n";
    auto b = co_await suspend_one();
    std::cout<< "suspend_two return\n";
    co_return a + b;
}

static task<int> suspend_five()
{
    std::cout<< "suspend_five -> suspend_two #1\n";
    auto a = co_await suspend_two();
    std::cout<< "suspend_five -> suspend_one #2\n";
    auto b = co_await suspend_one();
    std::cout<< "suspend_five -> suspend_two #3\n";
    auto c = co_await suspend_two();
    std::cout<< "suspend_five return\n";
    co_return a + b + c;
}

static task<int> run()
{
    std::cout<< "run -> suspend_two #1\n";
    auto a = co_await suspend_two();
    std::cout<< "run -> suspend_one #2\n";
    auto b = co_await suspend_one();
    std::cout<< "run -> suspend_five #3\n";
    auto c = co_await suspend_five();
    std::cout<< "run -> suspend_one #4\n";
    auto d = co_await suspend_one();
    std::cout<< "run -> suspend_two #5\n";
    auto e = co_await suspend_two();
    std::cout<< "run return\n";
    co_return a + b + c + d + e;
}

int main()
{
    std::cout<< "main in\n";
    auto r = run();
    std::cout<< "main -> while\n";
    while (r.one_step()){  std::cout<< "<<<< while loop\n"; }

    std::cout<< "main return\n";
    return r.result();
}

https://gcc.godbolt.org/z/JULJCi

Function run work as expected but I have problem with e.g. suspend_five where it never reach line

std::cout<< "suspend_five -> suspend_two #3\n";

Probably my version of task is completely broken but I have no idea where find this error or how it should look. Or simply thing I want achieve is not supported by C++20? co_yeld could be candidate for work around because it look more possible to nest them manually (for (auto z : f()) co_yeld z;) but goal of this question is to understand internal mechanic of C++20 functionality that solving some existing problem.

like image 326
Yankes Avatar asked May 09 '20 12:05

Yankes


1 Answers

After some digging and reading documetaion I come to conclustion that code like:


    //manualy wait for finish
    bool one_step()
    {
        if (!m_handle.done())
            m_handle.resume();
        return !m_handle.done();
    }

Is completely broken and and done backward. You cannot m_handle.resume() if inner task is not finished aka something like m_handle.inner.done(). This mean we need first "move" forward most inner task before we could move outer one.

This is now new version that "work", probably I miss some details there but at least result is similar to that I expect.

#include <cstdio>
#include <coroutine>
#include <optional>

namespace
{

template <typename T>
struct task 
{
    struct task_promise;

    using promise_type = task_promise;
    using handle_type = std::coroutine_handle<promise_type>;

    mutable handle_type m_handle;

    task(handle_type handle)
        : m_handle(handle) 
    {

    }

    task(task&& other) noexcept
        : m_handle(other.m_handle)
    {
        other.m_handle = nullptr;
    }

    bool await_ready()
    {
        //check need for coroutine that do not have `co_await`
        return !m_handle || m_handle.done();
    }

    bool await_suspend(std::coroutine_handle<> handle)
    {
        return true;
    }

    bool await_suspend(std::coroutine_handle<promise_type> handle)
    {
        handle.promise().m_inner_handler = m_handle;
        m_handle.promise().m_outer_handler = handle;
        return true;
    }

    auto await_resume()
    {
        return *m_handle.promise().m_value;
    }

    //manualy wait for finish
    bool one_step()
    {
        auto curr = m_handle;
        while (curr)
        {
            if (!curr.promise().m_inner_handler)
            {
                while (!curr.done())
                {
                    curr.resume();
                    if (!curr.done())
                    {
                        return true;
                    }
                    if (curr.promise().m_outer_handler)
                    {
                        curr = curr.promise().m_outer_handler;
                        curr.promise().m_inner_handler = nullptr;
                    }
                    else
                    {
                        return false;
                    }
                }
                break;
            }
            curr = curr.promise().m_inner_handler;
        }
        return !curr.done();
    }

    ~task()
    {
        if (m_handle)
            m_handle.destroy();
    }

    struct task_promise 
    {
        std::optional<T>    m_value {};
        std::coroutine_handle<promise_type> m_inner_handler {};
        std::coroutine_handle<promise_type> m_outer_handler {};

        auto value()
        {
            return m_value;
        }

        auto initial_suspend()
        {
            return std::suspend_never{};
        }

        auto final_suspend()
        {
            return std::suspend_always{};
        }

        auto return_value(T t)
        {
            m_value = t;
            return std::suspend_always{};
        }

        task<T> get_return_object()
        {
            return {handle_type::from_promise(*this)};
        }

        void unhandled_exception()
        {
            std::terminate();
        }

        void rethrow_if_unhandled_exception()
        {
            
        }
    };

};

task<int> suspend_none()
{
    std::printf("suspend_none\n");
    co_return 0;
}

task<int> suspend_one()
{
    std::printf("suspend_one \\\n");
    co_await std::suspend_always();
    std::printf("suspend_one /\n");
    co_return 1;
}
task<int> suspend_two()
{
    co_await suspend_none();
    auto a = co_await suspend_one();
    co_await suspend_none();
    auto b = co_await suspend_one();
    co_return a + b;
}

task<int> suspend_five()
{
    auto a = co_await suspend_two();
    auto b = co_await suspend_two();
    co_return 1 + a + b;
}

task<int> run()
{
    std::printf("run\n");
    auto a = co_await suspend_five();
    auto b = co_await suspend_five();
    auto c = co_await suspend_five();
    co_return 5 + a + b + c;
}
    
}

int main()
{
    std::printf( "main in\n");
    auto r = run();
    std::printf( "main -> while\n");
    while (r.one_step()){  std::printf("              while loop\n"); }
    
    std::printf( "main return\n");
    return r.await_resume();
}

https://gcc.godbolt.org/z/f8zKPqK1d

Most important thing is:

    bool await_suspend(std::coroutine_handle<promise_type> handle)
    {
        handle.promise().m_inner_handler = m_handle;
        m_handle.promise().m_outer_handler = handle;
        return true;
    }

Where I link each frame together and allow us push most inner one before we could go "up stack".

I think this could be consider a poor man's stackfull coroutine.

PS: Example updated to include handling coroutines without co_await

like image 107
Yankes Avatar answered Nov 14 '22 14:11

Yankes