This is pretty common SQL query:
Select lines with maximum value in column X, group by group_id.
The result is for every group_id, one (first) line where column X value is maximum within group.
I have a 2D NumPy array with many columns but lets simplify it to (ID, X, Y):
import numpy as np
rows = np.array([[1 22 1236]
[1 11 1563]
[2 13 1234]
[2 10 1224]
[2 23 1111]
[2 23 1250]])
And I want to get:
[[1 22 1236]
[2 23 1111]]
I am able to do it through cumbersome loop, something like:
row_grouped_with_max = []
max_row = rows[0]
last_max = max_row[1]
last_row_group = max_row[0]
for row in rows:
if last_max < row[1]:
max_row = row
if row[0] != last_row_group:
last_row_group = row[0]
last_max = 0
row_grouped_with_max.append(max_row)
row_grouped_with_max.append(max_row)
How to do this in a clean NumPy way?
Alternative using the pandas library (easier to manipulate ndarrays there, IMO).
In [1]: import numpy as np
...: import pandas as pd
In [2]: rows = np.array([[1,22,1236],
...: [1,11,1563],
...: [2,13,1234],
...: [2,10,1224],
...: [2,23,1111],
...: [2,23,1250]])
...: print rows
[[ 1 22 1236]
[ 1 11 1563]
[ 2 13 1234]
[ 2 10 1224]
[ 2 23 1111]
[ 2 23 1250]]
In [3]: df = pd.DataFrame(rows)
...: print df
0 1 2
0 1 22 1236
1 1 11 1563
2 2 13 1234
3 2 10 1224
4 2 23 1111
5 2 23 1250
In [4]: g = df.groupby([0])[1].transform(max)
...: print g
0 22
1 22
2 23
3 23
4 23
5 23
dtype: int32
In [5]: df2 = df[df[1] == g]
...: print df2
0 1 2
0 1 22 1236
4 2 23 1111
5 2 23 1250
In [6]: df3 = df2.drop_duplicates([1])
...: print df3
0 1 2
0 1 22 1236
4 2 23 1111
In [7]: mtx = df3.as_matrix()
...: print mtx
[[ 1 22 1236]
[ 2 23 1111]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With