Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter a pandas data frame by requiring presence of multiple items in a MultiIndex level

I have a table of data with a multi-index. The first level of the multi-index is a name corresponding to a given sequence (DNA), the second level of the multi-index corresponds to a specific type of sequence variant wt, m1,m2, m3 in the example below. Not all given wt sequences will have all types of variants(see seqA, and seqC below).

df = pd.DataFrame(data={'A':range(1,9), 'B':range(1,9), 'C': range(1,9)},
     index=pd.MultiIndex.from_tuples([('seqA', 'wt'), ('seqA', 'm1'),
     ('seqA', 'm2'),  ('seqB', 'wt'), ('seqB', 'm1'), ('seqB', 'm2'),
     ('seqB', 'm3'), ('seqC', 'wt') ]))

df.index.rename(['seq_name','type'], inplace=True)
print df

               A  B  C
seq_name type         
seqA     wt    1  1  1
         m1    2  2  2
         m2    3  3  3
seqB     wt    4  4  4
         m1    5  5  5
         m2    6  6  6
         m3    7  7  7
seqC     wt    8  8  8

I want to perform subsequent analyses on the data for only the sequences that have specific type(s) of variants (m1 and m2 in this example). So I want to filter my data frame to require that a given seq_name has all variant types that are specified in a list.

My current solution is pretty clunky, and not very aesthetically pleasing IMO.

var_l = ['wt', 'm1', 'm2']
df1 = df[df.index.get_level_values('type').isin(var_l)] #Filter varaints not of interest

set_l = []
for v in var_l: #Filter for each variant individually, and store seq_names
    df2 = df[df.index.get_level_values('type').isin([v])]
    set_l.append(set(df2.index.get_level_values('seq_name')))

seq_s = set.intersection(*set_l) # Get seq_names that only have all three variants
df3 = df1[df1.index.get_level_values('seq_name').isin(seq_s)] #Filter based on seq_name
print df3

               A  B  C
seq_name type         
seqA     wt    1  1  1
         m1    2  2  2
         m2    3  3  3
seqB     wt    4  4  4
         m1    5  5  5
         m2    6  6  6

I feel like there must be a one-liner that can do this. Something like:

var_l = ['wt', 'm1', 'm2']
filtered_df = filterDataframe(df1, var_l)
print filtered_df

               A  B  C
seq_name type         
seqA     wt    1  1  1
         m1    2  2  2
         m2    3  3  3
seqB     wt    4  4  4
         m1    5  5  5
         m2    6  6  6

I've tried searching this site, and have only found answers that let you filter by any item in a list.

like image 264
HikerT Avatar asked Mar 17 '17 20:03

HikerT


2 Answers

You can use query with filter:

var_l = ['wt', 'm1', 'm2']

filtered_df=df.query('type in @var_l').groupby(level=0).filter(lambda x: len(x)==len(var_l))
print (filtered_df)
               A  B  C
seq_name type         
seqA     wt    1  1  1
         m1    2  2  2
         m2    3  3  3
seqB     wt    4  4  4
         m1    5  5  5
         m2    6  6  6

Another solution with transform size and then filter by boolean indexing:

filtered_df = df.query('type in @var_l')
filtered_df = filtered_df[filtered_df.groupby(level=0)['A']
                                     .transform('size')
                                     .eq(len(var_l))
                                     .rename(None)]

print (filtered_df)
               A  B  C
seq_name type         
seqA     wt    1  1  1
         m1    2  2  2
         m2    3  3  3
seqB     wt    4  4  4
         m1    5  5  5
         m2    6  6  6

It works because:

print (filtered_df.groupby(level=0)['A'].transform('size'))
seq_name  type
seqA      wt      3
          m1      3
          m2      3
seqB      wt      3
          m1      3
          m2      3
seqC      wt      1
Name: A, dtype: int32

print (filtered_df.groupby(level=0)['A']
                  .transform('size')
                  .eq(len(var_l))
                  .rename(None))
seq_name  type
seqA      wt       True
          m1       True
          m2       True
seqB      wt       True
          m1       True
          m2       True
seqC      wt      False
dtype: bool
like image 80
jezrael Avatar answered Oct 27 '22 18:10

jezrael


option 1
using query + stack
As @jezrael pointed out, this depends on no NaN existing in rows to be analyzed.

df.query('type in @var_l').unstack().dropna().stack()

                 A    B    C
seq_name type               
seqA     m1    2.0  2.0  2.0
         m2    3.0  3.0  3.0
         wt    1.0  1.0  1.0
seqB     m1    5.0  5.0  5.0
         m2    6.0  6.0  6.0
         wt    4.0  4.0  4.0

Preserve the dtypes

df.query('type in @var_l').unstack().dropna().stack().astype(df.dtypes)

               A  B  C
seq_name type         
seqA     m1    2  2  2
         m2    3  3  3
         wt    1  1  1
seqB     m1    5  5  5
         m2    6  6  6
         wt    4  4  4

option 2
using filter
it checks if the sub-index intersected with the var_l is the same as var_l

def correct_vars(df, v):
    x = set(v)
    n = df.name
    y = set(df.xs(n).index.intersection(v))
    return x == y

df.groupby(level=0).filter(correct_vars, v=var_l)

               A  B  C
seq_name type         
seqA     wt    1  1  1
         m1    2  2  2
         m2    3  3  3
seqB     wt    4  4  4
         m1    5  5  5
         m2    6  6  6
         m3    7  7  7
like image 24
piRSquared Avatar answered Oct 27 '22 18:10

piRSquared