Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

mayavi 3d object in matplotlib Axes3D

I sometimes find myself frustrated with the lack of certain rendering features in matplotlib's mplot3d. In most of these cases, I do find that I can get what I want in mayavi, but still the matplotlib 3d axes are preferable, if only for aesthetics, like LaTeX-ified labels and visual consistency with my other figures.

My question here is about the obvious hack: is it possible to draw some 3d object (a surface or 3d scatter plot or whatever) in mayavi without axes, export that image, then place it in a matplotlib Axes3D of correct size, orientation, coordinate projection, etc.? Can anyone think of an outline of what would be needed to accomplish this, or perhaps even offer a skeleton solution?

I fiddled around with this some time ago and found I had no trouble in exporting a transparent background mayavi figure and placing it in an empty matplotlib Axes3D (with ticks, labels, and so on), but I didn't get far in getting the camera configurations of mayavi and matplotlib to match. Simply setting the three common parameters of azimuth, elevation, and distance equal in both environments didn't do the trick; presumably what's needed is some consideration of the perspective (or other) transformations going on to render the whole scene, and I'm fairly clueless in that area.

It seems like this might be useful: http://docs.enthought.com/mayavi/mayavi/auto/example_mlab_3D_to_2D.html

like image 350
tsj Avatar asked Nov 01 '22 06:11

tsj


1 Answers

I produced a proof-of-concept solution for Mayavi -> PGFPlots using the mlab_3D_to_2D.py example and the "Support for External Three-Dimensional Graphics" section of the PGFPlots manual.

Procedure:

  1. Run the modified mlab_3D_to_2D.py with Mayavi to generate img.png. Four random points are printed to the console, copy these to the clipboard. Note the figure size and resolution are hard-coded into the script, these shoud be edited or automatically extracted for different image sizes.
  2. Paste the points into mlab_pgf.tex.
  3. Run LaTeX on mlab_pgf.tex.

Result:

enter image description here

Modified mlab_3D_to_2D.py:

# Modified mlab_3D_to_2D.py from https://docs.enthought.com/mayavi/mayavi/auto/example_mlab_3D_to_2D.html

# Original copyright notice:
# Author: S. Chris Colbert <[email protected]>
# Copyright (c) 2009, S. Chris Colbert
# License: BSD Style

from __future__ import print_function

# this import is here because we need to ensure that matplotlib uses the
# wx backend and having regular code outside the main block is PyTaboo.
# It needs to be imported first, so that matplotlib can impose the
# version of Wx it requires.
import matplotlib
# matplotlib.use('WXAgg')
import pylab as pl


import numpy as np
from mayavi import mlab
from mayavi.core.ui.mayavi_scene import MayaviScene

def get_world_to_view_matrix(mlab_scene):
    """returns the 4x4 matrix that is a concatenation of the modelview transform and
    perspective transform. Takes as input an mlab scene object."""

    if not isinstance(mlab_scene, MayaviScene):
        raise TypeError('argument must be an instance of MayaviScene')


    # The VTK method needs the aspect ratio and near and far clipping planes
    # in order to return the proper transform. So we query the current scene
    # object to get the parameters we need.
    scene_size = tuple(mlab_scene.get_size())
    clip_range = mlab_scene.camera.clipping_range
    aspect_ratio = float(scene_size[0])/float(scene_size[1])

    # this actually just gets a vtk matrix object, we can't really do anything with it yet
    vtk_comb_trans_mat = mlab_scene.camera.get_composite_projection_transform_matrix(
                                aspect_ratio, clip_range[0], clip_range[1])

     # get the vtk mat as a numpy array
    np_comb_trans_mat = vtk_comb_trans_mat.to_array()

    return np_comb_trans_mat


def get_view_to_display_matrix(mlab_scene):
    """ this function returns a 4x4 matrix that will convert normalized
        view coordinates to display coordinates. It's assumed that the view should
        take up the entire window and that the origin of the window is in the
        upper left corner"""

    if not (isinstance(mlab_scene, MayaviScene)):
        raise TypeError('argument must be an instance of MayaviScene')

    # this gets the client size of the window
    x, y = tuple(mlab_scene.get_size())

    # normalized view coordinates have the origin in the middle of the space
    # so we need to scale by width and height of the display window and shift
    # by half width and half height. The matrix accomplishes that.
    view_to_disp_mat = np.array([[x/2.0,      0.,   0.,   x/2.0],
                                 [   0.,  -y/2.0,   0.,   y/2.0],
                                 [   0.,      0.,   1.,      0.],
                                 [   0.,      0.,   0.,      1.]])

    return view_to_disp_mat


def apply_transform_to_points(points, trans_mat):
    """a function that applies a 4x4 transformation matrix to an of
        homogeneous points. The array of points should have shape Nx4"""

    if not trans_mat.shape == (4, 4):
        raise ValueError('transform matrix must be 4x4')

    if not points.shape[1] == 4:
        raise ValueError('point array must have shape Nx4')

    return np.dot(trans_mat, points.T).T

def test_surf():
    """Test surf on regularly spaced co-ordinates like MayaVi."""
    def f(x, y):
        sin, cos = np.sin, np.cos
        return sin(x + y) + sin(2 * x - y) + cos(3 * x + 4 * y)

    x, y = np.mgrid[-7.:7.05:0.1, -5.:5.05:0.05]
    z = f(x, y)
    s = mlab.surf(x, y, z)
    #cs = contour_surf(x, y, f, contour_z=0)
    return x, y, z, s

if __name__ == '__main__':
    f = mlab.figure()
    f.scene.parallel_projection = True

    N = 4

    # x, y, z, m = test_mesh()
    x, y, z, s = test_surf()

    mlab.move(forward=2.0)

    # now were going to create a single N x 4 array of our points
    # adding a fourth column of ones expresses the world points in
    # homogenous coordinates
    W = np.ones(x.flatten().shape)
    hmgns_world_coords = np.column_stack((x.flatten(), y.flatten(), z.flatten(), W))

    # applying the first transform will give us 'unnormalized' view
    # coordinates we also have to get the transform matrix for the
    # current scene view
    comb_trans_mat = get_world_to_view_matrix(f.scene)
    view_coords = \
            apply_transform_to_points(hmgns_world_coords, comb_trans_mat)

    # to get normalized view coordinates, we divide through by the fourth
    # element
    norm_view_coords = view_coords / (view_coords[:, 3].reshape(-1, 1))

    # the last step is to transform from normalized view coordinates to
    # display coordinates.
    view_to_disp_mat = get_view_to_display_matrix(f.scene)
    disp_coords = apply_transform_to_points(norm_view_coords, view_to_disp_mat)

    # at this point disp_coords is an Nx4 array of homogenous coordinates
    # where X and Y are the pixel coordinates of the X and Y 3D world
    # coordinates, so lets take a screenshot of mlab view and open it
    # with matplotlib so we can check the accuracy
    img = mlab.screenshot(figure=f, mode='rgba', antialiased=True)
    pl.imsave("img.png", img)
    pl.imshow(img)
    # mlab.close(f)

    idx = np.random.choice(range(disp_coords[:, 0:2].shape[0]), N, replace=False)

    for i in idx:
        # print('Point %d:  (x, y) ' % i, disp_coords[:, 0:2][i], hmgns_world_coords[:, 0:3][i])
        a = hmgns_world_coords[:, 0:3][i]
        a = str(list(a)).replace('[', '(').replace(']', ')').replace('  ',',')
        # See note below about 298.
        b = np.array([0, 298]) - disp_coords[:, 0:2][i]
        b = b * np.array([-1, 1])
        # Important! These values are not constant.
        # The image is 400 x 298 pixels, or 288 x 214.6 pt.
        b[0] = b[0] / 400 * 288
        b[1] = b[1] / 298 * 214.6
        b = str(list(b)).replace('[', '(').replace(']', ')').replace('  ',',')
        print(a, "=>", b)
        pl.plot([disp_coords[:, 0][i]], [disp_coords[:, 1][i]], 'ro')

    pl.show()

    # you should check that the printed coordinates correspond to the
    # proper points on the screen

    mlab.show()

#EOF

mlab_pgf.py:

\documentclass{standalone}

\usepackage{pgfplots}
\pgfplotsset{compat=1.17}

\begin{document}

\begin{tikzpicture}
\begin{axis}[
  grid=both,minor tick num=1,
  xlabel=$x$,ylabel=$y$,zlabel=$z$,
  xmin=-7,
  xmax=7,
  ymin=-5,
  ymax=5,
  zmin=-3,
  zmax=3,
  ]
  \addplot3 graphics [
  points={% important, paste points generated by `mlab_3D_to_2D.py`
    (5.100000000000001, -3.8, 2.9491697063900895) => (69.82857610254948, 129.60245304203693)
    (-6.2, -3.0999999999999996, 0.6658335107904079) => (169.834990346303, 158.6375879061911)
    (-1.7999999999999998, 0.4500000000000002, -1.0839565197346115) => (162.75120267070378, 103.53696636434113)
    (-5.3, -4.9, 0.6627774166307937) => (147.33354714145847, 162.93938533017257)
  },
  ] {img.png};
\end{axis}
\end{tikzpicture}

\end{document}
like image 180
tsj Avatar answered Nov 15 '22 03:11

tsj