I have 3 numpy arrays of shape 2xN (with N large, a few millions), call them a1, a2, a3. Then I have another array of shape Nx3 whose row values refer to one of the arrays a1, a2, a3, call it permutations. This permutations array looks like: [[0, 1, 2], [1,2,0], [1,0,2], ... up to N rows ]
I want to create another 3 numpy arrays b1, b2, b3 of shape 2xN which have the contents of the original a1, a2, a3 but their columns have been permuted according to the rows of the permutations array.
I have tried fancy indexing stacking the 3 arrays, and numpy.choose, but I can't get it to work. I'm looking for a solution without python loops. Any help would be greatly appreciated!
EDIT
Just to clarify I show the python loop implementation of what I'm trying to do:
aa = np.dstack((a1, a2, a3))
bb = np.empty_like(aa)
for i, o in enumerate(permutations):
bb[:,i, np.arange(3)] = aa[:, i, o]
Then I would retrieve b1, b2, b3 from bb.
With fancy-indexing, you could do -
bb = aa[:,np.arange(N),permutations.T]
Please note that this would be of shape (2,3,N). So, to select b1, b2, b3, you would do :
b1,b2,b3 = bb[:,0,:], bb[:,1,:], bb[:,2,:]
Or if you insist bb to be of the same shape as with the posted code, you could add this :
bb = bb.swapaxes(1,2)
Here's another approach using linear indexing, slicing and of course NumPy broadcasting -
idx = permutations + 3*np.arange(N)[:,None]
bb = aa.reshape(2,-1)[:,idx].reshape(2,N,3)
This would create a bb of the same shape as with the posted loopy code.
Runtime test
In [189]: def original_app(aa,permutations):
...: bb = np.empty_like(aa)
...: for i, o in enumerate(permutations):
...: bb[:,i, np.arange(3)] = aa[:, i, o]
...: return bb
...:
...:
...: def linear_index_app(aa,permutations):
...: idx = permutations + 3*np.arange(N)[:,None]
...: return aa.reshape(2,-1)[:,idx].reshape(2,N,3)
...:
In [190]: # Setup input arrays
...: N = 10000
...: a1 = np.random.rand(2,N)
...: a2 = np.random.rand(2,N)
...: a3 = np.random.rand(2,N)
...:
...: permutations = np.random.randint(0,3,(N,3))
...: aa = np.dstack((a1, a2, a3))
In [191]: %timeit original_app(aa,permutations)
10 loops, best of 3: 128 ms per loop
In [192]: %timeit aa[:,np.arange(N),permutations.T]
1000 loops, best of 3: 972 µs per loop
In [193]: %timeit linear_index_app(aa,permutations)
1000 loops, best of 3: 1.02 ms per loop
So, seems like fancy-indexing is the best one of the lot!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With