Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matplotlib returning a plot object

I have a function that wraps pyplot.plt so I can quickly create graphs with oft-used defaults:

def plot_signal(time, signal, title='', xlab='', ylab='',                 line_width=1, alpha=1, color='k',                 subplots=False, show_grid=True, fig_size=(10, 5)):      # Skipping a lot of other complexity here      f, axarr = plt.subplots(figsize=fig_size)     axarr.plot(time, signal, linewidth=line_width,                alpha=alpha, color=color)     axarr.set_xlim(min(time), max(time))     axarr.set_xlabel(xlab)     axarr.set_ylabel(ylab)     axarr.grid(show_grid)      plt.suptitle(title, size=16)     plt.show() 

However, there are times where I'd want to be able to return the plot so I can manually add/edit things for a specific graph. For example, I want to be able to change the axis labels, or add a second line to the plot after calling the function:

import numpy as np  x = np.random.rand(100) y = np.random.rand(100)  plot = plot_signal(np.arange(len(x)), x)  plot.plt(y, 'r') plot.show() 

I've seen a few questions on this (How to return a matplotlib.figure.Figure object from Pandas plot function? and AttributeError: 'Figure' object has no attribute 'plot') and as a result I've tried adding the following to the end of the function:

  • return axarr

  • return axarr.get_figure()

  • return plt.axes()

However, they all return a similar error: AttributeError: 'AxesSubplot' object has no attribute 'plt'

Whats the correct way to return a plot object so it can be edited later?

like image 933
Simon Avatar asked May 11 '17 20:05

Simon


People also ask

How do I return a figure in Matplotlib?

MatPlotLib with Python Make a function plot(x, y) that creates a new figure or activate an existing figure using figure() method. Plot the x and y data points using plot() method; return fig instance. Call plot(x, y) method and store the figure instance in a variable, f. To display the figure, use show() method.

What does PLT show () return?

matplotlib.pyplot.show() Function Returns: This method does not return any value.

How do I save a plot in Matplotlib?

Matplotlib plots can be saved as image files using the plt. savefig() function. The plt. savefig() function needs to be called right above the plt.

Is Matplotlib a PLT object?

While it is easy to quickly generate plots with the matplotlib. pyplot module, the use of object-oriented approach is recommended as it gives more control and customization of your plots. Most of the functions are also available in the matplotlib.


2 Answers

I think the error is pretty self-explanatory. There is no such thing as pyplot.plt, or similar. plt is the quasi-standard abbreviated form of pyplot when being imported, i.e., import matplotlib.pyplot as plt.

Concerning the problem, the first approach, return axarr is the most versatile one. You get an axis, or an array of axes, and can plot to it.

The code may look like:

def plot_signal(x,y, ..., **kwargs):     # Skipping a lot of other complexity here     f, ax = plt.subplots(figsize=fig_size)     ax.plot(x,y, ...)     # further stuff     return ax  ax = plot_signal(x,y, ...) ax.plot(x2, y2, ...) plt.show() 
like image 60
ImportanceOfBeingErnest Avatar answered Sep 29 '22 10:09

ImportanceOfBeingErnest


This is actually a great question that took me YEARS to figure out. A great way to do this is to pass a figure object to your code and have your function add an axis then return the updated figure. Here is an example:

fig_size = (10, 5) f = plt.figure(figsize=fig_size)  def plot_signal(time, signal, title='', xlab='', ylab='',                 line_width=1, alpha=1, color='k',                 subplots=False, show_grid=True, fig=f):      # Skipping a lot of other complexity here      axarr = f.add_subplot(1,1,1) # here is where you add the subplot to f     plt.plot(time, signal, linewidth=line_width,                alpha=alpha, color=color)     plt.set_xlim(min(time), max(time))     plt.set_xlabel(xlab)     plt.set_ylabel(ylab)     plt.grid(show_grid)     plt.title(title, size=16)          return(f) f = plot_signal(time, signal, fig=f) f 
like image 38
dpear_ Avatar answered Sep 29 '22 11:09

dpear_