Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In-place shuffling of multidimensional arrays

I am trying to implement a NaN-safe shuffling procedure in Cython that can shuffle along several axis of a multidimensional matrix of arbitrary dimension.

In the simple case of a 1D matrix, one can simply shuffle over all indices with non-NaN values using the Fisher–Yates algorithm:

def shuffle1D(np.ndarray[double, ndim=1] x):
    cdef np.ndarray[long, ndim=1] idx = np.where(~np.isnan(x))[0]
    cdef unsigned int i,j,n,m

    randint = np.random.randint
    for i in xrange(len(idx)-1, 0, -1):
        j = randint(i+1)
        n,m = idx[i], idx[j]
        x[n], x[m] = x[m], x[n]

I would like to extend this algorithm to handle large multidimensional arrays without reshape (which triggers a copy for more complicated cases not considered here). To this end, I would need to get rid of the fixed input dimension, which seems neither possible with numpy arrays nor memoryviews in Cython. Is there a workaround?

Many thanks in advance!

like image 932
user45893 Avatar asked Sep 29 '14 15:09

user45893


1 Answers

The following algorithm is based on slices, where no copy is made and it should work for any np.ndarray. The main steps are:

  • np.ndindex() is used to run throught the different multidimensional indices, excluding the one belonging to the axis you want to shuffle
  • the shuffle already developed by you for the 1-D case is applied.

Code:

def shuffleND(np.ndarray x, axis=-1):
    cdef np.ndarray[long long, ndim=1] idx
    cdef unsigned int i, j, n, m
    if axis==-1:
        axis = x.ndim-1
    all_shape = list(np.shape(x))
    shape = all_shape[:]
    shape.pop(axis)
    for slices in np.ndindex(*shape):
        slices = list(slices)
        axis_slice = slices[:]
        axis_slice.insert(axis, slice(None))
        idx = np.where(~np.isnan(x[tuple(axis_slice)]))[0]
        for i in range(idx.shape[0]-1, 0, -1):
            j = randint(i+1)
            n, m = idx[i], idx[j]
            slice1 = slices[:]
            slice1.insert(axis, n)
            slice2 = slices[:]
            slice2.insert(axis, m)
            slice1 = tuple(slice1)
            slice2 = tuple(slice2)
            x[slice1], x[slice2] = x[slice2], x[slice1]
    return x
like image 66
Saullo G. P. Castro Avatar answered Oct 20 '22 00:10

Saullo G. P. Castro