Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

matplotlib (mplot3d) - how to increase the size of an axis (stretch) in a 3D Plot?

I have this so far:

x,y,z = data.nonzero()     fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(x, y, z, zdir='z', c= 'red') plt.savefig("plot.png") 

Which creates: enter image description here

What I'd like to do is stretch this out to make the Z axis 9 times taller and keep X and Y the same. I'd like to keep the same coordinates though.

So far I tried this guy:

fig = plt.figure(figsize=(4.,35.)) 

But that just stretches out the plot.png image.

like image 538
Greg Avatar asked May 13 '15 18:05

Greg


People also ask

How do you increase the size of a 3D plot in Python?

If we want our plots to be bigger or smaller than the default size, we can easily set the size of the plot either when initializing the figure – using the figsize parameter of the plt. figure method, or we can update the size of an existing plot by calling the set_size_inches method on the figure object.

How do I change the axis size in Matplotlib?

MatPlotLib with Python To change the range of X and Y axes, we can use xlim() and ylim() methods.

How do I zoom in Matplotlib 3D?

Matplotlib mplot3d toolkit One can rotate the 3D scene by simply clicking-and-dragging the scene. Zooming is done by right-clicking the scene and dragging the mouse up and down. Note that one does not use the zoom button like one would use for regular 2D plots.


2 Answers

The code example below provides a way to scale each axis relative to the others. However, to do so you need to modify the Axes3D.get_proj function. Below is an example based on the example provided by matplot lib: http://matplotlib.org/1.4.0/mpl_toolkits/mplot3d/tutorial.html#line-plots

(There is a shorter version at the end of this answer)

from mpl_toolkits.mplot3d.axes3d import Axes3D from mpl_toolkits.mplot3d import proj3d  import matplotlib as mpl import numpy as np import matplotlib.pyplot as plt  #Make sure these are floating point values:                                                                                                                                                                                               scale_x = 1.0 scale_y = 2.0 scale_z = 3.0  #Axes are scaled down to fit in scene                                                                                                                                                                                                     max_scale=max(scale_x, scale_y, scale_z)  scale_x=scale_x/max_scale scale_y=scale_y/max_scale scale_z=scale_z/max_scale  #Create scaling matrix                                                                                                                                                                                                                    scale = np.array([[scale_x,0,0,0],                   [0,scale_y,0,0],                   [0,0,scale_z,0],                   [0,0,0,1]]) print scale  def get_proj_scale(self):     """                                                                                                                                                                                                                                         Create the projection matrix from the current viewing position.                                                                                                                                                                              elev stores the elevation angle in the z plane                                                                                                                                                                                              azim stores the azimuth angle in the x,y plane                                                                                                                                                                                               dist is the distance of the eye viewing point from the object                                                                                                                                                                               point.                                                                                                                                                                                                                                       """     relev, razim = np.pi * self.elev/180, np.pi * self.azim/180      xmin, xmax = self.get_xlim3d()     ymin, ymax = self.get_ylim3d()     zmin, zmax = self.get_zlim3d()      # transform to uniform world coordinates 0-1.0,0-1.0,0-1.0                                                                                                                                                                                  worldM = proj3d.world_transformation(         xmin, xmax,         ymin, ymax,         zmin, zmax)      # look into the middle of the new coordinates                                                                                                                                                                                               R = np.array([0.5, 0.5, 0.5])      xp = R[0] + np.cos(razim) * np.cos(relev) * self.dist     yp = R[1] + np.sin(razim) * np.cos(relev) * self.dist     zp = R[2] + np.sin(relev) * self.dist     E = np.array((xp, yp, zp))      self.eye = E     self.vvec = R - E     self.vvec = self.vvec / proj3d.mod(self.vvec)      if abs(relev) > np.pi/2:     # upside down                                                                                                                                                                                                                                 V = np.array((0, 0, -1))     else:       V = np.array((0, 0, 1))     zfront, zback = -self.dist, self.dist      viewM = proj3d.view_transformation(E, R, V)     perspM = proj3d.persp_transformation(zfront, zback)     M0 = np.dot(viewM, worldM)     M = np.dot(perspM, M0)      return np.dot(M, scale);  Axes3D.get_proj=get_proj_scale  """ You need to include all the code above. From here on you should be able to plot as usual. """  mpl.rcParams['legend.fontsize'] = 10  fig = plt.figure(figsize=(5,5)) ax = fig.gca(projection='3d') theta = np.linspace(-4 * np.pi, 4 * np.pi, 100) z = np.linspace(-2, 2, 100) r = z**2 + 1 x = r * np.sin(theta) y = r * np.cos(theta) ax.plot(x, y, z, label='parametric curve') ax.legend()  plt.show() 

Standard output:

Normal Scale

Scaled by (1, 2, 3):

Scale_x=1, Scale_y=2, Scale_z=3

Scaled by (1, 1, 3):

Scale_x=1, Scale_y=1, Scale_z=3

The reason I particularly like this method, Swap z and x, scale by (3, 1, 1):

Swap z and x, scale_x=4

Below is a shorter version of the code.

from mpl_toolkits.mplot3d.axes3d import Axes3D from mpl_toolkits.mplot3d import proj3d  import matplotlib as mpl import numpy as np import matplotlib.pyplot as plt  mpl.rcParams['legend.fontsize'] = 10  fig = plt.figure(figsize=(5,5)) ax = fig.gca(projection='3d') theta = np.linspace(-4 * np.pi, 4 * np.pi, 100) z = np.linspace(-2, 2, 100) r = z**2 + 1 x = r * np.sin(theta) y = r * np.cos(theta)   """                                                                                                                                                     Scaling is done from here...                                                                                                                            """ x_scale=1 y_scale=1 z_scale=2  scale=np.diag([x_scale, y_scale, z_scale, 1.0]) scale=scale*(1.0/scale.max()) scale[3,3]=1.0  def short_proj():   return np.dot(Axes3D.get_proj(ax), scale)  ax.get_proj=short_proj """                                                                                                                                                     to here                                                                                                                                                 """  ax.plot(z, y, x, label='parametric curve') ax.legend()  plt.show() 
like image 76
Christian Sarofeen Avatar answered Sep 25 '22 17:09

Christian Sarofeen


Please note that the answer below simplifies the patch, but uses the same underlying principle as the answer by @ChristianSarofeen.

Solution

As already indicated in other answers, it is not a feature that is currently implemented in matplotlib. However, since what you are requesting is simply a 3D transformation that can be applied to the existing projection matrix used by matplotlib, and thanks to the wonderful features of Python, this problem can be solved with a simple oneliner:

ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([scale_x, scale_y, scale_z, 1])) 

where scale_x, scale_y and scale_z are values from 0 to 1 that will re-scale your plot along each of the axes accordingly. ax is simply the 3D axes which can be obtained with ax = fig.gca(projection='3d')

Explanation

To explain, the function get_proj of Axes3D generates the projection matrix from the current viewing position. Multiplying it by a scaling matrix:

scale_x, 0,       0 0,       scale_y, 0 0,       0,       scale_z 0,       0,       1 

includes the scaling into the projection used by the renderer. So, what we are doing here is substituting the original get_proj function with an expression taking the result of the original get_proj and multiplying it by the scaling matrix.

Example

To illustrate the result with the standard parametric function example:

from mpl_toolkits.mplot3d import Axes3D import numpy as np import matplotlib.pyplot as plt  fig = plt.figure() ax = fig.gca(projection='3d') theta = np.linspace(-4 * np.pi, 4 * np.pi, 100) z = np.linspace(-2, 2, 100) r = z ** 2 + 1 x = r * np.sin(theta) y = r * np.cos(theta)  # OUR ONE LINER ADDED HERE: ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([0.5, 0.5, 1, 1]))  ax.plot(x, y, z) plt.show() 

for values 0.5, 0.5, 1, we get:

enter image description here

while for values 0.2, 1.0, 0.2, we get:

enter image description here

like image 24
Andrzej Pronobis Avatar answered Sep 24 '22 17:09

Andrzej Pronobis