Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to access the `polymorphic` base class for any child of an `std::variant`?

Let suppose several child classes for a base class:

class Base 
{
public:
    void printHello() const { cout << "Hello" << endl; }
};

class Child1: public Base {};
class Child2: public Base {};
class Child3: public Base {};
..
class ChildN: public Base {};

Let suppose a variant containing any of the contained classes:

using MyVariant = std::variant<Base, Child1, Child2, Child3, ... ChildN>;

Note: The interest of this (compared to a simple vector of polymorphic pointers), is to have all the data in the same memory array, because they are going to be transferred to a device. In this case, the real content of each object is in the vector, not only a pointer to some heap position.

Finally, let suppose I want to work with the Base polymorphic version of each element of a vector<MyVariant>.

std::vector<MyVariant> myVariantList;
... // Initialization

for (const MyVariant& elem: myVariantList)
{
    const Base* baseElem = get_if_polymorph<Base>(elem); //HOW TO?
    baseElem->printHello();
}

Note: Obviously, the trivial solution of having an if statement for each types is not the intention, because new child classes could be added to MyVariant without having to change all further usages. (Extensibility)

So another way to express the question is:

How to manage polymorphism within std::variant?

like image 448
Adrian Maire Avatar asked Mar 04 '23 08:03

Adrian Maire


1 Answers

Use std::visit with a generic lambda:

const Base& baseElem = std::visit(
    [](const auto& x) -> const Base& { return x; },
    elem);

Minimal reproducible example:

#include <iostream>
#include <variant>
#include <vector>

struct Base {
    virtual void hi() const
    {
        std::cout << "Base\n";
    }
};

struct Derived1 : Base {
    void hi() const override
    {
        std::cout << "Derived1\n";
    }
};

struct Derived2 : Base {
    void hi() const override
    {
        std::cout << "Derived2\n";
    }
};

int main()
{
    using Var = std::variant<Base, Derived1, Derived2>;
    std::vector<Var> elems;
    elems.emplace_back(std::in_place_type<Base>);
    elems.emplace_back(std::in_place_type<Derived1>);
    elems.emplace_back(std::in_place_type<Derived2>);
    for (const auto& elem : elems) {
        const Base& x = std::visit(
            [](const auto& x) -> const Base& { return x; },
            elem);
        x.hi();
    }
}

Output:

Base
Derived1
Derived2

(live demo)

like image 149
L. F. Avatar answered May 03 '23 18:05

L. F.