Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Updating a pyplot 3d scatter plot in a loop, grid lines overlap points

I am updating a 3d scatter plot with every iteration of a loop. When the plot is redrawn, the gridlines "go through" or "cover" the points, which makes my data more difficult to visualize. If I build a single 3d plot (no loop updating) this does not happen. The code below demonstrates the simplest case:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2])
plt.draw()

for i in range(0, 20):
    time.sleep(3)   #make changes more apparent/easy to see

    Y = np.random.rand(100, 3)*5
    ax.cla()    
    ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2])
    plt.draw()

Has anyone else encountered this problem?

like image 434
NLi10Me Avatar asked Feb 06 '14 22:02

NLi10Me


2 Answers

It looks like MaxNoe is right in the sense that the problem is in the ax.cla()or plt.cla() call. In fact it seems it is something like a known issue.

Then there is a problem, since the clear axes method doesn't work in 3D plots and for 3D scatters there is no clean way to change the coordinates of the data points (a la sc.set_data(new_values)), as discussed in this mail list (I didn't find anything more recent).

In the mail list, however, Ben Roon points to a workaround that might be useful for you, too.

Workaround:

You need to set the new coordinates of the datapoints in the internal _ofsets3d variable of the Line3DCollectionobject returned by the scatter function.

Your example adapted would look like:

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time

X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5

plt.ion()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2])
fig.show()

for i in range(0, 20):
    plt.pause(1)

    Y = np.random.rand(100, 3)*5

    sc._offsets3d = (Y[:,0], Y[:,1], Y[:,2])
    plt.draw()
like image 125
mgab Avatar answered Sep 28 '22 08:09

mgab


I could narrow it down to the use of cla():

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))

ax.plot_surface(x,y, x**2+y**2)
fig.savefig("fig_a.png")

ax.cla()
ax.plot_surface(x,y, x**2+y**2)

fig.savefig("fig_b.png")

these are the resulting plots: fig_afig_b

like image 40
MaxNoe Avatar answered Sep 28 '22 08:09

MaxNoe