Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select pandas frame rows based on two columns' values

I wish to select some specific rows based on two column values. For example:

d = {'user' : [1., 2., 3., 4] ,'item' : [5., 6., 7., 8.],'f1' : [9., 16., 17., 18.], 'f2':[4,5,6,5], 'f3':[4,5,5,8]}
df = pd.DataFrame(d)
print df

Out:
   f1  f2  f3  item  user
0   9   4   4     5     1
1  16   5   5     6     2
2  17   6   5     7     3
3  18   5   8     8     4

I want to select the rows based on the values of 'user' and 'item'. Given an 2d numpy array which stores the [user, item] values pairs:

samples = np.array([[1,5],[3,7],[3,7],[2,6]]) 
Out: 
array([[1, 5],
       [3, 7],
       [3, 7],
       [2, 6]])

Then the expected output is:

    Out:
   f1  f2  f3  item  user
0   9   4   4     5     1
2  17   6   5     7     3
2  17   6   5     7     3
1  16   5   5     6     2

Then, my final objective is to get an 2d numpy array stores all the columns values except item and user, which is:

Out: 
array([[9, 4, 4],
       [17, 6, 5],
       [17, 6, 5],
       [16, 5, 5]])

As we can see, it is the values of columns f1, f2, f3.

How can I do this?

like image 514
Excalibur Avatar asked Jun 01 '15 20:06

Excalibur


1 Answers

If you make samples a DataFrame with columns user and item, then you can obtain the desired values with an inner join. By default, pd.merge merges on all columns of samples and df shared in common -- in this case, that would be user and item. Hence,

result = pd.merge(samples, df, how='inner')

yields

   user  item  f1  f2  f3
0     1     5   9   4   4
1     3     7  17   6   5
2     3     7  17   6   5
3     2     6  16   5   5

import numpy as np
import pandas as pd

d = {'user' : [1., 2., 3., 4] ,'item' : [5., 6., 7., 8.],'f1' : [9., 16., 17., 18.], 'f2':[4,5,6,5], 'f3':[4,5,5,8]}
df = pd.DataFrame(d)
samples = np.array([[1,5],[3,7],[3,7],[2,6]]) 
samples = pd.DataFrame(samples, columns=['user', 'item'])

result = pd.merge(samples, df, how='inner')
result = result[['f1', 'f2', 'f3']]
result = result.values
print(result)

yields

[[  9.   4.   4.]
 [ 17.   6.   5.]
 [ 17.   6.   5.]
 [ 16.   5.   5.]]
like image 133
unutbu Avatar answered Sep 23 '22 03:09

unutbu