Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I calculate all pairs of vector differences in numpy?

I know I can do np.subtract.outer(x, x). If x has shape (n,), then I end up with an array with shape (n, n). However, I have an x with shape (n, 3). I want to output something with shape (n, n, 3). How do I do this? Maybe np.einsum?

like image 818
Neil G Avatar asked Sep 09 '15 07:09

Neil G


People also ask

How do you count the number of different values in a numpy array?

To count each unique element's number of occurrences in the numpy array, we can use the numpy. unique() function. It takes the array as an input argument and returns all the unique elements inside the array in ascending order.

How do I compare vectors in Numpy?

The easiest way to compare two NumPy arrays is to: Create a comparison array by calling == between two arrays. Call . all() method for the result array object to check if the elements are True.

What is diff () in Numpy give example?

diff(arr[, n[, axis]]) function is used when we calculate the n-th order discrete difference along the given axis. The first order difference is given by out[i] = arr[i+1] – arr[i] along the given axis. If we have to calculate higher differences, we are using diff recursively.


1 Answers

You can use broadcasting after extending the dimensions with None/np.newaxis to form a 3D array version of x and subtracting the original 2D array version from it, like so -

x[:, np.newaxis, :] - x

Sample run -

In [6]: x
Out[6]: 
array([[6, 5, 3],
       [4, 3, 5],
       [0, 6, 7],
       [8, 4, 1]])

In [7]: x[:,None,:] - x
Out[7]: 
array([[[ 0,  0,  0],
        [ 2,  2, -2],
        [ 6, -1, -4],
        [-2,  1,  2]],

       [[-2, -2,  2],
        [ 0,  0,  0],
        [ 4, -3, -2],
        [-4, -1,  4]],

       [[-6,  1,  4],
        [-4,  3,  2],
        [ 0,  0,  0],
        [-8,  2,  6]],

       [[ 2, -1, -2],
        [ 4,  1, -4],
        [ 8, -2, -6],
        [ 0,  0,  0]]])
like image 84
Divakar Avatar answered Oct 18 '22 01:10

Divakar