Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

connecting all numpy array plot points to each other using plt.plot() from matplotlib

I have a numpy array with xy co-ordinates for points. I have plotted each of these points and want a line connecting each point to every other point (a complete graph). The array is a 2x50 structure so I have transposed it and used a view to let me iterate through the rows. However, I am getting an 'index out of bounds' error with the following:

     plt.plot(*zip(*v.T)) #to plot all the points
     viewVX = (v[0]).T
     viewVY = (v[1]).T
     for i in range(0, 49):
        xPoints = viewVX[i], viewVX[i+1]
        print("xPoints is", xPoints)
        yPoints = viewVY[i+2], viewVY[i+3]
        print("yPoints is", yPoints)
        xy = xPoints, yPoints
        plt.plot(*zip(*xy), ls ='-')

I was hoping that the indexing would 'wrap-around' so that for the ypoints, it'd start with y0, y1 etc. Is there an easier way to accomplish what I'm trying to achieve?

like image 498
wot Avatar asked Feb 21 '23 23:02

wot


2 Answers

import matplotlib.pyplot as plt
import numpy as np
import itertools

v=np.random.random((2,50))
plt.plot(
    *zip(*itertools.chain.from_iterable(itertools.combinations(v.T,2))),
    marker='o', markerfacecolor='red')
plt.show()

The advantage of doing it this way is that there are fewer calls to plt.plot. This should be significantly faster than methods that make O(N**2) calls to plt.plot.

Note also that you do not need to plot the points separately. Instead, you can use the marker='o' parameter.


Explanation: I think the easiest way to understand this code is to see how it operates on a simple v:

In [4]: import numpy as np
In [5]: import itertools
In [7]: v=np.arange(8).reshape(2,4)
In [8]: v
Out[8]: 
array([[0, 1, 2, 3],
       [4, 5, 6, 7]])

itertools.combinations(...,2) generates all possible pairs of points:

In [10]: list(itertools.combinations(v.T,2))
Out[10]: 
[(array([0, 4]), array([1, 5])),
 (array([0, 4]), array([2, 6])),
 (array([0, 4]), array([3, 7])),
 (array([1, 5]), array([2, 6])),
 (array([1, 5]), array([3, 7])),
 (array([2, 6]), array([3, 7]))]

Now we use itertools.chain.from_iterable to convert this list of pairs of points into a (flattened) list of points:

In [11]: list(itertools.chain.from_iterable(itertools.combinations(v.T,2)))
Out[11]: 
[array([0, 4]),
 array([1, 5]),
 array([0, 4]),
 array([2, 6]),
 array([0, 4]),
 array([3, 7]),
 array([1, 5]),
 array([2, 6]),
 array([1, 5]),
 array([3, 7]),
 array([2, 6]),
 array([3, 7])]

If we plot these points one after another, connected by lines, we get our complete graph. The only problem is that plt.plot(x,y) expects x to be a sequence of x-values, and y to be a sequence of y-values.

We can use zip to convert the list of points into a list of x-values and y-values:

In [12]: zip(*itertools.chain.from_iterable(itertools.combinations(v.T,2)))
Out[12]: [(0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3), (4, 5, 4, 6, 4, 7, 5, 6, 5, 7, 6, 7)]

The use of the splat operator (*) in zip and plt.plot is explained here.

Thus we've managed to massage the data into the right form to be fed to plt.plot.

like image 192
unutbu Avatar answered Feb 24 '23 21:02

unutbu


With a 2 by 50 array,

 for i in range(0, 49):
    xPoints = viewVX[i], viewVX[i+1]
    print("xPoints is", xPoints)
    yPoints = viewVY[i+2], viewVY[i+3]

would get out of bounds for i = 47 and i = 48 since you use i+2 and i+3 as indices into viewVY.

like image 35
Daniel Fischer Avatar answered Feb 24 '23 20:02

Daniel Fischer