Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Conditional override in derived class template

I have a Container class that holds objects whose type may be derived from any combination of some base classes (TypeA, TypeB, etc.). The base class of Container has virtual methods that return a pointer to the contained object; these should return nullptr if the contained object isn't derived from the expected class. I would like to selectively override the base's methods based on Container's template parameter. I tried using SFINAE as follows, but it doesn't compile. I would like to avoid specializing Container for every possible combination because there could be many.

#include <type_traits>
#include <iostream>

using namespace std;

class TypeA {};
class TypeB {};
class TypeAB: public TypeA, public TypeB {};

struct Container_base {
    virtual TypeA* get_TypeA() {return nullptr;}
    virtual TypeB* get_TypeB() {return nullptr;}
};

template <typename T>
struct Container: public Container_base
{
    Container(): ptr(new T()) {}

    //Override only if T is derived from TypeA
    auto get_TypeA() -> enable_if<is_base_of<TypeA, T>::value, TypeA*>::type
    {return ptr;}

    //Override only if T is dervied from TypeB
    auto get_TypeB() -> enable_if<is_base_of<TypeB, T>::value, TypeB*>::type
    {return ptr;}

private:
    T* ptr;
};

int main(int argc, char *argv[])
{
    Container<TypeA> typea;
    Container<TypeB> typeb;
    Container<TypeAB> typeab;

    cout << typea.get_TypeA() << endl; //valid pointer
    cout << typea.get_TypeB() << endl; //nullptr

    cout << typeb.get_TypeA() << endl; //nullptr
    cout << typeb.get_TypeB() << endl; //valid pointer

    cout << typeab.get_TypeA() << endl; //valid pointer
    cout << typeab.get_TypeB() << endl; //valid pointer

    return 0;
}
like image 835
Carlton Avatar asked Oct 23 '18 19:10

Carlton


2 Answers

... or you could change your approach to a simpler one:

template <typename T>
struct Container: public Container_base
{
    TypeA* get_TypeA() override
    {
        if constexpr(is_base_of_v<TypeA, T>)
            return ptr;
        else
            return nullptr;
    }

    ...
};

and rely on optimizer to smooth away any wrinkles. Like replacing multiple return nullptr functions with one (in final binary). Or removing dead branch of code if your compiler doesn't support if constexpr.

Edit:

... or (if you insist on using SFINAE) something along these lines:

template<class B, class T, enable_if_t< is_base_of_v<B, T>>...> B* cast_impl(T* p) { return p; }
template<class B, class T, enable_if_t<!is_base_of_v<B, T>>...> B* cast_impl(T* p) { return nullptr; }

template <typename T>
struct Container: public Container_base
{
    ...

    TypeA* get_TypeA() override { return cast_impl<TypeA>(ptr); }
    TypeB* get_TypeB() override { return cast_impl<TypeB>(ptr); }

private:
    T* ptr;
};
like image 99
C.M. Avatar answered Sep 29 '22 14:09

C.M.


CRTP to the rescue!

template<class T, class D, class Base, class=void>
struct Container_getA:Base {};
template<class T, class D, class Base, class=void>
struct Container_getB:Base {};

template<class T, class D, class Base>
struct Container_getA<T, D, Base, std::enable_if_t<std::is_base_of<TypeA,T>{}>>:
  Base
{
  TypeA* get_TypeA() final { return self()->ptr; }
  D* self() { return static_cast<D*>(this); }
};

template<class T, class D, class Base>
struct Container_getB<T, D, Base, std::enable_if_t<std::is_base_of<TypeB,T>{}>>:
  Base
{
  TypeB* get_TypeB() final { return self()->ptr; }
  D* self() { return static_cast<D*>(this); }
};

template <class T>
struct Container: 
  Container_getA< T, Container<T>,
    Container_getB< T, Container<T>,
      Container_base
    >
  >
{
    Container(): ptr(new T()) {}

public: // either public, or complex friend declarations; just make it public
    T* ptr;
};

and done.

You can do a bit of work to permit:

struct Container: Bases< T, Container<T>, Container_getA, Container_getB, Container_getC >

or the like where we fold the CRTP bases in.

You can also clean up your syntax:

template<class...Ts>
struct types {};

template<class T>
struct tag_t {using type=T;};
template<class T>
constexpr tag_t<T> tag{};

Then, instead of having a pile of named getters, have:

template<class List>
struct Container_getters;

template<class T>
struct Container_get {
  virtual T* get( tag_t<T> ) { return nullptr; }
};
template<class...Ts>
struct Container_getters<types<Ts...>>:
  Container_get<Ts>...
{
   using Container_get<Ts>::get...; // C++17
   template<class T>
   T* get() { return get(tag<T>); }
};

and now a central type list can be used to maintain the set of types you can get from the container.

We can then use that central type list to write the CRTP intermediate helpers.

template<class Actual, class Derived, class Target, class Base, class=void>
struct Container_impl_get:Base {};
template<class Actual, class Derived, class Target, class Base>
struct Container_impl_get<Actual, Derived, Target, Base,
  std::enable_if_t<std::is_base_of<Target, Actual>{}>
>:Base {
  using Base::get;
  virtual Target* get( tag_t<Target> ) final { return self()->ptr; }
  Derived* self() { return static_cast<Derived*>(this); }
};

and now we just need to write the fold machinery.

template<class Actual, class Derived, class List>
struct Container_get_folder;
template<class Actual, class Derived, class List>
using Container_get_folder_t=typename Container_get_folder<Actual, Derived, List>::type;

template<class Actual, class Derived>
struct Container_get_folder<Actual, Derived, types<>> {
  using type=Container_base;
};
template<class Actual, class Derived, class T0, class...Ts>
struct Container_get_folder<Actual, Derived, types<T0, Ts...>> {
  using type=Container_impl_get<Actual, Derived, T0,
    Container_get_folder_t<Actual, Derived, types<Ts...>>
  >;
};

so we get

using Container_types = types<TypeA, TypeB, TypeC>;
struct Container_base:Container_getters<Container_types> {
};

template <typename T>
struct Container: Container_get_folder_t<T, Container<T>, Container_types>
{
    Container(): ptr(new T()) {}
    T* ptr;
};

and now we can extend this by simply adding a type to Container_types.

Callers who want a specific type can either do:

Container_base* ptr = /* whatever */;
ptr->get<TypeA>()

or

ptr->get(tag<TypeA>);

both work equally well.

Live example -- it does use a C++14 feature or two (namely variable templates in tag), but you can replace tag<X> with tag_t<X>{}.

like image 45
Yakk - Adam Nevraumont Avatar answered Sep 29 '22 15:09

Yakk - Adam Nevraumont