Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy: make batched version of quaternion multiplication

I transformed the following function

def quaternion_multiply(quaternion0, quaternion1):
    """Return multiplication of two quaternions.

    >>> q = quaternion_multiply([1, -2, 3, 4], [-5, 6, 7, 8])
    >>> numpy.allclose(q, [-44, -14, 48, 28])
    True

    """
    x0, y0, z0, w0 = quaternion0
    x1, y1, z1, w1 = quaternion1
    return numpy.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=numpy.float64)

to a batched version

def quat_multiply(self, quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, 1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, 1)

    result = np.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=np.float64)
    return np.transpose(np.squeeze(result))

This function handles quaternion1 and quaternion0 with shape (?,4). Now I want that the function can handle an arbitrary number of dimensions, such as (?,?,4). How to do this?

like image 811
Derk Avatar asked Dec 01 '16 15:12

Derk


3 Answers

You can get the behavior you are after by simply passing axis-=-1 to np.split to split along the last axis.

And since your arrays have that annoying size 1 trailing dimension, rather than stacking along a new dimension, then squeezing that one away, you can simply concatenate them, again along the (last) axis=-1:

def quat_multiply(self, quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, axis=-1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, axis=-1)
    return np.concatenate(
        (x1*w0 + y1*z0 - z1*y0 + w1*x0,
         -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
         -x1*x0 - y1*y0 - z1*z0 + w1*w0),
        axis=-1)

Note that, with this approach, not only can you multiply identically shaped quaternion stacks of any number of dimensions:

>>> a = np.random.rand(6, 5, 4)
>>> b = np.random.rand(6, 5, 4)
>>> quat_multiply(None, a, b).shape
(6, 5, 4)

But you also get the nice broadcasting that allows you to i.e. multiply a stack of quaternions with a single one without having to fiddle with the dimensions:

>>> a = np.random.rand(6, 5, 4)
>>> b = np.random.rand(4)
>>> quat_multiply(None, a, b).shape
(6, 5, 4)

Or with minimal fiddling do all cross products between two stacks in a single line:

>>> a = np.random.rand(6, 4)
>>> b = np.random.rand(5, 4)
>>> quat_multiply(None, a[:, None], b).shape
(6, 5, 4)
like image 50
Jaime Avatar answered Nov 15 '22 15:11

Jaime


You could make use of np.rollaxis to bring the last axis to the front, helping us in slicing out the 4 arrays without actually splitting them. We perform the required operations and finally send back the first axis to the end to keep the output array shape same as the inputs. Thus, we would have a solution for generic n-dimensional ndarrays, like so -

def quat_multiply_ndim(quaternion0, quaternion1):
    x0, y0, z0, w0 = np.rollaxis(quaternion0, -1, 0)
    x1, y1, z1, w1 = np.rollaxis(quaternion1, -1, 0)
    result = np.array((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), dtype=np.float64)
    return np.rollaxis(result,0, result.ndim)

Sample run -

In [107]: # N-dim arrays
     ...: a1 = np.random.randint(0,9,(2,3,2,4))
     ...: b1 = np.random.randint(0,9,(2,3,2,4))
     ...: 

In [108]: quat_multiply_ndim(a1,b1) # New ndim approach
Out[108]: 
array([[[[ 154.,   48.,   55.,  -57.],
         [  31.,   81.,   29.,  -95.]],

        [[  31.,   14.,   88.,   12.],
         [   3.,   30.,   20.,  -51.]],

        [[ 104.,   61.,  102.,  -39.],
         [   0.,   14.,   14.,  -56.]]],


       [[[ -28.,   36.,   24.,   -8.],
         [  11.,   76.,   -7.,  -36.]],

        [[  54.,    3.,   -2.,  -19.],
         [  52.,   62.,   15.,  -55.]],

        [[  76.,   28.,   28.,  -60.],    <--------|
         [  14.,   54.,   13.,    5.]]]])          |
                                                   |
In [109]: quat_multiply(a1[1,2],b1[1,2]) # Old 2D approach
Out[109]:                                          |
array([[ 76.,  28.,  28., -60.], ------------------|
       [ 14.,  54.,  13.,   5.]])
like image 35
Divakar Avatar answered Nov 15 '22 15:11

Divakar


You're almost there! You just need to be a little careful about how you're splitting and concatenating your array:

def quat_multiply(quaternion0, quaternion1):
    x0, y0, z0, w0 = np.split(quaternion0, 4, axis=-1)
    x1, y1, z1, w1 = np.split(quaternion1, 4, axis=-1)

    return np.squeeze(np.stack((
         x1*w0 + y1*z0 - z1*y0 + w1*x0,
        -x1*z0 + y1*w0 + z1*x0 + w1*y0,
         x1*y0 - y1*x0 + z1*w0 + w1*z0,
        -x1*x0 - y1*y0 - z1*z0 + w1*w0), axis=-1), axis=-2)

Here, we're using axis=-1 both times to split along the last axis, and then concatenate back along the last axis. Finally, we squeeze out the second-to-last axis, as you correctly noticed. And to show you that it works:

>>> q0 = np.array([-5, 6, 7, 8])
>>> q1 = np.array([1, -2, 3, 4])
>>> q0 = np.tile(q1, (2, 2, 1))
>>> q0
array([[[-5,  6,  7,  8],
        [-5,  6,  7,  8]],
       [[-5,  6,  7,  8],
        [-5,  6,  7,  8]]])
>>> q1 = np.tile(q2, (2, 2, 1))
>>> q = quat_multiply(q0, q1)
array([[[-44, -14,  48,  28],
        [-44, -14,  48,  28]],
       [[-44, -14,  48,  28],
        [-44, -14,  48,  28]]])
>>> q.shape
(2, 2, 4)

Hope that's what you needed! This should work on arbitrary dimensions, and arbitrary number of dimensions.

Note: np.split appears not to work on lists. So you can only pass arrays to your new function, as I've done above. If you want to be able to pass lists, you can instead call

 np.split(np.asarray(quaternion0), 4, -1)

inside your function.

Also, your test case appears to be wrong. I think you've swapped the positions of quaternion0 and quaternion1: I've swapped them back above while testing q0 and q1.

like image 21
Praveen Avatar answered Nov 15 '22 17:11

Praveen