Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Retrieve matplotlib ContourSet for SymPy plots

Using SymPy, I can create a contour plot manually using the following code (there isn't yet a built-in contour plotting function Update: SymPy now has a contour plotting function):

from sympy import init_session
init_session()
from sympy.plotting.plot import Plot, ContourSeries

# show plot centered at 0,0
x_min = -7
x_max = 7
y_min = -5
y_max = 5

# contour plot of inverted cone
my_plot = Plot(
    ContourSeries(
        sqrt(x**2 + y**2),
        (x,x_min,x_max),
        (y,y_min,y_max)
    )
)
my_plot.show()

contour plot 14x10

Currently, when SymPy calls contour(), it does not appear to be saving the returned ContourSet (Update: I have filed a issue to see if the ContourSet can be saved):

class MatplotlibBackend(BaseBackend):
    ...
    def process_series(self):  
        ... 
        for s in self.parent._series:
            # Create the collections
            ...
            elif s.is_contour:
                self.ax.contour(*s.get_meshes()) # returned ContourSet not saved by SymPy

In other examples where modifications are performed to the plot, such as adding inline labels using clabel(), the ContourSet (CS) is needed:

# Create a simple contour plot with labels using default colors.
plt.figure()
CS = plt.contour(X, Y, Z) # CS is the ContourSet
plt.clabel(CS, inline=1, fontsize=10)
plt.title('Simplest default with labels')

contour plot with labels

Going back to the SymPy example, my_plot._backend does provide access to the figure and axes; what workarounds are possible to keep or obtain access to the ContourSet?

like image 657
chrstphrchvz Avatar asked Jan 10 '17 16:01

chrstphrchvz


2 Answers

One option when SymPy's built-in plotting capabilities fall short of what you want is to use matplotlib directly. The key is to use lambdify to convert the SymPy expression to a NumPy function.

f = lambdify((x, y), sqrt(x**2 + y**2), 'numpy')

The following creates a contour plot, with c as the ContourSet object.

a = numpy.linspace(-7, 7, 1000)
b = numpy.linspace(-5, 5, 1000)
x, y = numpy.meshgrid(a, b)
c = matplotlib.pyplot.contour(x, y, f(x, y))

enter image description here

like image 141
asmeurer Avatar answered Nov 04 '22 10:11

asmeurer


In this question you learn that you need to keep a reference to your contour plot, from which you can retrieve the points. For your 'Simplest default with labels', CS.collections[0].get_paths().

But I can't find a way to retrieve the list of plots from a given axes object...

In the following code C contains the desired collection of lines, but one collection per level in the contour plot and with some scaling:

import numpy as np
import matplotlib.pylab as plt

X = np.arange(10)
Y = np.arange(10)
Z = np.random.random((10,10))

fig = plt.figure()
ax = fig.add_subplot(111)
A = ax.get_children()
cs = ax.contour (X, Y, Z)
B = ax.get_children()

C = [ref for ref in B if ref not in A]

fig2 = plt.figure()
ax2 = fig2.add_subplot(111)
ax2.add_collection(C[0])

Which results in fig: fig and fig2: fig2

like image 1
berna1111 Avatar answered Nov 04 '22 10:11

berna1111