Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to convert Matplotlib figure to PIL Image object (without saving image)

As the title states, I am trying to convert a fig to a PIL.Image. I am currently able to do so by first saving the fig to disk and then opening that file using Image.open() but the process is taking longer than expected and I am hoping that by skipping the saving locally step it will be a bit faster.

Here is what I have so far:

# build fig
figsize, dpi = self._calc_fig_size_res(img_height)
fig = plt.Figure(figsize=figsize)
canvas = FigureCanvas(fig)
ax = fig.add_subplot(111)
ax.imshow(torch.from_numpy(S).flip(0), cmap = cmap)
fig.subplots_adjust(left = 0, right = 1, bottom = 0, top = 1)
ax.axis('tight'); ax.axis('off')

# export
fig.savefig(export_path, dpi = dpi)

# open image as PIL object
img = Image.open(export_path)

I have tried doing this after I build the fig (it would be right before the export stage):

pil_img = Image.frombytes('RGB', canvas.get_width_height(), canvas.tostring_rgb())

But it's not showing the entire image. It looks like it's a crop of the top left corner, but it could just be a weird representation of the data -- I'm working with spectrograms so the images are fairly abstract.

like image 883
Zach Avatar asked Aug 01 '19 20:08

Zach


People also ask

How do I save a PLT plot as an image in Python?

Saving a plot on your disk as an image file Now if you want to save matplotlib figures as image files programmatically, then all you need is matplotlib. pyplot. savefig() function. Simply pass the desired filename (and even location) and the figure will be stored on your disk.

How do I save a Matplotlib file as a JPEG?

To save plot figure as JPG or PNG file, call savefig() function on matplotlib. pyplot object. Pass the file name along with extension, as string argument, to savefig() function.

How do I save a figure without showing in Matplotlib?

We can simply save plots generated from Matplotlib using savefig() and imsave() methods. If we are in interactive mode, the plot might get displayed. To avoid the display of plot we use close() and ioff() methods.


2 Answers

EDIT # 2

PIL.Image.frombytes('RGB', 
fig.canvas.get_width_height(),fig.canvas.tostring_rgb())

takes around 2ms compared to the 35/40ms of the below.

This is the fastest way I can find so far.


I've been looking at this also today.

In the matplotlib docs the savefig function had this.

pil_kwargsdict, optional Additional keyword arguments that are passed to PIL.Image.save when saving the figure. Only applicable for formats that are saved using Pillow, i.e. JPEG, TIFF, and (if the keyword is set to a non-None value) PNG.

This must mean it's already a pil image before saving but I can't see it.

You could follow this

Matplotlib: save plot to numpy array

To get it into a numpy array and then do

PIL.Image.fromarray(array)

You might need to reverse the channels from BGR TO RGB with array [:, :, ::-1]

EDIT:

I've tested each way come up with so far.

import io
    
def save_plot_and_get():
    fig.savefig("test.jpg")
    img = cv2.imread("test.jpg")
    return PIL.Image.fromarray(img)
    
def buffer_plot_and_get():
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    return PIL.Image.open(buf)
    
def from_canvas():
    lst = list(fig.canvas.get_width_height())
    lst.append(3)
    return PIL.Image.fromarray(np.fromstring(fig.canvas.tostring_rgb(),dtype=np.uint8).reshape(lst))

Results

%timeit save_plot_and_get()

35.5 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit save_plot_and_get()

35.5 ms ± 142 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit buffer_plot_and_get()

40.4 ms ± 152 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

like image 69
Lewis Morris Avatar answered Sep 24 '22 19:09

Lewis Morris


I use the following function:

def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

Example usage:

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

x = np.arange(-3,3)
plt.plot(x)
fig = plt.gcf()

img = fig2img(fig)

img.show()

like image 35
kotchwane Avatar answered Sep 22 '22 19:09

kotchwane