Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy 3D array max value

Tags:

python

numpy

import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
           [ 0.50,  0.60 ,  0.70 ,  0.30]],
          [[ 0.25,  0.50 ,  0.20 ,  0.70],
           [ 0.80,  0.10 ,  0.50 ,  0.15]]])

I need to find the row and column of the max value in a[i]. If i=0, a[0,1,2] is max, so I need to code a method that gives [1,2] as the output for max in a[0]. Any pointers, please? NB: np.argmax flattens the a[i] 2D array and when axis=0 is used, it gives the index of max in each row of a[0]

like image 387
Surabhi Amit Chembra Avatar asked Feb 18 '26 23:02

Surabhi Amit Chembra


2 Answers

You can also use argmax with unravel_index:

def max_by_index(idx, arr):
    return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])

e.g.

import numpy as np
a = np.array([[[ 0.25,  0.10 ,  0.50 ,  0.15],
               [ 0.50,  0.60 ,  0.70 ,  0.30]],
              [[ 0.25,  0.50 ,  0.20 ,  0.70],
               [ 0.80,  0.10 ,  0.50 ,  0.15]]])

def max_by_index(idx, arr):
    return (idx,) + np.unravel_index(np.argmax(arr[idx]), arr.shape[1:])


print(max_by_index(0, a))

gives

(0, 1, 2)

like image 101
csunday95 Avatar answered Feb 20 '26 12:02

csunday95


You can use numpy.where, which you can wrap in a simple function to meet your requirements:

def max_by_index(idx, arr):
    return np.where(arr[idx] == np.max(arr[idx]))

In action:

>>> max_by_index(0, a)
(array([1], dtype=int64), array([2], dtype=int64))

You can index your array with this result to access the maximum value:

>>> a[0][max_by_index(0, a)]
array([0.7])

This will return all locations of the maximum value, if you only want a single occurrence, you may replace np.max with np.argmax.

like image 21
user3483203 Avatar answered Feb 20 '26 13:02

user3483203



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!