Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I plot the probability density function for a fitted Gaussian mixture model under scikit-learn?

I'm struggling with a rather simple task. I have a vector of floats to which I would like to fit a Gaussian mixture model with two Gaussian kernels:

from sklearn.mixture import GMM

gmm = GMM(n_components=2)
gmm.fit(values)  # values is numpy vector of floats

I would now like to plot the probability density function for the mixture model I've created, but I can't seem to find any documentation on how to do this. How should I best proceed?

Edit:

Here is the vector of data I'm fitting. And below is a more detailed example of how I'm doing things:

from sklearn.mixture import GMM
from matplotlib.pyplot import *
import numpy as np

try:
    import cPickle as pickle
except:
    import pickle

with open('/path/to/kde.pickle') as f:  # open the data file provided above
    kde = pickle.load(f)

gmm = GMM(n_components=2)
gmm.fit(kde)

x = np.linspace(np.min(kde), np.max(kde), len(kde))

# Plot the data to which the GMM is being fitted
figure()
plot(x, kde, color='blue')

enter image description here

# My half-baked attempt at replicating the scipy example
fit = gmm.score_samples(x)[0]
plot(x, fit, color='red')

The fitted curve doesn't look anything like what I'd expect. It doesn't even seem Gaussian, which is a bit strange given it was produced by a Gaussian process. Am I crazy?

like image 580
Louis Thibault Avatar asked May 12 '14 12:05

Louis Thibault


People also ask

What is Gaussian mixture model in machine learning?

A Gaussian mixture model is a probabilistic model that assumes all the data points are generated from a mixture of a finite number of Gaussian distributions with unknown parameters.

What is Gaussian mixture model Python?

A Gaussian mixture model (GMM) attempts to find a mixture of multi-dimensional Gaussian probability distributions that best model any input dataset. In the simplest case, GMMs can be used for finding clusters in the same manner as k-means: from sklearn.mixture import GMM gmm = GMM(n_components=4). fit(X) labels = gmm.

What is Bayesian Gaussian mixture model?

Bayesian Gaussian mixture models constitutes a form of unsupervised learning and can be useful in fitting multi-modal data for tasks such as clustering, data compression, outlier detection, or generative classifiers.


3 Answers

I followed some examples mentioned in this thread and others and managed to get closer to the solution, but the final probability density function does not integrate to one. I guess, that I will post the question for this in another thread.

import ntumpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture

np.random.seed(1)

mus =  np.array([[0.2], [0.8]])
sigmas = np.array([[0.1], [0.1]]) ** 2
gmm = GaussianMixture(2)
gmm.means_ = mus
gmm.covars_ = sigmas
gmm.weights_ = np.array([0.5, 0.5])

#Fit the GMM with random data from the correspondent gaussians
gaus_samples_1 = np.random.normal(mus[0], sigmas[0], 10).reshape(10,1)
gaus_samples_2 = np.random.normal(mus[1], sigmas[1], 10).reshape(10,1)
fit_samples = np.concatenate((gaus_samples_1, gaus_samples_2))
gmm.fit(fit_samples)

fig = plt.figure()
ax = fig.add_subplot(111)
x = np.linspace(0, 1, 1000).reshape(1000,1)
logprob = gmm.score_samples(x)
pdf = np.exp(logprob)
#print np.max(pdf) -> 19.8409464401 !?
ax.plot(x, pdf, '-k')
plt.show()

Here is the resulting plot

like image 91
rauldg Avatar answered Sep 25 '22 01:09

rauldg


Take a look at this link:

http://www.astroml.org/book_figures/chapter4/fig_GMM_1D.html

They show how to plot a 1D GMM in 3 different ways:

1D GMM plots

like image 37
Sibelius Seraphini Avatar answered Sep 27 '22 01:09

Sibelius Seraphini


Take a look at the one of scikit-learn examples on Github

https://github.com/scikit-learn/scikit-learn/blob/master/examples/mixture/plot_gmm_pdf.py

The idea is to generate meshgrid, get their score from the gmm, and plot it.

The example shows

enter image description here

like image 22
emeth Avatar answered Sep 25 '22 01:09

emeth