Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I convert from scatter size to data coordinates in matplotlib?

I would like to programmatically test whether two scatterplot glyphs will overlap in matplotlib. So given a pair of (x, y) coordinates and a size (which as i understand is the area of the circle, in points), I would like to plot

plt.scatter(x, y, s=s)

and then have a function called points_overlap that takes these parameters and returns True if the points will overlap and False otherwise.

def points_overlap(x, y, s):
    if ...
        return True
    else:
        return False

I know there are transformation matrices to take me between the different matplotlib coordinate systems, but I can't figure out the right steps for writing this function.

like image 888
mwaskom Avatar asked Dec 02 '14 21:12

mwaskom


People also ask

What is the difference between scatter and plot in Matplotlib?

The difference between the two functions is: with pyplot. plot() any property you apply (color, shape, size of points) will be applied across all points whereas in pyplot. scatter() you have more control in each point's appearance. That is, in plt.

What is marker size in scatter?

So far the answer to what the size of a scatter marker means is given in units of points. Points are often used in typography, where fonts are specified in points. Also linewidths is often specified in points. The standard size of points in matplotlib is 72 points per inch (ppi) - 1 point is hence 1/72 inches.


1 Answers

This needs some testing, but it might work? These should all be in Display space

def overlap(x, y, sx, sy):
    return np.linalg.norm(x - y) < np.linalg.norm(sx + sy)

test:

In [227]: X = np.array([[1, 1], [2, 1], [2.5, 1]])
In [228]: s = np.array([20, 10000, 10000])

In [229]: fig, ax = plt.subplots()

In [230]: ax.scatter(X[:, 0], X[:, 1], s=s)
Out[230]: <matplotlib.collections.PathCollection at 0x10c32f28>

In [231]: plt.draw()

Test every pair:

Xt = ax.transData.transform(X)
st = np.sqrt(s)

pairs = product(Xt, Xt)
sizes = product(st, st)

for i, ((x, y), (sx, sy)) in enumerate(zip(pairs, sizes)):
    h = i % 3
    j = i // 3
    if h != j and overlap(x, y, sx, sy):
        print((i, h, j))

enter image description here

There's lots of room for improvement. It's probably easier to transform all your data and pass that into the points_overlap function instead of doing the transform inside. That'd be much better actually.

like image 165
TomAugspurger Avatar answered Oct 20 '22 15:10

TomAugspurger