Logo Questions Linux Laravel Mysql Ubuntu Git Menu

C++ class template parameter must have a specific parent class [duplicate]

Given is a class MyClass with one template parameter

template<typename T>
class MyClass

and another class MySecondClass with two template parameters.

template<typename T, typename U>
class MySecondClass

What I would like to do is to restrict MyClass to only allow a T that is a derived type of MySecondClass. I already know I need something like

template<typename T, typename = std::enable_if<std::is_base_of<MySecondClass<?,?>, T>::value>>
class MyClass

I am just not sure what to put in for the ? as I want to allow all possible MySecondClass's.

like image 814
user1056903 Avatar asked Aug 22 '16 13:08


2 Answers

You can use a template template parameter for the base template, then check if a T* can be converted to some Temp<Args...>:

template <template <typename...> class Of, typename T>
struct is_base_instantiation_of {
    template <typename... Args>
    static std::true_type test (Of<Args...>*);
    static std::false_type test (...);

    using type = decltype(test(std::declval<T*>()));
    static constexpr auto value = type::value;

Live Demo

like image 105
TartanLlama Avatar answered Nov 10 '22 01:11


You can use a custom trait to check whether a type is derived from a template. Then use this trait inside a static_assert:

#include <type_traits>

template <template <typename...> class T, typename U>
struct is_derived_from_template
    template <typename... Args>
    static decltype(static_cast<const T<Args...>&>(std::declval<U>()), std::true_type{}) test(
            const T<Args...>&);
    static std::false_type test(...);

    static constexpr bool value = decltype(test(std::declval<U>()))::value;

template <typename T1, typename T2>
struct MyParentClass

template<typename T>
struct MyClass
    static_assert(is_derived_from_template<MyParentClass, T>::value, "T must derive from MyParentClass");

struct DerivedFromMyParentClass : MyParentClass<int, float>{};

struct Foo{};

int main()
    MyClass<DerivedFromMyParentClass> m;
    MyClass<Foo> f;

live example

like image 43
m.s. Avatar answered Nov 10 '22 00:11
