Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

draw a border around subplots in matplotlib

Tags:

Anyone know how to draw a border around an individual subplot within a figure in matplotlib? I'm using pyplot.

eg:

import matplotlib.pyplot as plt
f = plt.figure()
ax1 = f.add_subplot(211)
ax2 = f.add_subplot(212)
# ax1.set_edgecolor('black')

..but Axes objects have no 'edgecolor', and I can't seem to find a way to outline the plot from the figure level either.

I'm actually wrapping mpl code and adding a wx UI with controls that I would like to have context depending on which subplot is selected. i.e. User clicks on subplot within figure canvas -- subplot is 'selected' (has an outline drawn around it, ideally sawtooth) -- GUI updates to present controls to modify that specific subplot.

like image 467
Adam Fraser Avatar asked Jan 08 '10 12:01

Adam Fraser


2 Answers

You essentially want to draw outside of the axes, right?

I adapted this from here. It would need clean up as I used some hard-coded "fudge-factors" in there.

#!/usr/bin/env python
from pylab import *

def f(t):
    s1 = cos(2*pi*t)
    e1 = exp(-t)
    return multiply(s1,e1)

t1 = arange(0.0, 5.0, 0.1)
t2 = arange(0.0, 5.0, 0.02)
t3 = arange(0.0, 2.0, 0.01)

figure(figsize=(4, 4))
sub1 = subplot(211)
l = plot(t1, f(t1), 'bo', t2, f(t2), 'k--', markerfacecolor='green')
grid(True)
title('A tale of 2 subplots')
ylabel('Damped oscillation')

## I ADDED THIS
autoAxis = sub1.axis()
rec = Rectangle((autoAxis[0]-0.7,autoAxis[2]-0.2),(autoAxis[1]-autoAxis[0])+1,(autoAxis[3]-autoAxis[2])+0.4,fill=False,lw=2)
rec = sub1.add_patch(rec)
rec.set_clip_on(False)

subplot(212)
plot(t3, cos(2*pi*t3), 'r.')
grid(True)
xlabel('time (s)')
ylabel('Undamped')

savefig('test.png')

Produces:

enter image description here

like image 131
Mark Avatar answered Sep 21 '22 23:09

Mark


An alternative solution is derived from this answer on SO regarding placing Rectangle patches directly to the figure canvas, rather than to individual axes:

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(nrows=2, ncols=1)
axes[0].plot(np.cumsum(np.random.randn(100)))
axes[1].plot(np.cumsum(np.random.randn(100)))

rect = plt.Rectangle(
    # (lower-left corner), width, height
    (0.02, 0.5), 0.97, 0.49, fill=False, color="k", lw=2, 
    zorder=1000, transform=fig.transFigure, figure=fig
)
fig.patches.extend([rect])

plt.tight_layout()
plt.show()

Result:

enter image description here

like image 34
MPA Avatar answered Sep 19 '22 23:09

MPA