Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding the extent of a matplotlib plot (including ticklabels) in axis coordinates

I need to find the extent of a plot including its related artists (in this case just ticks and ticklabels) in axis coordinates (as defined in the matplotlib transformations tutorial).

The background to this is that I am automatically creating thumbnail plots (as in this SO question) for a large number of charts, only when I can position the thumbnail so that it does not obscure data in the original plot.

This is my current approach:

  1. Create a number of candidate rectangles to test, starting at the top-right of the original plot and working left, then the bottom-right of the original plot and move left.
  2. For each candidate rectangle:
    1. Using code from this SO question convert the left and right hand side of the rect (in axis coordinates) into data coordinates, to find which slice of the x-data the rectangle will cover.
    2. Find the minimum / maximum y-value for the slice of data the rectangle covers.
    3. Find the top and bottom of the rectangle in data coordinates.
    4. Using the above, determine whether the rectangle overlaps with any data. If not, draw the thumbnail plot in the current rectangle, otherwise continue.

The problem with this approach is that axis coordinates give you the extent of the axis from (0,0) (bottom-left of the axes) to (1,1) (top-right) and does not include ticks and ticklabels (the thumbnail plots do not have titles, axis labels, legends or other artists).

All charts use the same font sizes, but the charts have ticklabels of different lengths (e.g. 1.5 or 1.2345 * 10^6), although these are known before the inset is drawn. Is there a way to convert from font sizes / points to axis coordinates? Alternatively, maybe there is a better approach than the one above (bounding boxes?).

The following code implements the algorithm above:

import math

from matplotlib import pyplot, rcParams
rcParams['xtick.direction'] = 'out'
rcParams['ytick.direction'] = 'out'

INSET_DEFAULT_WIDTH = 0.35
INSET_DEFAULT_HEIGHT = 0.25
INSET_PADDING = 0.05
INSET_TICK_FONTSIZE = 8


def axis_data_transform(axis, xin, yin, inverse=False):
    """Translate between axis and data coordinates.
    If 'inverse' is True, data coordinates are translated to axis coordinates,
    otherwise the transformation is reversed.
    Code by Covich, from: https://stackoverflow.com/questions/29107800/
    """
    xlim, ylim = axis.get_xlim(), axis.get_ylim()
    xdelta, ydelta = xlim[1] - xlim[0], ylim[1] - ylim[0]
    if not inverse:
        xout, yout = xlim[0] + xin * xdelta, ylim[0] + yin * ydelta
    else:
        xdelta2, ydelta2 = xin - xlim[0], yin - ylim[0]
        xout, yout = xdelta2 / xdelta, ydelta2 / ydelta
    return xout, yout


def add_inset_to_axis(fig, axis, rect):
    left, bottom, width, height = rect
    def transform(coord):
        return fig.transFigure.inverted().transform(
            axis.transAxes.transform(coord))
    fig_left, fig_bottom = transform((left, bottom))
    fig_width, fig_height = transform([width, height]) - transform([0, 0])
    return fig.add_axes([fig_left, fig_bottom, fig_width, fig_height])


def collide_rect((left, bottom, width, height), fig, axis, data):
    # Find the values on the x-axis of left and right edges of the rect.
    x_left_float, _ = axis_data_transform(axis, left, 0, inverse=False)
    x_right_float, _ = axis_data_transform(axis, left + width, 0, inverse=False)
    x_left = int(math.floor(x_left_float))
    x_right = int(math.ceil(x_right_float))
    # Find the highest and lowest y-value in that segment of data.
    minimum_y = min(data[int(x_left):int(x_right)])
    maximum_y = max(data[int(x_left):int(x_right)])
    # Convert the bottom and top of the rect to data coordinates.
    _, inset_top = axis_data_transform(axis, 0, bottom + height, inverse=False)
    _, inset_bottom = axis_data_transform(axis, 0, bottom, inverse=False)
    # Detect collision.
    if ((bottom > 0.5 and maximum_y > inset_bottom) or  # inset at top of chart
           (bottom < 0.5 and minimum_y < inset_top)):   # inset at bottom
        return True
    return False


if __name__ == '__main__':
    x_data, y_data = range(0, 100), [-1.0] * 50 + [1.0] * 50  # Square wave.
    y_min, y_max = min(y_data), max(y_data)
    fig = pyplot.figure()
    axis = fig.add_subplot(111)
    axis.set_ylim(y_min - 0.1, y_max + 0.1)
    axis.plot(x_data, y_data)
    # Find a rectangle that does not collide with data. Start top-right
    # and work left, then try bottom-right and work left.
    inset_collides = False
    left_offsets = [x / 10.0 for x in xrange(6)] * 2
    bottom_values = (([1.0 - INSET_DEFAULT_HEIGHT - INSET_PADDING] * (len(left_offsets) / 2))
                     + ([INSET_PADDING * 2] * (len(left_offsets) / 2)))
    for left_offset, bottom in zip(left_offsets, bottom_values):
        # rect: (left, bottom, width, height)
        rect = (1.0 - INSET_DEFAULT_WIDTH - left_offset - INSET_PADDING,
                bottom, INSET_DEFAULT_WIDTH, INSET_DEFAULT_HEIGHT)
        inset_collides = collide_rect(rect, fig, axis, y_data)
        print 'TRYING:', rect, 'RESULT:', inset_collides
        if not inset_collides:
            break
    if not inset_collides:
        inset = add_inset_to_axis(fig, axis, rect)
        inset.set_ylim(axis.get_ylim())
        inset.set_yticks([y_min, y_min + ((y_max - y_min) / 2.0), y_max])
        inset.xaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
        inset.yaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
        inset_xlimit = (0, int(len(y_data) / 100.0 * 2.5)) # First 2.5% of data.
        inset.set_xlim(inset_xlimit[0], inset_xlimit[1], auto=False)
        inset.plot(x_data[inset_xlimit[0]:inset_xlimit[1] + 1],
                   y_data[inset_xlimit[0]:inset_xlimit[1] + 1])
    fig.savefig('so_example.png')

And the output of this is:

TRYING: (0.6, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.5, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.4, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.30000000000000004, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.2, 0.7, 0.35, 0.25) RESULT: True
TRYING: (0.10000000000000002, 0.7, 0.35, 0.25) RESULT: False

script output

like image 636
snim2 Avatar asked Jan 04 '17 11:01

snim2


1 Answers

My solution doesn't seem to detect tick marks, but does take care of the tick labels, axis labels and the figure title. Hopefully it's enough though, since a fixed pad value should be fine to account for the tick marks.

Use axes.get_tightbbox to obtain a rectangle that fits around the axes including labels.

from matplotlib import tight_layout
renderer = tight_layout.get_renderer(fig)
inset_tight_bbox = inset.get_tightbbox(renderer)

Whereas your original rectangle set the axis bbox, inset.bbox. Find the rectangles in axis coordinates for these two bboxes:

inv_transform = axis.transAxes.inverted() 

xmin, ymin = inv_transform.transform(inset.bbox.min)
xmin_tight, ymin_tight = inv_transform.transform(inset_tight_bbox.min) 

xmax, ymax = inv_transform.transform(inset.bbox.max)
xmax_tight, ymax_tight = inv_transform.transform(inset_tight_bbox.max)

Now calculate a new rectangle for the axis itself, such that the outer tight bbox will be reduced in size to the old axis bbox:

xmin_new = xmin + (xmin - xmin_tight)
ymin_new = ymin + (ymin - ymin_tight)
xmax_new = xmax - (xmax_tight - xmax)
ymax_new = ymax - (ymax_tight - ymax)     

Now, just switch back to figure coordinates and reposition the inset axes:

[x_fig,y_fig] = axis_to_figure_transform([xmin_new, ymin_new])
[x2_fig,y2_fig] = axis_to_figure_transform([xmax_new, ymax_new])

inset.set_position ([x_fig, y_fig, x2_fig - x_fig, y2_fig - y_fig])

The function axis_to_figure_transform is based on your transform function from add_inset_to_axis:

def axis_to_figure_transform(coord, axis):
    return fig.transFigure.inverted().transform(
        axis.transAxes.transform(coord))

Note: this doesn't work with fig.show(), at least on my system; tight_layout.get_renderer(fig) causes an error. However, it works fine if you're only using savefig() and not displaying the plot interactively.

Finally, here's your full code with my changes and additions:

import math

from matplotlib import pyplot, rcParams, tight_layout
rcParams['xtick.direction'] = 'out'
rcParams['ytick.direction'] = 'out'

INSET_DEFAULT_WIDTH = 0.35
INSET_DEFAULT_HEIGHT = 0.25
INSET_PADDING = 0.05
INSET_TICK_FONTSIZE = 8

def axis_data_transform(axis, xin, yin, inverse=False):
    """Translate between axis and data coordinates.
    If 'inverse' is True, data coordinates are translated to axis coordinates,
    otherwise the transformation is reversed.
    Code by Covich, from: http://stackoverflow.com/questions/29107800/
    """
    xlim, ylim = axis.get_xlim(), axis.get_ylim()
    xdelta, ydelta = xlim[1] - xlim[0], ylim[1] - ylim[0]
    if not inverse:
        xout, yout = xlim[0] + xin * xdelta, ylim[0] + yin * ydelta
    else:
        xdelta2, ydelta2 = xin - xlim[0], yin - ylim[0]
        xout, yout = xdelta2 / xdelta, ydelta2 / ydelta
    return xout, yout

def axis_to_figure_transform(coord, axis):
    return fig.transFigure.inverted().transform(
        axis.transAxes.transform(coord))

def add_inset_to_axis(fig, axis, rect):
    left, bottom, width, height = rect

    fig_left, fig_bottom = axis_to_figure_transform((left, bottom), axis)
    fig_width, fig_height = axis_to_figure_transform([width, height], axis) \
                                   - axis_to_figure_transform([0, 0], axis)
    return fig.add_axes([fig_left, fig_bottom, fig_width, fig_height], frameon=True)


def collide_rect((left, bottom, width, height), fig, axis, data):
    # Find the values on the x-axis of left and right edges of the rect.
    x_left_float, _ = axis_data_transform(axis, left, 0, inverse=False)
    x_right_float, _ = axis_data_transform(axis, left + width, 0, inverse=False)
    x_left = int(math.floor(x_left_float))
    x_right = int(math.ceil(x_right_float))
    # Find the highest and lowest y-value in that segment of data.
    minimum_y = min(data[int(x_left):int(x_right)])
    maximum_y = max(data[int(x_left):int(x_right)])
    # Convert the bottom and top of the rect to data coordinates.
    _, inset_top = axis_data_transform(axis, 0, bottom + height, inverse=False)
    _, inset_bottom = axis_data_transform(axis, 0, bottom, inverse=False)
    # Detect collision.
    if ((bottom > 0.5 and maximum_y > inset_bottom) or  # inset at top of chart
           (bottom < 0.5 and minimum_y < inset_top)):   # inset at bottom
        return True
    return False


if __name__ == '__main__':
    x_data, y_data = range(0, 100), [-1.0] * 50 + [1.0] * 50  # Square wave.
    y_min, y_max = min(y_data), max(y_data)
    fig = pyplot.figure()
    axis = fig.add_subplot(111)
    axis.set_ylim(y_min - 0.1, y_max + 0.1)
    axis.plot(x_data, y_data)
    # Find a rectangle that does not collide with data. Start top-right
    # and work left, then try bottom-right and work left.
    inset_collides = False
    left_offsets = [x / 10.0 for x in xrange(6)] * 2
    bottom_values = (([1.0 - INSET_DEFAULT_HEIGHT - INSET_PADDING] * (len(left_offsets) / 2))
                     + ([INSET_PADDING * 2] * (len(left_offsets) / 2)))
    for left_offset, bottom in zip(left_offsets, bottom_values):
        # rect: (left, bottom, width, height)
        rect = (1.0 - INSET_DEFAULT_WIDTH - left_offset - INSET_PADDING,
                bottom, INSET_DEFAULT_WIDTH, INSET_DEFAULT_HEIGHT)
        inset_collides = collide_rect(rect, fig, axis, y_data)
        print 'TRYING:', rect, 'RESULT:', inset_collides
        if not inset_collides:
            break
    if not inset_collides:
        inset = add_inset_to_axis(fig, axis, rect)
        inset.set_ylim(axis.get_ylim())
        inset.set_yticks([y_min, y_min + ((y_max - y_min) / 2.0), y_max])
        inset.xaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
        inset.yaxis.set_tick_params(labelsize=INSET_TICK_FONTSIZE)
        inset_xlimit = (0, int(len(y_data) / 100.0 * 2.5)) # First 2.5% of data.
        inset.set_xlim(inset_xlimit[0], inset_xlimit[1], auto=False)
        inset.plot(x_data[inset_xlimit[0]:inset_xlimit[1] + 1],
                   y_data[inset_xlimit[0]:inset_xlimit[1] + 1])


    # borrow this function from tight_layout 
    renderer = tight_layout.get_renderer(fig)
    inset_tight_bbox = inset.get_tightbbox(renderer)

    # uncomment this to show where the two bboxes are
#    def show_bbox_on_plot(ax, bbox, color='b'):
#        inv_transform = ax.transAxes.inverted()
#        xmin, ymin = inv_transform.transform(bbox.min)
#        xmax, ymax = inv_transform.transform(bbox.max)
#        axis.add_patch(pyplot.Rectangle([xmin, ymin], xmax-xmin, ymax-ymin, transform=axis.transAxes, color = color))
#        
#    show_bbox_on_plot(axis, inset_tight_bbox)
#    show_bbox_on_plot(axis, inset.bbox, color = 'g')

    inv_transform = axis.transAxes.inverted() 

    xmin, ymin = inv_transform.transform(inset.bbox.min)
    xmin_tight, ymin_tight = inv_transform.transform(inset_tight_bbox.min) 

    xmax, ymax = inv_transform.transform(inset.bbox.max)
    xmax_tight, ymax_tight = inv_transform.transform(inset_tight_bbox.max)

    # shift actual axis bounds inwards by "margin" so that new size + margin
    # is original axis bounds
    xmin_new = xmin + (xmin - xmin_tight)
    ymin_new = ymin + (ymin - ymin_tight)
    xmax_new = xmax - (xmax_tight - xmax)
    ymax_new = ymax - (ymax_tight - ymax)

    [x_fig,y_fig] = axis_to_figure_transform([xmin_new, ymin_new], axis)
    [x2_fig,y2_fig] = axis_to_figure_transform([xmax_new, ymax_new], axis)

    inset.set_position ([x_fig, y_fig, x2_fig - x_fig, y2_fig - y_fig])

    fig.savefig('so_example.png')
like image 119
Ben Schmidt Avatar answered Oct 19 '22 07:10

Ben Schmidt