Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does numpy.median scale so well?

A question I recently got at a job interview, was:

Write a data structure that supports two operations.
1. Adding a number to the structure.
2. Calculating the median.
The operations to add a number and calculate the median must have a minimum time complexity.

My implementation was pretty simple, basically keep the elements sorted, this way adding an elements costs O(log(n)) instead of O(1), but the median is O(1) instead of O(n*log(n))

I also added an implementation that is naive, but contains the elements in a numpy array:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from random import randint, random
import math
from time import time

class MedianList():
    def __init__(self, initial_values = []):
        self.values = sorted(initial_values)
        self.size = len(initial_values)

    def add_element(self, element):
        index = self.find_pos(self.values, element)
        self.values = self.values[:index] + [element] + self.values[index:]
        self.size += 1

    def find_pos(self, values, element):
        if len(values) == 0: return 0
        index = int(len(values)/2)
        if element > values[index]: 
            return self.find_pos(values[index+1:], element) + index +  1
        if element < values[index]:
            return self.find_pos(values[:index], element)
        if element == values[index]: return index

    def median(self):
        if self.size == 0: return np.nan
        split = math.floor(self.size/2)
        if self.size % 2 == 1:
            return self.values[split]
        try:
            return (self.values[split] + self.values[split-1])/2
        except:
            print(self.values, self.size, split)

class NaiveMedianList():
    def __init__(self, initial_values = []):
        self.values = sorted(initial_values)

    def add_element(self, element):
        self.values.append(element)

    def median(self):
        split = math.floor(len(self.values)/2)
        sorted_values = sorted(self.values)
        if len(self.values) % 2 == 1:
            return sorted_values[split]
        return (sorted_values[split] + sorted_values[split-1])/2

class NumpyMedianList():
    def __init__(self, initial_values = []):
        self.values = np.array(initial_values)

    def add_element(self, element):
        self.values = np.append(self.values, element)

    def median(self):
        return np.median(self.values)

def time_performance(median_list, total_elements = 10**5):
    elements = [randint(0, 100) for _ in range(total_elements)]
    times = []
    start = time()
    for element in elements:
        median_list.add_element(element)
        median_list.median()
        times.append(time() - start)
    return times

ml_times = time_performance(MedianList())
nl_times = time_performance(NaiveMedianList())
npl_times = time_performance(NumpyMedianList())
times = pd.DataFrame()
times['MedianList'] = ml_times
times['NaiveMedianList'] = nl_times
times['NumpyMedianList'] = npl_times
times.plot()
plt.show()

And here is how the performances look, for 10^4 elements: enter image description here

And for 10^5 elements, the naive numpy implementation is actually faster:

enter image description here

My question is: How come? Even if numpy is faster by a constant factor, how is their median function scaling so well, if they do not keep a sorted version of the array?

like image 472
Hristo Buyukliev Avatar asked Jun 17 '18 18:06

Hristo Buyukliev


People also ask

Does NumPy have a median function?

The NumPy median function computes the median of the values in a NumPy array. Note that the NumPy median function will also operate on “array-like objects” like Python lists.

How do you find the median of an array in Python?

If the number of values, n, is odd, then the median is the value in the (n+1)/2 position in the sorted list(or array) of values. If the number of values, n, is even, then the median is the average of the values in n/2 and n/2 + 1 position in the sorted list(or array) of values.


1 Answers

We can inspect the Numpy source code for median (source):

def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
    ...

    if overwrite_input:
        if axis is None:
            part = a.ravel()
            part.partition(kth)
        else:
            a.partition(kth, axis=axis)
            part = a
    else:
        part = partition(a, kth, axis=axis)

...

The key function is partition, which from the docs, uses introselect. As @zython commented, this is a variant of Quickselect, which provides the critical performance boost.

like image 161
Yatharth Agarwal Avatar answered Sep 28 '22 17:09

Yatharth Agarwal