I want to find the location of minima along a given axis in a rank-3 numpy array. I have obtained these locations with np.argmin, however I'm not sure how to "apply" this to the original matrix to get the actual minima.
For example:
import numpy as np
a = np.random.randn(10, 5, 2)
min_loc = a.argmin(axis = 0) # this gives an array of shape (5, 2)
Now, the problem is how do I get the actual minima using min_loc? I have tried a[min_loc], which gives me a shape (5, 2, 5, 2). What's the logic for this shape? How can I use this auxiliary matrix to get a sensible solution of shape (5, 2)
Note that a.min(axis = 0) is not the solution I'm looking for. I need a solution via argmin.
a[min_loc] does integer array indexing on the first dimension, i.e. it will pick up (5, 2) shaped array for each index in min_loc. Since min_loc itself is (5, 2) shaped, and for each integer in min_loc, it picks up another (5, 2) shaped array. You end up with a (5, 2, 5, 2) array. Same reason a[np.array([0, 3])] has a shape of (2, 5, 2) and a[np.array([[0], [3]])] has a shape of (2, 1, 5, 2), since you only provide the index for the 1st dimension.
For your usecase, you do not want to pick up a subarray for each index in min_loc but rather you need an element. For instance, if you have min_loc = [[5, ...], ...], the first element should have a full indice of 5, 0, 0 instead of 5, :, :. This is exactly what advanced indexing does. Basically by providing an integer array as index for each dimension, you can pick up the element corresponding to the specific positions. And you can construct indices for the 2nd and 3rd dimensions from a (5, 2) shape with np.indices:
j, k = np.indices(min_loc.shape)
a[min_loc, j, k]
# [[-1.82762089 -0.80927253]
# [-1.06147046 -1.70961507]
# [-0.59913623 -1.10963768]
# [-2.57382762 -0.77081778]
# [-1.6918745 -1.99800825]]
where j, k are coordinates for the 2nd and 3rd dimensions:
j
#[[0 0]
# [1 1]
# [2 2]
# [3 3]
# [4 4]]
k
#[[0 1]
# [0 1]
# [0 1]
# [0 1]
# [0 1]]
Or as @hpaulj commented, use np.take_along_axis method:
np.take_along_axis(a, min_loc[None], axis=0)
# [[[-0.93515242 -2.29665325]
# [-1.30864779 -1.483428 ]
# [-1.24262879 -0.71030707]
# [-1.40322789 -1.35580273]
# [-2.10997209 -2.81922197]]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With