Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Template function overloading and SFINAE implementations

I'm spending some time in learning how to use templates in C++. I never used them before and I'm not always sure what can be or what cannot be achieved in different situation.

As an exercise I'm wrapping some of the Blas and Lapack functions that I use for my activities, and I'm currently working on the wrapping of ?GELS (that evaluates the solution of a linear set of equations).

 A x + b = 0

?GELS function (for real values only) exists with two names: SGELS, for single precision vectors and DGELS for double precision.

My idea of interface is a function solve in this way:

 const std::size_t rows = /* number of rows for A */;
 const std::size_t cols = /* number of cols for A */;
 std::array< double, rows * cols > A = { /* values */ };
 std::array< double, ??? > b = { /* values */ };  // ??? it can be either
                                                  // rows or cols. It depends on user
                                                  // problem, in general
                                                  // max( dim(x), dim(b) ) =
                                                  // max( cols, rows )     
 solve< double, rows, cols >(A, b);
 // the solution x is stored in b, thus b 
 // must be "large" enough to accomodate x

Depending on user requirements, the problem may be overdetermined or undetermined, that means:

  • if it is overdetermined dim(b) > dim(x) (the solution is a pseudo-inverse)
  • if it is undetermined dim(b) < dim(x) (the solution is a LSQ minimization)
  • or the normal case in which dim(b) = dim(x) (the solution is the inverse of A)

(without considering singular cases).

Since ?GELS stores the result in the input vector b, the std::array shouold have enough space to accomodate the solution, as described in code comments (max(rows, cols)).

I want to (compile time) determine wich kind of solution to adopt (it is a paramenter change in ?GELS call). I have two functions (I'm simplifying for the sake of the question), that handle the precision and already know which is the dimension of b and the number of rows/cols:

namespace wrap {

template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<float, rows * cols> & A, std::array<float, dimb> & b) {
  SGELS(/* Called in the right way */);
}

template <std::size_t rows, std::size_t cols, std::size_t dimb>
void solve(std::array<double, rows * cols> & A, std::array<double, dimb> & b) {
  DGELS(/* Called in the right way */);
}

}; /* namespace wrap */

that are part of an internal wrapper. The user function, detemine the size required in the b vector through templates:

#include <type_traits>

/** This struct makes the max between rows and cols */
template < std::size_t rows, std::size_t cols >
struct biggest_dim {
  static std::size_t const value = std::conditional< rows >= cols, std::integral_constant< std::size_t, rows >,
                                                     std::integral_constant< std::size_t, cols > >::type::value;
};

/** A type for the b array is selected using "biggest_dim" */
template < typename REAL_T, std::size_t rows, std::size_t cols >
using b_array_t = std::array< REAL_T, biggest_dim< rows, cols >::value >;

/** Here we have the function that allows only the call with b of
 *  the correct size to continue */
template < typename REAL_T, std::size_t rows, std::size_t cols >
void solve(std::array< REAL_T, cols * rows > & A, b_array_t< REAL_T, cols, rows > & b) {
  static_assert(std::is_floating_point< REAL_T >::value, "Only float/double accepted");
  wrap::solve< rows, cols, biggest_dim< rows, cols >::value >(A, b);
} 

In this way it actually works. But I want to go one step further, and I really don't have a clue on how to do it. If the user tries to call solve with b of a size that is too small an extremely difficult-to-read error is raised by the compiler.

I'm trying to insert a static_assert that helps the user to understand his error. But any direction that comes in my mind requires the use of two function with the same signature (it is like a template overloading?) for which I cannot find a SFINAE strategy (and they actually do not compile at all).

Do you think it is possible to raise a static assertion for the case of wrong b dimension without changing the user interface at compile time? I hope the question is clear enough.

@Caninonos: For me the user interface is how the user calls the solver, that is:

 solve< type, number of rows, number of cols > (matrix A, vector b)

This is a constraint that I put on my exercise, in order to improve my skills. That means, I don't know if it is actually possible to achieve the solution. The type of b must match the function call, and it is easy if I add another template parameter and I change the user interface, violating my constraint.

Minimal complete and working example

This is a minimal complete and working example. As requested I removed any reference to linear algebra concepts. It is a problem of number. The cases are:

  • N1 = 2, N2 =2. Since N3 = max(N1, N2) = 2 everything works
  • N1 = 2, N2 =1. Since N3 = max(N1, N2) = N1 = 2 everything works
  • N1 = 1, N2 =2. Since N3 = max(N1, N2) = N2 = 2 everything works
  • N1 = 1, N2 =2. Since N3 = N1 = 1 < N2 it correctly raises a compilation error. Iwant to intercept the compilation error with a static assertion that explains the fact that the dimension of N3 is wrong. As for now the error is difficult to read and understand.

You can view and test it online here

like image 887
Matteo Ragni Avatar asked Mar 08 '18 11:03

Matteo Ragni


People also ask

What is the difference between function overloading and templates?

What is the difference between function overloading and templates? Both function overloading and templates are examples of polymorphism features of OOP. Function overloading is used when multiple functions do quite similar (not identical) operations, templates are used when multiple functions do identical operations.

What is function template explain overloading of template function?

Template Function Overloading:The name of the function templates are the same but called with different arguments is known as function template overloading. If the function template is with the ordinary template, the name of the function remains the same but the number of parameters differs.

What is function overloading how it is implemented?

Function overloading is a feature of object-oriented programming where two or more functions can have the same name but different parameters. When a function name is overloaded with different jobs it is called Function Overloading.

Can template be overloaded?

A function template can overload non-template functions of the same name. In this scenario, the compiler first attempts to resolve a function call by using template argument deduction to instantiate the function template with a unique specialization.


3 Answers

First some improvements that simplify the design a bit and help readability:

  • there is no need for biggest_dim. std::max is constexpr since C++14. You should use it instead.

  • there is no need for b_array_t. You can just write std::array< REAL_T, std::max(N1, N2)>

And now to your problem. One nice way in C++17 is:

template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {

    if constexpr (N3 == std::max(N1, N2))
        wrap::internal< N1, N2, N3 >(A, b);
    else
        static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");

        // don't write static_assert(false)
        // this would make the program ill-formed (*)
} 

Or, as pointed by @max66

template < typename REAL_T, std::size_t N1, std::size_t N2, std::size_t N3>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {

    static_assert(N3 == std::max(N1, N2), "invalid 3rd dimension");

    if constexpr (N3 == std::max(N1, N2))
        wrap::internal< N1, N2, N3 >(A, b);

} 

Tadaa!! Simple, elegant, nice error message.

The difference between the constexpr if version and just a static_assert I.e.:

void solve(...)
{
   static_assert(...);
   wrap::internal(...);
}

is that with just the static_assert the compiler will try to instantiate wrap::internal even on static_assert fail, polluting the error output. With the constexpr if the call to wrap::internal is not part of the body on condition fail so the error output is clean.


(*) The reason I didn't just write static_asert(false, "error msg) is because that would make the program ill-formed, no diagnostics required. See constexpr if and static_assert


You can also make the float / double deductible if you want by moving the template argument after the non-deductible ones:

template < std::size_t N1, std::size_t N2, std::size_t N3,  typename REAL_T>
void solve(std::array< REAL_T, N1 * N2 > & A, std::array< REAL_T, N3> & b) {

So the call becomes:

solve< n1_3, n2_3>(A_3, b_3);
like image 95
bolov Avatar answered Oct 21 '22 22:10

bolov


Why don't you try to combine tag dispatch together with some static_asserts? Below is one way of achieving what you want to solve, I hope. I mean, all the three correct cases are properly piped to the correct blas calls, different types and dimension mismatches are handled, and the violation about float and doubles is also handled, all in a user-friendly way, thanks to static_assert.

EDIT. I am not sure about your C++ version requirement, but below is C++11 friendly.

#include <algorithm>
#include <iostream>
#include <type_traits>

template <class value_t, int nrows, int ncols> struct Matrix {};
template <class value_t, int rows> struct Vector {};

template <class value_t> struct blas;

template <> struct blas<float> {
  static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};

template <> struct blas<double> {
  static void overdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void underdet(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
  static void normal(...) { std::cout << __PRETTY_FUNCTION__ << std::endl; }
};

class overdet {};
class underdet {};
class normal {};

template <class T1, class T2, int nrows, int ncols, int dim>
void solve(const Matrix<T1, nrows, ncols> &lhs, Vector<T2, dim> &rhs) {
  static_assert(std::is_same<T1, T2>::value,
                "lhs and rhs must have the same value types");
  static_assert(dim >= nrows && dim >= ncols,
                "rhs does not have enough space");
  static_assert(std::is_same<T1, float>::value ||
                std::is_same<T1, double>::value,
                "Only float or double are accepted");
  solve_impl(lhs, rhs,
             typename std::conditional<(nrows < ncols), underdet,
             typename std::conditional<(nrows > ncols), overdet,
                                                        normal>::type>::type{});
}

template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, underdet) {
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::underdet(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}

template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, overdet) {
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::overdet(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}

template <class value_t, int nrows, int ncols, int dim>
void solve_impl(const Matrix<value_t, nrows, ncols> &lhs,
                Vector<value_t, dim> &rhs, normal) {
  /* get the pointers and dimension information from lhs and rhs */
  blas<value_t>::normal(
      /* trans, m, n, nrhs, A, lda, B, ldb, work, lwork, info */);
}

int main() {
  /* valid types */
  Matrix<float, 2, 4> A1;
  Matrix<float, 4, 4> A2;
  Matrix<float, 5, 4> A3;
  Vector<float, 4> b1;
  Vector<float, 5> b2;
  solve(A1, b1);
  solve(A2, b1);
  solve(A3, b2);

  Matrix<int, 4, 4> A4;
  Vector<int, 4> b3;
  // solve(A4, b3); // static_assert for float & double

  Matrix<float, 4, 4> A5;
  Vector<int, 4> b4;
  // solve(A5, b4); // static_assert for different types

  // solve(A3, b1); // static_assert for dimension problem

  return 0;
}
like image 1
Arda Aytekin Avatar answered Oct 21 '22 23:10

Arda Aytekin


You have to consider why the interface offers this (convoluted) mess of parameters. The author had several things in mind. First of all, you can solve problems of the form A x + b == 0and A^T x + b == 0 in one function. Secondly, the given A and b can actually point to memory in matrices larger than the ones needed by alg. This can be seen by the LDA and LDB parameters.

It is the subaddressing that makes things complicated. If you want a simple but maybe useful enough API, you could chose to ignore that part:

using ::std::size_t;
using ::std::array;

template<typename T, size_t rows, size_t cols>
using matrix = array<T, rows * cols>;

enum class TransposeMode : bool {
  None = false, Transposed = true
};

// See https://stackoverflow.com/questions/14637356/
template<typename T> struct always_false_t : std::false_type {};
template<typename T> constexpr bool always_false_v = always_false_t<T>::value;

template < typename T, size_t rowsA, size_t colsA, size_t rowsB, size_t colsB
    , TransposeMode mode = TransposeMode::None >
void solve(matrix<T, rowsA, colsA>& A, matrix<T, rowsB, colsB>& B)
{
  // Since the algorithm works in place, b needs to be able to store
  // both input and output
  static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
  // LDA = rowsA, LDB = rowsB
  if constexpr (::std::is_same_v<T, float>) {
    // SGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
  } else if constexpr (::std::is_same_v<T, double>) {
    // DGELS(mode == TransposeMode::None ? 'N' : 'T', ....);
  } else {
    static_assert(always_false_v<T>, "Unknown type");
  }
}

Now, addressing the subaddressing possible with LDA and LDB. I propose that you make that part of your data type, not directly part of the template signature. You want to have your own matrix type that can reference storage in a matrix. Perhaps something like this:

// Since we store elements in a column-major order, we can always 
// pretend that our matrix has less columns than it actually has
// less rows than allocated. We can not equally pretend less rows
// otherwise the addressing into the array is off.
// Thus, we'd only four total parameters:
// offset = columnSkipped * actualRows + rowSkipped), actualRows, rows, cols
// We store the offset implicitly by adjusting our begin pointer
template<typename T, size_t rows, size_t cols, size_t actualRows>
class matrix_view { // Name derived from string_view :)
  static_assert(actualRows >= rows);
  T* start;
  matrix_view(T* start) : start(start) {}
  template<typename U, size_t r, size_t c, size_t ac>
  friend class matrix_view;
public:
  template<typename U>
  matrix_view(matrix<U, rows, cols>& ref)
  : start(ref.data()) { }

  template<size_t rowSkipped, size_t colSkipped, size_t newRows, size_t newCols>
  auto submat() {
    static_assert(colSkipped + newCols <= cols, "can only shrink");
    static_assert(rowSkipped + newRows <= rows, "can only shrink");
    auto newStart = start + colSkipped * actualRows + rowSkipped;
    using newType = matrix_view<T, newRows, newCols, actualRows>
    return newType{ newStart };
  }
  T* data() {
    return start;
  }
};

Now, you'd need to adapt your interface to this new datatype, that's basically just introducing a few new parameters. The checks stay basically the same.

// Using this instead of just type-defing allows us to use deducation guides
// Replaces: using matrix = std::array from above
template<typename T, size_t rows, size_t cols>
class matrix {
public:
    std::array<T, rows * cols> storage;
    auto data() { return storage.data(); }
    auto data() const { return storage.data(); }
};

extern void dgels(char TRANS
  , integer M, integer N , integer NRHS
  , double* A, integer LDA
  , double* B, integer LDB); // Mock, missing a few parameters at the end
// Replaces the solve method from above
template < typename T, size_t rowsA, size_t colsA, size_t actualRowsA
    , size_t rowsB, size_t colsB, size_t actualRowsB
    , TransposeMode mode = TransposeMode::None >
void solve(matrix_view<T, rowsA, colsA, actualRowsA> A, matrix_view<T, rowsB, colsB, actualRowsB> B)
{
    static_assert(rowsB >= rowsA && rowsB >= colsA, "b is too small");
    char transMode = mode == TransposeMode::None ? 'N' : 'T';
    // LDA = rowsA, LDB = rowsB
    if constexpr (::std::is_same_v<T, float>) {
      fgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
    } else if constexpr (::std::is_same_v<T, double>) {
      dgels(transMode, rowsA, colsA, colsB, A.data(), actualRowsA, B.data(), actualRowsB);
    // DGELS(, ....);
    } else {
    static_assert(always_false_v<T>, "Unknown type");
    }
}

Example usage:

int main() {
  matrix<float, 5, 5> A;
  matrix<float, 4, 1> b;

  auto viewA = matrix_view{A}.submat<1, 1, 4, 4>();
  auto viewb = matrix_view{b};
  solve(viewA, viewb);
  // solve(viewA, viewb.submat<1, 0, 2, 1>()); // Error: b is too small
  // solve(matrix_view{A}, viewb.submat<0, 0, 5, 1>()); // Error: can only shrink (b is 4x1 and can not be viewed as 5x1)
}
like image 1
WorldSEnder Avatar answered Oct 21 '22 22:10

WorldSEnder