Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use the output of argmin as index with Numpy [duplicate]

I want to find the location of minima along a given axis in a rank-3 numpy array. I have obtained these locations with np.argmin, however I'm not sure how to "apply" this to the original matrix to get the actual minima.

For example:

import numpy as np

a = np.random.randn(10, 5, 2)
min_loc = a.argmin(axis = 0)   # this gives an array of shape (5, 2)

Now, the problem is how do I get the actual minima using min_loc? I have tried a[min_loc], which gives me a shape (5, 2, 5, 2). What's the logic for this shape? How can I use this auxiliary matrix to get a sensible solution of shape (5, 2)

Note that a.min(axis = 0) is not the solution I'm looking for. I need a solution via argmin.

like image 669
Pythonist Avatar asked Nov 18 '25 18:11

Pythonist


1 Answers

a[min_loc] does integer array indexing on the first dimension, i.e. it will pick up (5, 2) shaped array for each index in min_loc. Since min_loc itself is (5, 2) shaped, and for each integer in min_loc, it picks up another (5, 2) shaped array. You end up with a (5, 2, 5, 2) array. Same reason a[np.array([0, 3])] has a shape of (2, 5, 2) and a[np.array([[0], [3]])] has a shape of (2, 1, 5, 2), since you only provide the index for the 1st dimension.

For your usecase, you do not want to pick up a subarray for each index in min_loc but rather you need an element. For instance, if you have min_loc = [[5, ...], ...], the first element should have a full indice of 5, 0, 0 instead of 5, :, :. This is exactly what advanced indexing does. Basically by providing an integer array as index for each dimension, you can pick up the element corresponding to the specific positions. And you can construct indices for the 2nd and 3rd dimensions from a (5, 2) shape with np.indices:

j, k = np.indices(min_loc.shape)
a[min_loc, j, k]

# [[-1.82762089 -0.80927253]
#  [-1.06147046 -1.70961507]
#  [-0.59913623 -1.10963768]
#  [-2.57382762 -0.77081778]
#  [-1.6918745  -1.99800825]]

where j, k are coordinates for the 2nd and 3rd dimensions:

j
#[[0 0]
# [1 1]
# [2 2]
# [3 3]
# [4 4]]

k  
#[[0 1]
# [0 1]
# [0 1]
# [0 1]
# [0 1]]

Or as @hpaulj commented, use np.take_along_axis method:

np.take_along_axis(a, min_loc[None], axis=0)

# [[[-0.93515242 -2.29665325]
#   [-1.30864779 -1.483428  ]
#   [-1.24262879 -0.71030707]
#   [-1.40322789 -1.35580273]
#   [-2.10997209 -2.81922197]]]
like image 132
Psidom Avatar answered Nov 21 '25 07:11

Psidom



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!