Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make cython function accept float or double array input?

Suppose I have the following (MCVE...) cython function

cimport cython

from scipy.linalg.cython_blas cimport dnrm2


cpdef double func(int n, double[:] x):
   cdef int inc = 1
   return dnrm2(&n, &x[0], &inc)

Then, I cannot call it on a np.float32 array x.

How could I make func accept a double[:] or a float[:], and call dnrm2 or snrm2 alternatively? The only solution I have currently is to have two functions, which creates a huge quantity of duplicated code.

like image 343
P. Camilleri Avatar asked May 26 '26 00:05

P. Camilleri


1 Answers

You could use a fused type. Please note that the below doesn't compile on my system because ddot and sdot apparently require 5 parameters:

# cython: infer_types=True
cimport cython

from scipy.linalg.cython_blas cimport ddot, sdot

ctypedef fused anyfloat:
   double
   float

cpdef anyfloat func(int n, anyfloat[:] x):
   cdef int inc = 1
   if anyfloat is double:
      return ddot(&n, &x[0], &inc)
   else:
      return sdot(&n, &x[0], &inc)
like image 107
Paul Panzer Avatar answered May 27 '26 14:05

Paul Panzer



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!