Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use a dict to subset a DataFrame?

Say, I have given a DataFrame with most of the columns being categorical data.

> data.head()
  age risk     sex smoking
0  28   no    male      no
1  58   no  female      no
2  27   no    male     yes
3  26   no    male      no
4  29  yes  female     yes

And I would like to subset this data by a dict of key-value pairs for those categorical variables.

tmp = {'risk':'no', 'smoking':'yes', 'sex':'female'}

Hence, I would like to have the following subset.

data[ (data.risk == 'no') & (data.smoking == 'yes') & (data.sex == 'female')]

What I want to do is:

data[tmp]

What is the most python / pandas way of doing this?


Minimal example:

import numpy as np
import pandas as pd
from pandas import Series, DataFrame

x = Series(random.randint(0,2,50), dtype='category')
x.cat.categories = ['no', 'yes']

y = Series(random.randint(0,2,50), dtype='category')
y.cat.categories = ['no', 'yes']

z = Series(random.randint(0,2,50), dtype='category')
z.cat.categories = ['male', 'female']

a = Series(random.randint(20,60,50), dtype='category')

data = DataFrame({'risk':x, 'smoking':y, 'sex':z, 'age':a})

tmp = {'risk':'no', 'smoking':'yes', 'sex':'female'}
like image 687
Thomas Möbius Avatar asked Oct 18 '16 15:10

Thomas Möbius


3 Answers

I would use .query() method for this task:

qry = ' and '.join(["{} == '{}'".format(k,v) for k,v in tmp.items()])    

data.query(qry)

output:

   age risk     sex smoking
7   24   no  female     yes
22  43   no  female     yes
23  42   no  female     yes
25  24   no  female     yes
32  29   no  female     yes
40  34   no  female     yes
43  35   no  female     yes

Query string:

print(qry)
"sex == 'female' and risk == 'no' and smoking == 'yes'"
like image 120
MaxU - stop WAR against UA Avatar answered Oct 24 '22 17:10

MaxU - stop WAR against UA


You can create a look up data frame from the dictionary and then do an inner join with the data which will have the same effect as query:

from pandas import merge, DataFrame
merge(DataFrame(tmp, index =[0]), data)

enter image description here

like image 43
Psidom Avatar answered Oct 24 '22 18:10

Psidom


You can use list comprehension with concat and all:

import numpy as np
import pandas as pd

np.random.seed(123)
x = pd.Series(np.random.randint(0,2,10), dtype='category')
x.cat.categories = ['no', 'yes']
y = pd.Series(np.random.randint(0,2,10), dtype='category')
y.cat.categories = ['no', 'yes']
z = pd.Series(np.random.randint(0,2,10), dtype='category')
z.cat.categories = ['male', 'female']

a = pd.Series(np.random.randint(20,60,10), dtype='category')

data = pd.DataFrame({'risk':x, 'smoking':y, 'sex':z, 'age':a})
print (data)
  age risk     sex smoking
0  24   no    male     yes
1  23  yes    male     yes
2  22   no  female      no
3  40   no  female     yes
4  59   no  female      no
5  22   no    male     yes
6  40   no  female      no
7  27  yes    male     yes
8  55  yes    male     yes
9  48   no    male      no
tmp = {'risk':'no', 'smoking':'yes', 'sex':'female'}
mask = pd.concat([data[x[0]].eq(x[1]) for x in tmp.items()], axis=1).all(axis=1)
print (mask)
0    False
1    False
2    False
3     True
4    False
5    False
6    False
7    False
8    False
9    False
dtype: bool

df1 = data[mask]
print (df1)
 age risk     sex smoking
3  40   no  female     yes
L = [(x[0], x[1]) for x in tmp.items()]
print (L)
[('smoking', 'yes'), ('sex', 'female'), ('risk', 'no')]

L = pd.concat([data[x[0]].eq(x[1]) for x in tmp.items()], axis=1)
print (L)
  smoking    sex   risk
0    True  False   True
1    True  False  False
2   False   True   True
3    True   True   True
4   False   True   True
5    True  False   True
6   False   True   True
7    True  False  False
8    True  False  False
9   False  False   True

Timings:

len(data)=1M.

N = 1000000
np.random.seed(123)
x = pd.Series(np.random.randint(0,2,N), dtype='category')
x.cat.categories = ['no', 'yes']
y = pd.Series(np.random.randint(0,2,N), dtype='category')
y.cat.categories = ['no', 'yes']
z = pd.Series(np.random.randint(0,2,N), dtype='category')
z.cat.categories = ['male', 'female']

a = pd.Series(np.random.randint(20,60,N), dtype='category')

data = pd.DataFrame({'risk':x, 'smoking':y, 'sex':z, 'age':a})

#[1000000 rows x 4 columns]
print (data)


tmp = {'risk':'no', 'smoking':'yes', 'sex':'female'}


In [133]: %timeit (data[pd.concat([data[x[0]].eq(x[1]) for x in tmp.items()], axis=1).all(axis=1)])
10 loops, best of 3: 89.1 ms per loop

In [134]: %timeit (data.query(' and '.join(["{} == '{}'".format(k,v) for k,v in tmp.items()])))
1 loop, best of 3: 237 ms per loop

In [135]: %timeit (pd.merge(pd.DataFrame(tmp, index =[0]), data.reset_index()).set_index('index'))
1 loop, best of 3: 256 ms per loop
like image 38
jezrael Avatar answered Oct 24 '22 18:10

jezrael