Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Find all indices of maximum in Pandas DataFrame

Tags:

python

pandas

I need to find all indices where the maximum value (per row) is obtained in a Pandas DataFrame. For instance, if I have a dataFrame like this:

   cat1  cat2  cat3
0     0     2     2
1     3     0     1
2     1     1     0

then the method I am looking for would yield a result like:

[['cat2', 'cat3'],
 ['cat1'],
 ['cat1', 'cat2']]

This is a list of lists, but some other data structure is also okay.

I cannot use df.idxmax(axis=1), because it only yields the first maximum.

like image 341
RafG Avatar asked Feb 07 '14 12:02

RafG


People also ask

How do you find the index with the maximum value of a DataFrame?

The idxmax() method returns a Series with the index of the maximum value for each column. By specifying the column axis ( axis='columns' ), the idxmax() method returns a Series with the index of the maximum value for each row.

How do you get max from Pandas series?

Maximum of the values for the Pandas requested axisThe max() function is used to get the maximum of the values for the requested axis. If you want the index of the maximum, use idxmax. This is the equivalent of the numpy. ndarray method argmax.

How do you check indices Pandas?

The get_loc() function is used to find the index of any column in the Python pandas dataframe. We simply pass the column name to get_loc() function to find index.


1 Answers

Here is the information, in a different data structure:

In [8]: df = pd.DataFrame({'cat1':[0,3,1], 'cat2':[2,0,1], 'cat3':[2,1,0]})

In [9]: df
Out[9]: 
   cat1  cat2  cat3
0     0     2     2
1     3     0     1
2     1     1     0

[3 rows x 3 columns]

In [10]: rowmax = df.max(axis=1)

The max values are indicated by True values:

In [82]: df.values == rowmax[:,None]
Out[82]: 
array([[False,  True,  True],
       [ True, False, False],
       [ True,  True, False]], dtype=bool)

np.where returns the indices where the DataFrame above is True.

In [84]: np.where(df.values == rowmax[:,None])
Out[84]: (array([0, 0, 1, 2, 2]), array([1, 2, 0, 0, 1]))

The first array indicates index values for axis=0, the second array for axis=1. There are 5 values in each array since there are five locations that are True.


You could use itertools.groupby to build the list of lists you posted, though perhaps you don't need this given the data structures above:

In [46]: import itertools as IT

In [47]: import operator

In [48]: idx = np.where(df.values == rowmax[:,None])

In [49]: groups = IT.groupby(zip(*idx), key=operator.itemgetter(0))

In [50]: [[df.columns[j] for i, j in grp] for k, grp in groups]
Out[50]: [['cat1', 'cat1'], ['cat2'], ['cat3', 'cat3']]
like image 50
unutbu Avatar answered Oct 15 '22 01:10

unutbu