Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter out everything before a condition is met, keep all elements after

I was wondering if there was an easy solution to the the following problem. The problem here is that I want to keep every element occurring inside this list after the initial condition is true. The condition here being that I want to remove everything before the condition that a value is greater than 18 is true, but keep everything after. Example

Input:

p = [4,9,10,4,20,13,29,3,39]

Expected output:

p = [20,13,29,3,39]

I know that you can filter over the entire list through

[x for x in p if x>18] 

But I want to stop this operation once the first value above 18 is found, and then include the rest of the values regardless if they satisfy the condition or not. It seems like an easy problem but I haven't found the solution to it yet.

like image 764
cinderashes Avatar asked Sep 09 '25 15:09

cinderashes


2 Answers

You can use itertools.dropwhile:

from itertools import dropwhile

p = [4,9,10,4,20,13,29,3,39]

p = dropwhile(lambda x: x <= 18, p)
print(*p) # 20 13 29 3 39

In my opinion, this is arguably the easiest-to-read version. This also corresponds to a common pattern in other functional programming languages, such as dropWhile (<=18) p in Haskell and p.dropWhile(_ <= 18) in Scala.


Alternatively, using walrus operator (only available in python 3.8+):

exceeded = False
p = [x for x in p if (exceeded := exceeded or x > 18)]
print(p) # [20, 13, 29, 3, 39]

But my guess is that some people don't like this style. In that case, one can do an explicit for loop (ilkkachu's suggestion):

for i, x in enumerate(p):
    if x > 18:
        output = p[i:]
        break
else:
    output = [] # alternatively just put output = [] before for
like image 199
j1-lee Avatar answered Sep 12 '25 13:09

j1-lee


You could use enumerate and list slicing in a generator expression and next:

out = next((p[i:] for i, item in enumerate(p) if item > 18), [])

Output:

[20, 13, 29, 3, 39]

In terms of runtime, it depends on the data structure.

The plots below show the runtime difference among the answers on here for various lengths of p.

If the original data is a list, then using a lazy iterator as proposed by @Kelly Bundy is the clear winner:

enter image description here

But if the initial data is a ndarray object, then the vectorized operations as proposed by @richardec and @0x263A (for large arrays) are faster. In particular, numpy beats list methods regardless of array size. But for very large arrays, pandas starts to perform better than numpy (I don't know why, I (and I'm sure others) would appreciate it if anyone can explain it).

enter image description here

Code used to generate the first plot:

import perfplot
import numpy as np
import pandas as pd
import random
from itertools import dropwhile

def it_dropwhile(p):
    return list(dropwhile(lambda x: x <= 18, p))

def walrus(p):
    exceeded = False
    return [x for x in p if (exceeded := exceeded or x > 18)]

def explicit_loop(p):
    for i, x in enumerate(p):
        if x > 18:
            output = p[i:]
            break
    else:
        output = []
    return output

def genexpr_next(p):
    return next((p[i:] for i, item in enumerate(p) if item > 18), [])

def np_argmax(p):
    return p[(np.array(p) > 18).argmax():]

def pd_idxmax(p):
    s = pd.Series(p)
    return s[s.gt(18).idxmax():]

def list_index(p):
    for x in p:
        if x > 18:
            return p[p.index(x):]
    return []

def lazy_iter(p):
    it = iter(p)
    for x in it:
        if x > 18:
            return [x, *it]
    return []

perfplot.show(
    setup=lambda n: random.choices(range(0, 15), k=10*n) + random.choices(range(-20,30), k=10*n),
    kernels=[it_dropwhile, walrus, explicit_loop, genexpr_next, np_argmax, pd_idxmax, list_index, lazy_iter],
    labels=['it_dropwhile','walrus','explicit_loop','genexpr_next','np_argmax','pd_idxmax', 'list_index', 'lazy_iter'],
    n_range=[2 ** k for k in range(18)],
    equality_check=np.allclose,
    xlabel='~n/20'
)

Code used to generate the second plot (note that I had to modify list_index because numpy doesn't have index method):

def list_index(p):
    for x in p:
        if x > 18:
            return p[np.where(p==x)[0][0]:]
    return []

perfplot.show(
    setup=lambda n: np.hstack([np.random.randint(0,15,10*n), np.random.randint(-20,30,10*n)]),
    kernels=[it_dropwhile, walrus, explicit_loop, genexpr_next, np_argmax, pd_idxmax, list_index, lazy_iter],
    labels=['it_dropwhile','walrus','explicit_loop','genexpr_next','np_argmax','pd_idxmax', 'list_index', 'lazy_iter'],
    n_range=[2 ** k for k in range(18)],
    equality_check=np.allclose,
    xlabel='~n/20'
)

Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!