Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Instead of grid lines on a plot, can matplotlib print grid crosses?

I want to have some grid lines on a plot, but actually full-length lines are too much/distracting, even dashed light grey lines. I went and manually did some editing of the SVG output to get the effect I was looking for. Can this be done with matplotlib? I had a look at the pyplot api for grid, and the only thing I can see that might be able to get near it are the xdata and ydata Line2D kwargs.

enter image description here

like image 299
a different ben Avatar asked Apr 25 '12 11:04

a different ben


2 Answers

This cannot be done through the basic API, because the grid lines are created using only two points. The grid lines would need a 'data' point at every tick mark for there to be a marker drawn. This is shown in the following example:

import matplotlib.pyplot as plt

ax = plt.subplot(111)
ax.grid(clip_on=False, marker='o', markersize=10)
plt.savefig('crosses.png')
plt.show()

This results in:

enter image description here

Notice how the 'o' markers are only at the beginning and the end of the Axes edges, because the grid lines only involve two points.

You could write a method to emulate what you want, creating the cross marks using a series of Artists, but it's quicker to just leverage the basic plotting capabilities to draw the cross pattern.

This is what I do in the following example:

import matplotlib.pyplot as plt
import numpy as np

NPOINTS=100

def set_grid_cross(ax, in_back=True):
    xticks = ax.get_xticks()
    yticks = ax.get_yticks()
    xgrid, ygrid = np.meshgrid(xticks, yticks)
    kywds = dict() 
    if in_back:
        kywds['zorder'] = 0
    grid_lines = ax.plot(xgrid, ygrid, 'k+', **kywds)

xvals = np.arange(NPOINTS)
yvals = np.random.random(NPOINTS) * NPOINTS

ax1 = plt.subplot(121)
ax2 = plt.subplot(122)

ax1.plot(xvals, yvals, linewidth=4)
ax1.plot(xvals, xvals, linewidth=7)
set_grid_cross(ax1)
ax2.plot(xvals, yvals, linewidth=4)
ax2.plot(xvals, xvals, linewidth=7)
set_grid_cross(ax2, in_back=False)

plt.savefig('gridpoints.png')
plt.show()

This results in the following figure:

enter image description here

As you can see, I take the tick marks in x and y to define a series of points where I want grid marks ('+'). I use meshgrid to take two 1D arrays and make 2 2D arrays corresponding to the double loop over each grid point. I plot this with the mark style as '+', and I'm done... almost. This plots the crosses on top, and I added an extra keyword to reorder the list of lines associated with the plot. I adjust the zorder of the grid marks if they are to be drawn behind everything.*****

The example shows the left subplot where by default the grid is placed in back, and the right subplot disables this option. You can notice the difference if you follow the green line in each plot.

If you are bothered by having grid crosses on the boarder, you can remove the first and last tick marks for both x and y before you define the grid in set_grid_cross, like so:

xticks = ax.get_xticks()[1:-1] #< notice the slicing
yticks = ax.get_yticks()[1:-1] #< notice the slicing
xgrid, ygrid = np.meshgrid(xticks, yticks)

I do this in the following example, using a larger, different marker to make my point:

enter image description here

***** Thanks to the answer by @fraxel for pointing this out.

like image 153
Yann Avatar answered Sep 28 '22 05:09

Yann


You can draw on line segments at every intersection of the tickpoints. Its pretty easy to do, just grab the tick locations get_ticklocs() for both axis, then loop through all combinations, drawing short line segments using axhline and axvline, thus creating a cross hair at every intersection. I've set zorder=0 so the cross-hairs are drawn first, so that they are behind the plot data. Its easy to control the color/alpha and cross-hair size. Couple of slight 'gotchas'... do the plot before you get the tick locations.. and also the xmin and xmax parameters seem to require normalisation.

import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(1,1,1)    
ax.plot((0,2,3,5,5,5,6,7,8,6,6,4,3,32,7,99), 'r-',linewidth=4)

x_ticks = ax.xaxis.get_ticklocs()
y_ticks = ax.yaxis.get_ticklocs()    
for yy in y_ticks[1:-1]:
    for xx in x_ticks[1:-1]:
        plt.axhline(y=yy, xmin=xx / max(x_ticks) - 0.02, 
                xmax=xx / max(x_ticks) + 0.02, color='gray', alpha=0.5, zorder=0)
        plt.axvline(x=xx, ymin=yy / max(y_ticks) - 0.02, 
                ymax=yy / max(y_ticks) + 0.02, color='gray', alpha=0.5, zorder=0)
plt.show()

enter image description here

like image 40
fraxel Avatar answered Sep 28 '22 05:09

fraxel