Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient rolling trimmed mean with Python

What's the most efficient way to calculate a rolling (aka moving window) trimmed mean with Python?

For example, for a data set of 50K rows and a window size of 50, for each row I need to take the last 50 rows, remove the top and bottom 3 values (5% of the window size, rounded up), and get the average of the remaining 44 values.

Currently for each row I'm slicing to get the window, sorting the window and then slicing to trim it. It works, slowly, but there has to be a more efficient way.

Example

[10,12,8,13,7,18,19,9,15,14] # data used for example, in real its a 50k lines df

Example data set and results for a window size of 5. For each row we look at the last 5 rows, sort them and discard 1 top and 1 bottom row (5% of 5 = 0.25, rounded up to 1). Then we average the remaining middle rows.

Code to generate this example set as a DataFrame

pd.DataFrame({
    'value': [10, 12, 8, 13, 7, 18, 19, 9, 15, 14],
    'window_of_last_5_values': [
        np.NaN, np.NaN, np.NaN, np.NaN, '10,12,8,13,7', '12,8,13,7,18',
        '8,13,7,18,19', '13,7,18,19,9', '7,18,19,9,15', '18,19,9,15,14'
    ],
    'values that are counting for average': [
        np.NaN, np.NaN, np.NaN, np.NaN, '10,12,8', '12,8,13', '8,13,18',
        '13,18,9', '18,9,15', '18,15,14'
    ],
    'result': [
        np.NaN, np.NaN, np.NaN, np.NaN, 10.0, 11.0, 13.0, 13.333333333333334,
        14.0, 15.666666666666666
    ]
})

Example code for the naive implementation

window_size = 5
outliers_to_remove = 1

for index in range(window_size - 1, len(df)):
    current_window = df.iloc[index - window_size + 1:index + 1]
    trimmed_mean = current_window.sort_values('value')[
        outliers_to_remove:window_size - outliers_to_remove]['value'].mean()
    # save the result and the window content somewhere

A note about DataFrame vs list vs NumPy array

Just by moving the data from a DataFrame to a list, I'm getting a 3.5x speed boost with the same algorithm. Interestingly, using a NumPy array also gives almost the same speed boost. Still, there must be a better way to implement this and achieve an orders-of-magnitude boost.

like image 454
Alex Friedman Avatar asked Sep 02 '18 09:09

Alex Friedman


2 Answers

One observation that could come in handy is that you do not need to sort all the values at each step. Rather, if you ensure that the window is always sorted, all you need to do is insert the new value at the relevant spot, and remove the old one from where it was, both of which are operations that can be done in O(log_2(window_size)) using bisect. In practice, this would look something like

def rolling_mean(data):
    x = sorted(data[:49])
    res = np.repeat(np.nan, len(data))
    for i in range(49, len(data)):
        if i != 49:
            del x[bisect.bisect_left(x, data[i - 50])]
        bisect.insort_right(x, data[i])
        res[i] = np.mean(x[3:47])
    return res

Now, the additional benefit in this case turns out to be less than what is gained by the vectorization that scipy.stats.trim_mean relies on, and so in particular, this will still be slower than @ChrisA's solution, but it is a useful starting point for further performance optimization.

> data = pd.Series(np.random.randint(0, 1000, 50000))
> %timeit data.rolling(50).apply(lambda w: trim_mean(w, 0.06))
727 ms ± 34.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit rolling_mean(data.values)
812 ms ± 42.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Notably, Numba's jitter, which is often useful in situations like these, also provides no benefit:

> from numba import jit
> rolling_mean_jit = jit(rolling_mean)
> %timeit rolling_mean_jit(data.values)
1.05 s ± 183 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The following, seemingly far-from-optimal, approach outperforms both of the other approaches considered above:

def rolling_mean_np(data):
    res = np.repeat(np.nan, len(data))
    for i in range(len(data)-49):
        x = np.sort(data[i:i+50])
        res[i+49] = x[3:47].mean()
    return res

Timing:

> %timeit rolling_mean_np(data.values)
564 ms ± 4.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

What is more, this time around, JIT compilation does help:

> rolling_mean_np_jit = jit(rolling_mean_np)
> %timeit rolling_mean_np_jit(data.values)
94.9 ms ± 605 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

While we're at it, let's just quickly verify that this actually does what we expect it to:

> np.all(rolling_mean_np_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

In fact, by helping out the sorter just a little bit, we can squeeze out another factor of 2, taking the total time down to 57 ms:

def rolling_mean_np_manual(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old = np.searchsorted(x, data[i-50])
            x[idx_old] = data[i]
            x.sort()
    return res

> %timeit rolling_mean_np_manual(data.values)
580 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_manual_jit = jit(rolling_mean_np_manual)
> %timeit rolling_mean_np_manual_jit(data.values)
57 ms ± 5.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_manual_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

Now, the "sorting" that is going on in this example of course just boils down to placing the new element in the right place, while shifting everything in between by one. Doing this by hand will make the pure Python code slower, but the jitted version gains another factor of 2, taking us below 30 ms:

def rolling_mean_np_shift(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old, idx_new = np.searchsorted(x, [data[i-50], data[i]])
            if idx_old < idx_new:
                x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
                x[idx_new-1] = data[i]
            elif idx_new < idx_old:
                x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
                x[idx_new] = data[i]
            else:
                x[idx_new] = data[i]
    return res

> %timeit rolling_mean_np_shift(data.values)
937 ms ± 97.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_shift_jit = jit(rolling_mean_np_shift)
> %timeit rolling_mean_np_shift_jit(data.values)
26.4 ms ± 693 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_shift_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True

At this point, most of the time is spent in np.searchsorted, so let us make the search itself JIT-friendly. Adopting the source code for bisect, we let

@jit
def binary_search(a, x):
    lo = 0
    hi = 50
    while lo < hi:
        mid = (lo+hi)//2
        if a[mid] < x: lo = mid+1
        else: hi = mid
    return lo

@jit
def rolling_mean_np_jitted_search(data):
    x = np.sort(data[:50])
    res = np.repeat(np.nan, len(data))
    for i in range(50, len(data)+1):
        res[i-1] = x[3:47].mean()
        if i != len(data):
            idx_old = binary_search(x, data[i-50])
            idx_new = binary_search(x, data[i])
            if idx_old < idx_new:
                x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
                x[idx_new-1] = data[i]
            elif idx_new < idx_old:
                x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
                x[idx_new] = data[i]
            else:
                x[idx_new] = data[i]
    return res

This takes us down to 12 ms, a x60 improvement over the raw pandas+SciPy approach:

> %timeit rolling_mean_np_jitted_search(data.values)
12 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
like image 198
fuglede Avatar answered Sep 28 '22 14:09

fuglede


You might try using scipy.stats.trim_mean :

from scipy.stats import trim_mean

df['value'].rolling(5).apply(lambda x: trim_mean(x, 0.2))

[output]

0          NaN
1          NaN
2          NaN
3          NaN
4    10.000000
5    11.000000
6    13.000000
7    13.333333
8    14.000000
9    15.666667

Note that I had to use rolling(5) and proportiontocut=0.2 for your toy data set.

For your real data you should use rolling(50) and trim_mean(x, 0.06) to remove the top and bottom 3 values from the rolling window.

like image 26
Chris Adams Avatar answered Sep 28 '22 16:09

Chris Adams