Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotting the density of the sum of two random variables in SymPy

Tags:

python

plot

sympy

I would like to make and plot a mixed random variable with sympy.

I have a Gaussian mixture of two evenly weighted normal distributions, one with a mean of 1, and one with a mean of 2.

from sympy.stats import Normal
mixed = 0.5 * Normal('n1', 1, 1) + 0.5 * Normal('n2', 2, 1)
E(mixed1)

Out: 1.5

This is correct, but I can't plot this distribution:

x = symbols('x')
sp.plot(mixed(x), x)

 ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-56-a1c8047b8c4a> in <module>()
----> sp.plot(mixed(x), x)

TypeError: 'Add' object is not callable

When I try to make it a density, I get a long error ending in

sp.plot(density(mixed1)(x), x)
...
UnboundLocalError: local variable 'reprec' referenced before assignment

Any ideas on why the plotting feature doesn't like the sum here?

like image 720
Scott Staniewicz Avatar asked Sep 15 '25 06:09

Scott Staniewicz


1 Answers

That mixed(x) throws an error is by design. RandomSymbol objects are not callable (i.e., cannot be treated as functions), and neither are their sums. Plotting density(mixed)(x) is the correct approach. But being correct does not always imply being successful.

With the parameter evaluate=False you will see that SymPy sets up a reasonable integral for the density:

density(mixed, evaluate=False)(x)

The problem is evaluating it, which fails for your example. But this can be fixed by using rational numbers instead of floating point numbers like 0.5.

mixed = Normal('n1', 1, 1) / 2 +  Normal('n2', 2, 1) / 2

The chances of successful integration are greater when you have rational numbers. And indeed, the following works:

from sympy.stats import *
mixed = Normal('n1', 1, 1) / 2 +  Normal('n2', 2, 1) / 2
x = symbols("x")
d = density(mixed)(x)

The formula for density looks complicated, but after simplification d = d.simplify() it is exactly what one would expect:

exp(-x**2 + 3*x - 9/4)/sqrt(pi)

Finally, plot(d, (x, -3, 5)) gives

plot