I need to find the extent of a plot including its related artists (in this case just ticks and ticklabels) in axis coordinates (as defined in the matplotlib transformations tutorial).
The background to this is that I am automatically creating thumbnail plots (as in this SO question) for a large number of charts, only when I can position the thumbnail so that it does not obscure data in the original plot.
This is my current approach:
The problem with this approach is that axis coordinates give you the extent of the axis from (0,0)
(bottom-left of the axes) to (1,1)
(top-right) and does not include ticks and ticklabels (the thumbnail plots do not have titles, axis labels, legends or other artists).
All charts use the same font sizes, but the charts have ticklabels of different lengths (e.g. 1.5
or 1.2345 * 10^6
), although these are known before the inset is drawn. Is there a way to convert from font sizes / points to axis coordinates? Alternatively, maybe there is a better approach than the one above (bounding boxes?).
The following code implements the algorithm above:
import math
from matplotlib import pyplot, rcParams
rcParams['xtick.direction'] = 'out'
rcParams['ytick.direction'] = 'out'
INSET_DEFAULT_WIDTH = 0.35
INSET_DEFAULT_HEIGHT = 0.25
INSET_PADDING = 0.05
INSET_TICK_FONTSIZE = 8
def axis_data_transform(axis, xin, yin, inverse=False):
"""Translate between axis and data coordinates.
If 'inverse' is True, data coordinates are translated to axis coordinates,
otherwise the transformation is reversed.
Code by Covich, from: https://stackoverflow.com/questions/29107800/
"""
xlim, ylim = axis.get_xlim(), axis.get_ylim()
xdelta, ydelta = xlim[1] - xlim[0], ylim[1] - ylim[0]
if not inverse:
xout, yout = xlim[0] + xin * xdelta, ylim[0] + yin * ydelta
else:
xdelta2, ydelta2 = xin - xlim[0], yin - ylim[0]
xout, yout = xdelta2 / xdelta, ydelta2 / ydelta
return xout, yout
def add_inset_to_axis(fig, axis, rect):
left, bottom, width, height = rect
def transform(coord):
return fig.transFigure.inverted().transform(
axis.transAxes.transform(coord))
fig_left, fig_bottom = transform((left, bottom))
fig_width, fig_height = transform([width, height]) - transform([0, 0])
return fig.add_axes([fig_left, fig_bottom, fig_width, fig_height])
def collide_rect((left, bottom, width, height), fig, axis, data):
# Find the values on the x-axis of left and right edges of the rect.
x_left_float, _ = axis_data_transform(axis, left, 0, inverse=False)
x_right_float, _ = axis_data_transform(axis, left + width, 0, inverse=False)
x_left = int(math.floor(x_left_float))
x_right = int(math.ceil(x_right_float))
# Find the highest and lowest y-value in that segment of data.
minimum_y = min(data[int(x_left):int(x_right)])
maximum_y = max(data[int(x_left):int(x_right)])
# Convert the bottom and top of the rect to data coordinates.
_, inset_top = axis_data_transform(axis, 0, bottom + height, inverse=False)
_, inset_bottom = axis_data_transform(axis, 0, bottom, inverse=False)
# Detect collision.
if ((bottom > 0.5 and maximum_y > inset_bottom) or # inset at top of chart
(bottom < 0.5 and minimum_y < inset_top)): # inset at bottom
return True
return False
if __name__ == '__main__':
x_data, y_data = range(0, 100), [-1.0] * 50 + [1.0] * 50 # Square wave.
y_min, y_max = min(y_data), max(y_data)
fig = pyplot.figure()
axis = fig.add_subplot(111)
axis.set_ylim(y_min - 0.1, y_max + 0.1)
axis.plot(x_data, y_data)
# Find a rectangle that does not collide with data. Start top-right
# and work left, then try bottom-right and work left.
inset_collides = False
left_offsets = [x / 10.0 for x in xrange(6)] * 2
bottom_values = (([1.0 - INSET_DEFAULT_HEIGHT - INSET_PADDING] * (len(left_offsets) / 2))
+ ([INSET_PADDING * 2] * (len(left_offsets) / 2)))
for left_offset, bottom in zip(left_offsets, bottom_values):
# rect: (left, bottom, width, height)
rect = (1.0 - INSET_DEFAULT_WIDTH - left_offset - INSET_PADDING,
bottom, INSET_DEFAULT_WIDTH, INSET_DEFAULT_HEIGHT)
inset_collides = collide_rect(rect, fig, axis, y_data)
print 'TRYING:', rect, 'RESULT:', inset_collides
if not inset_collides:
break
if not inset_collides:
inset = add_inset_to_axis(fig, axis, rect)
inset.set_ylim(axis.get_ylim())
inset.set_yticks([y_min, y_min + ((y_max - y_min) / 2.0), y_max])
inset.xaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
inset.yaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
inset_xlimit = (0, int(len(y_data) / 100.0 * 2.5)) # First 2.5% of data.
inset.set_xlim(inset_xlimit[0], inset_xlimit[1], auto=False)
inset.plot(x_data[inset_xlimit[0]:inset_xlimit[1] + 1],
y_data[inset_xlimit[0]:inset_xlimit[1] + 1])
fig.savefig('so_example.png')
And the output of this is:
TRYING: (0.6, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.5, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.4, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.30000000000000004, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.2, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.10000000000000002, 0.7, 0.35, 0.25) RESULT: False
My solution doesn't seem to detect tick marks, but does take care of the tick labels, axis labels and the figure title. Hopefully it's enough though, since a fixed pad value should be fine to account for the tick marks.
Use axes.get_tightbbox to obtain a rectangle that fits around the axes including labels.
from matplotlib import tight_layout
renderer = tight_layout.get_renderer(fig)
inset_tight_bbox = inset.get_tightbbox(renderer)
Whereas your original rectangle set the axis bbox, inset.bbox
. Find the rectangles in axis coordinates for these two bboxes:
inv_transform = axis.transAxes.inverted()
xmin, ymin = inv_transform.transform(inset.bbox.min)
xmin_tight, ymin_tight = inv_transform.transform(inset_tight_bbox.min)
xmax, ymax = inv_transform.transform(inset.bbox.max)
xmax_tight, ymax_tight = inv_transform.transform(inset_tight_bbox.max)
Now calculate a new rectangle for the axis itself, such that the outer tight bbox will be reduced in size to the old axis bbox:
xmin_new = xmin + (xmin - xmin_tight)
ymin_new = ymin + (ymin - ymin_tight)
xmax_new = xmax - (xmax_tight - xmax)
ymax_new = ymax - (ymax_tight - ymax)
Now, just switch back to figure coordinates and reposition the inset axes:
[x_fig,y_fig] = axis_to_figure_transform([xmin_new, ymin_new])
[x2_fig,y2_fig] = axis_to_figure_transform([xmax_new, ymax_new])
inset.set_position ([x_fig, y_fig, x2_fig - x_fig, y2_fig - y_fig])
The function axis_to_figure_transform
is based on your transform
function from add_inset_to_axis
:
def axis_to_figure_transform(coord, axis):
return fig.transFigure.inverted().transform(
axis.transAxes.transform(coord))
Note: this doesn't work with fig.show()
, at least on my system; tight_layout.get_renderer(fig)
causes an error. However, it works fine if you're only using savefig()
and not displaying the plot interactively.
Finally, here's your full code with my changes and additions:
import math
from matplotlib import pyplot, rcParams, tight_layout
rcParams['xtick.direction'] = 'out'
rcParams['ytick.direction'] = 'out'
INSET_DEFAULT_WIDTH = 0.35
INSET_DEFAULT_HEIGHT = 0.25
INSET_PADDING = 0.05
INSET_TICK_FONTSIZE = 8
def axis_data_transform(axis, xin, yin, inverse=False):
"""Translate between axis and data coordinates.
If 'inverse' is True, data coordinates are translated to axis coordinates,
otherwise the transformation is reversed.
Code by Covich, from: http://stackoverflow.com/questions/29107800/
"""
xlim, ylim = axis.get_xlim(), axis.get_ylim()
xdelta, ydelta = xlim[1] - xlim[0], ylim[1] - ylim[0]
if not inverse:
xout, yout = xlim[0] + xin * xdelta, ylim[0] + yin * ydelta
else:
xdelta2, ydelta2 = xin - xlim[0], yin - ylim[0]
xout, yout = xdelta2 / xdelta, ydelta2 / ydelta
return xout, yout
def axis_to_figure_transform(coord, axis):
return fig.transFigure.inverted().transform(
axis.transAxes.transform(coord))
def add_inset_to_axis(fig, axis, rect):
left, bottom, width, height = rect
fig_left, fig_bottom = axis_to_figure_transform((left, bottom), axis)
fig_width, fig_height = axis_to_figure_transform([width, height], axis) \
- axis_to_figure_transform([0, 0], axis)
return fig.add_axes([fig_left, fig_bottom, fig_width, fig_height], frameon=True)
def collide_rect((left, bottom, width, height), fig, axis, data):
# Find the values on the x-axis of left and right edges of the rect.
x_left_float, _ = axis_data_transform(axis, left, 0, inverse=False)
x_right_float, _ = axis_data_transform(axis, left + width, 0, inverse=False)
x_left = int(math.floor(x_left_float))
x_right = int(math.ceil(x_right_float))
# Find the highest and lowest y-value in that segment of data.
minimum_y = min(data[int(x_left):int(x_right)])
maximum_y = max(data[int(x_left):int(x_right)])
# Convert the bottom and top of the rect to data coordinates.
_, inset_top = axis_data_transform(axis, 0, bottom + height, inverse=False)
_, inset_bottom = axis_data_transform(axis, 0, bottom, inverse=False)
# Detect collision.
if ((bottom > 0.5 and maximum_y > inset_bottom) or # inset at top of chart
(bottom < 0.5 and minimum_y < inset_top)): # inset at bottom
return True
return False
if __name__ == '__main__':
x_data, y_data = range(0, 100), [-1.0] * 50 + [1.0] * 50 # Square wave.
y_min, y_max = min(y_data), max(y_data)
fig = pyplot.figure()
axis = fig.add_subplot(111)
axis.set_ylim(y_min - 0.1, y_max + 0.1)
axis.plot(x_data, y_data)
# Find a rectangle that does not collide with data. Start top-right
# and work left, then try bottom-right and work left.
inset_collides = False
left_offsets = [x / 10.0 for x in xrange(6)] * 2
bottom_values = (([1.0 - INSET_DEFAULT_HEIGHT - INSET_PADDING] * (len(left_offsets) / 2))
+ ([INSET_PADDING * 2] * (len(left_offsets) / 2)))
for left_offset, bottom in zip(left_offsets, bottom_values):
# rect: (left, bottom, width, height)
rect = (1.0 - INSET_DEFAULT_WIDTH - left_offset - INSET_PADDING,
bottom, INSET_DEFAULT_WIDTH, INSET_DEFAULT_HEIGHT)
inset_collides = collide_rect(rect, fig, axis, y_data)
print 'TRYING:', rect, 'RESULT:', inset_collides
if not inset_collides:
break
if not inset_collides:
inset = add_inset_to_axis(fig, axis, rect)
inset.set_ylim(axis.get_ylim())
inset.set_yticks([y_min, y_min + ((y_max - y_min) / 2.0), y_max])
inset.xaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
inset.yaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
inset_xlimit = (0, int(len(y_data) / 100.0 * 2.5)) # First 2.5% of data.
inset.set_xlim(inset_xlimit[0], inset_xlimit[1], auto=False)
inset.plot(x_data[inset_xlimit[0]:inset_xlimit[1] + 1],
y_data[inset_xlimit[0]:inset_xlimit[1] + 1])
# borrow this function from tight_layout
renderer = tight_layout.get_renderer(fig)
inset_tight_bbox = inset.get_tightbbox(renderer)
# uncomment this to show where the two bboxes are
# def show_bbox_on_plot(ax, bbox, color='b'):
# inv_transform = ax.transAxes.inverted()
# xmin, ymin = inv_transform.transform(bbox.min)
# xmax, ymax = inv_transform.transform(bbox.max)
# axis.add_patch(pyplot.Rectangle([xmin, ymin], xmax-xmin, ymax-ymin, transform=axis.transAxes, color = color))
#
# show_bbox_on_plot(axis, inset_tight_bbox)
# show_bbox_on_plot(axis, inset.bbox, color = 'g')
inv_transform = axis.transAxes.inverted()
xmin, ymin = inv_transform.transform(inset.bbox.min)
xmin_tight, ymin_tight = inv_transform.transform(inset_tight_bbox.min)
xmax, ymax = inv_transform.transform(inset.bbox.max)
xmax_tight, ymax_tight = inv_transform.transform(inset_tight_bbox.max)
# shift actual axis bounds inwards by "margin" so that new size + margin
# is original axis bounds
xmin_new = xmin + (xmin - xmin_tight)
ymin_new = ymin + (ymin - ymin_tight)
xmax_new = xmax - (xmax_tight - xmax)
ymax_new = ymax - (ymax_tight - ymax)
[x_fig,y_fig] = axis_to_figure_transform([xmin_new, ymin_new], axis)
[x2_fig,y2_fig] = axis_to_figure_transform([xmax_new, ymax_new], axis)
inset.set_position ([x_fig, y_fig, x2_fig - x_fig, y2_fig - y_fig])
fig.savefig('so_example.png')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With