Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Top 3 Values Per Row in Pandas

I have a large Pandas dataframe that is in the vein of:

| ID | Var1 | Var2 | Var3 | Var4 | Var5 |
|----|------|------|------|------|------|
| 1  | 1    | 2    | 3    | 4    | 5    |
| 2  | 10   | 9    | 8    | 7    | 6    |
| 3  | 25   | 37   | 41   | 24   | 21   |
| 4  | 102  | 11   | 72   | 56   | 151  |
...

and I would like to generate output that looks like this, taking the column names of the 3 highest values for each row:

| ID | 1st Max | 2nd Max | 3rd Max |
|----|---------|---------|---------|
| 1  | Var5    | Var4    | Var3    |
| 2  | Var1    | Var2    | Var3    |
| 3  | Var3    | Var2    | Var1    |
| 4  | Var5    | Var1    | Var3    |
...

I have tried using df.idmax(axis=1) which returns the 1st maximum column name but am unsure how to compute the other two?

Any help on this would be truly appreciated, thanks!

like image 481
quicklegit Avatar asked Feb 04 '23 19:02

quicklegit


1 Answers

Use numpy.argsort for positions of sorted values with select top3 by indexing, last pass it to DataFrame constructor:

df = df.set_index('ID')
df = pd.DataFrame(df.columns.values[np.argsort(-df.values, axis=1)[:, :3]], 
                  index=df.index,
                  columns = ['1st Max','2nd Max','3rd Max']).reset_index()
print (df)
   ID 1st Max 2nd Max 3rd Max
0   1    Var5    Var4    Var3
1   2    Var1    Var2    Var3
2   3    Var3    Var2    Var1
3   4    Var5    Var1    Var3

Or if performance is not important use nlargest with apply per each row:

c = ['1st Max','2nd Max','3rd Max']
df = (df.set_index('ID')
        .apply(lambda x: pd.Series(x.nlargest(3).index, index=c), axis=1)
        .reset_index())
like image 103
jezrael Avatar answered Feb 06 '23 14:02

jezrael