I am updating a 3d scatter plot with every iteration of a loop. When the plot is redrawn, the gridlines "go through" or "cover" the points, which makes my data more difficult to visualize. If I build a single 3d plot (no loop updating) this does not happen. The code below demonstrates the simplest case:
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:, 0], X[:, 1], X[:, 2])
plt.draw()
for i in range(0, 20):
time.sleep(3) #make changes more apparent/easy to see
Y = np.random.rand(100, 3)*5
ax.cla()
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2])
plt.draw()
Has anyone else encountered this problem?
It looks like MaxNoe is right in the sense that the problem is in the ax.cla()
or plt.cla()
call. In fact it seems it is something like a known issue.
Then there is a problem, since the clear axes method doesn't work in 3D plots and for 3D scatters there is no clean way to change the coordinates of the data points (a la sc.set_data(new_values)
), as discussed in this mail list (I didn't find anything more recent).
In the mail list, however, Ben Roon points to a workaround that might be useful for you, too.
You need to set the new coordinates of the datapoints in the internal _ofsets3d
variable of the Line3DCollection
object returned by the scatter
function.
Your example adapted would look like:
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
X = np.random.rand(100, 3)*10
Y = np.random.rand(100, 3)*5
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
sc = ax.scatter(X[:, 0], X[:, 1], X[:, 2])
fig.show()
for i in range(0, 20):
plt.pause(1)
Y = np.random.rand(100, 3)*5
sc._offsets3d = (Y[:,0], Y[:,1], Y[:,2])
plt.draw()
I could narrow it down to the use of cla():
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
x, y = np.meshgrid(np.linspace(-2,2), np.linspace(-2,2))
ax.plot_surface(x,y, x**2+y**2)
fig.savefig("fig_a.png")
ax.cla()
ax.plot_surface(x,y, x**2+y**2)
fig.savefig("fig_b.png")
these are the resulting plots:
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With