Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to identify outliers with density plot

I'm trying to identify outliers with my density plot. I am currently using the seaborn library to plot my data. How would I go about identifying outliers? I have been looking at implementing the Z-score with the stats library, is this the only way this can be achieved could this not be done within a density plot?

like image 599
B.Billy Avatar asked Nov 07 '22 18:11

B.Billy


1 Answers

Kernel density estimation is an estimation of hypothetical probability density function (pdf) by given data. Now, we have a question: what data points should be treated as outliers. Outliers are rare datapoints, i.e. those points, where pdf is extremely low. We don't know the pdf, but know its estimation. So, we can use this estimation to identify outliers.

So, basic idea is to: 1) compute kernel density estimation at all data points; 2) find those points, where this estimation is lower than some predefined threshold. The latter would be outliers.

Lets write some code to illustrate this.

import numpy as np
# import seaborn as sns # you probably can use seaborn to get pdf-estimation values, I would use scikit-learn package for this.
from matplotlib import pyplot as plt
from sklearn.neighbors import KernelDensity

# 100 normally distributed data points and approximately 10 outliers in the end of the array.
data = np.r_[np.random.randn(100), np.random.rand(10)*100][:, np.newaxis]

# you an use kernel='gaussian' instead
kde = KernelDensity(kernel='tophat', bandwidth=0.75).fit(data)

yvals = kde.score_samples(data)  # yvals are logs of pdf-values
yvals[np.isinf(yvals)] = np.nan # some values are -inf, set them to nan

# approx. 10 percent of smallest pdf-values: lets treat them as outliers 
outlier_inds = np.where(yvals < np.percentile(yvals, 10))[0]
print(outlier_inds)
non_outlier_inds = np.where(yvals >= np.percentile(yvals, 10))[0]
print(non_outlier_inds)

[ 33  49 100 101 102 103 105 106 107 108 109]
[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  50  51  52  53  54  55
  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73
  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91
  92  93  94  95  96  97  98  99 104]

# I applied log to data points because we need to visualize small (0,1) and large (up to 100) values on the same plot.
plt.plot(non_outlier_inds, np.log(data[non_outlier_inds]), 'ro',
         outlier_inds, np.log(data[outlier_inds]), 'bo')
plt.gca().set_xlabel('Index')
plt.gca().set_ylabel('log(data)')
plt.show()

enter image description here

like image 129
bubble Avatar answered Nov 14 '22 11:11

bubble