Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Pairwise vdot using Numpy

I'm trying to compute the pairwise np.vdot of a complex 2D array x with itself. So the behaviour I want is:

X = np.empty((x.shape[0], x.shape[0]), dtype='complex128')
for i in range(x.shape[0]):
    for j in range(x.shape[0]):
        X[i, j] = np.vdot(x[i], x[j])

Is there a way to do this without the explicit loops? I tried using pairwise_kernel from sklearn but it assumes the input arrays are real numbers. I also tried broadcasting, but vdot flattens its inputs.

like image 739
Peter Avatar asked Feb 08 '23 07:02


1 Answers

X = np.einsum('ik,jk->ij', np.conj(x), x)

is equivalent to

X = np.empty((x.shape[0], x.shape[0]), dtype='complex128')
for i in range(x.shape[0]):
    for j in range(x.shape[0]):
        X[i, j] = np.vdot(x[i], x[j])

np.einsum takes a sum of products. The subscript 'ik,jk->ij' tells np.einsum that the second argument, np.conj(x) is an array with subscripts ik and the third argument, x has subscripts jk. Thus, the product np.conj(x)[i,k]*x[j,k] is computed for all i,j,k. The sum is taken over the repeated subscript, k, and since that leaves i and j remaining, they become the subscripts of the resultant array.

For example,

import numpy as np

N, M = 10, 20
a = np.random.random((N,M))
b = np.random.random((N,M))
x = a + b*1j

def orig(x):
    X = np.empty((x.shape[0], x.shape[0]), dtype='complex128')
    for i in range(x.shape[0]):
        for j in range(x.shape[0]):
            X[i, j] = np.vdot(x[i], x[j])
    return X

def alt(x):
    return np.einsum('ik,jk->ij', np.conj(x), x)

assert np.allclose(orig(x), alt(x))

In [307]: %timeit orig(x)
10000 loops, best of 3: 143 µs per loop

In [308]: %timeit alt(x)
100000 loops, best of 3: 8.63 µs per loop
like image 69
unutbu Avatar answered Feb 10 '23 12:02
