Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to locate the median in a (seaborn) KDE plot?

Tags:

I am trying to do a Kernel Density Estimation (KDE) plot with seaborn and locate the median. The code looks something like this:

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

sns.set_palette("hls", 1)
data = np.random.randn(30)
sns.kdeplot(data, shade=True)

# x_median, y_median = magic_function()
# plt.vlines(x_median, 0, y_median)

plt.show()

As you can see I need a magic_function() to fetch the median x and y values from the kdeplot. Then I would like to plot them with e.g. vlines. However, I can't figure out how to do that. The result should look something like this (obviously the black median bar is wrong here):

enter image description here

I guess my question is not strictly related to seaborn and also applies to other kinds of matplotlib plots. Any ideas are greatly appreciated.

like image 821
n1000 Avatar asked Mar 10 '15 05:03

n1000


People also ask

What does a KDE plot tell you?

A kernel density estimate (KDE) plot is a method for visualizing the distribution of observations in a dataset, analogous to a histogram. KDE represents the data using a continuous probability density curve in one or more dimensions.

What is KDE plot in Seaborn?

Kdeplot is a Kernel Distribution Estimation Plot which depicts the probability density function of the continuous or non-parametric data variables i.e. we can plot for the univariate or multiple variables altogether. Using the Python Seaborn module, we can build the Kdeplot with various functionality added to it.


1 Answers

You need to:

  1. Extract the data of the kde line
  2. Integrate it to calculate the cumulative distribution function (CDF)
  3. Find the value that makes CDF equal 1/2, that is the median
import numpy as np
import scipy
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_palette("hls", 1)
data = np.random.randn(30)
p=sns.kdeplot(data, shade=True)

x,y = p.get_lines()[0].get_data()

#care with the order, it is first y
#initial fills a 0 so the result has same length than x
cdf = scipy.integrate.cumtrapz(y, x, initial=0)

nearest_05 = np.abs(cdf-0.5).argmin()

x_median = x[nearest_05]
y_median = y[nearest_05]

plt.vlines(x_median, 0, y_median)
plt.show()

Result

like image 148
agomcas Avatar answered Sep 19 '22 08:09

agomcas