Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select values from a set of arrays according to an array of permutations

I have 3 numpy arrays of shape 2xN (with N large, a few millions), call them a1, a2, a3. Then I have another array of shape Nx3 whose row values refer to one of the arrays a1, a2, a3, call it permutations. This permutations array looks like: [[0, 1, 2], [1,2,0], [1,0,2], ... up to N rows ]

I want to create another 3 numpy arrays b1, b2, b3 of shape 2xN which have the contents of the original a1, a2, a3 but their columns have been permuted according to the rows of the permutations array.

I have tried fancy indexing stacking the 3 arrays, and numpy.choose, but I can't get it to work. I'm looking for a solution without python loops. Any help would be greatly appreciated!

EDIT

Just to clarify I show the python loop implementation of what I'm trying to do:

aa = np.dstack((a1, a2, a3))
bb = np.empty_like(aa)
for i, o in enumerate(permutations):
    bb[:,i, np.arange(3)] = aa[:, i, o]

Then I would retrieve b1, b2, b3 from bb.

like image 760
martinako Avatar asked Jan 20 '26 21:01

martinako


1 Answers

With fancy-indexing, you could do -

bb = aa[:,np.arange(N),permutations.T]

Please note that this would be of shape (2,3,N). So, to select b1, b2, b3, you would do :

b1,b2,b3 = bb[:,0,:], bb[:,1,:], bb[:,2,:]

Or if you insist bb to be of the same shape as with the posted code, you could add this :

bb = bb.swapaxes(1,2)

Here's another approach using linear indexing, slicing and of course NumPy broadcasting -

idx = permutations + 3*np.arange(N)[:,None]    
bb = aa.reshape(2,-1)[:,idx].reshape(2,N,3)

This would create a bb of the same shape as with the posted loopy code.


Runtime test

In [189]: def original_app(aa,permutations):
     ...:     bb = np.empty_like(aa)
     ...:     for i, o in enumerate(permutations):
     ...:         bb[:,i, np.arange(3)] = aa[:, i, o]
     ...:     return bb
     ...: 
     ...: 
     ...: def linear_index_app(aa,permutations):
     ...:     idx = permutations + 3*np.arange(N)[:,None]    
     ...:     return aa.reshape(2,-1)[:,idx].reshape(2,N,3)
     ...: 

In [190]: # Setup input arrays
     ...: N = 10000
     ...: a1 = np.random.rand(2,N)
     ...: a2 = np.random.rand(2,N)
     ...: a3 = np.random.rand(2,N)
     ...: 
     ...: permutations = np.random.randint(0,3,(N,3))
     ...: aa = np.dstack((a1, a2, a3))


In [191]: %timeit original_app(aa,permutations)
10 loops, best of 3: 128 ms per loop

In [192]: %timeit aa[:,np.arange(N),permutations.T]
1000 loops, best of 3: 972 µs per loop

In [193]: %timeit linear_index_app(aa,permutations)
1000 loops, best of 3: 1.02 ms per loop

So, seems like fancy-indexing is the best one of the lot!

like image 60
Divakar Avatar answered Jan 23 '26 11:01

Divakar