Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Un-broadcasting Numpy arrays

In a large code base, I am using np.broadcast_to to broadcast arrays (just using simple examples here):

In [1]: x = np.array([1,2,3])

In [2]: y = np.broadcast_to(x, (2,1,3))

In [3]: y.shape
Out[3]: (2, 1, 3)

Elsewhere in the code, I use third-party functions that can operate in a vectorized way on Numpy arrays but that are not ufuncs. These functions don't understand broadcasting, which means that calling such a function on arrays like y is inefficient. Solutions such as Numpy's vectorize aren't good either because while they understand broadcasting, they introduce a for loop over the array elements which is then very inefficient.

Ideally, what I'd like to be able to do is to have a function, which we can call e.g. unbroadcast, that returns an array with a minimal shape that can be broadcasted back to the full size if needed. So e.g.:

In [4]: z = unbroadcast(y)

In [5]: z.shape
Out[5]: (1, 1, 3)

I can then run the third-party functions on z, then broadcast the result back to y.shape.

Is there a way to implement unbroadcast that relies on Numpy's public API? If not, are there any hacks that would produce the desired result?

like image 245
astrofrog Avatar asked Mar 11 '23 13:03

astrofrog


1 Answers

I have a possible solution, so will post it here (however if anyone has a better one, please feel free to reply too!). One solution is to check the strides argument of arrays, which will be 0 along broadcasted dimensions:

def unbroadcast(array):
    slices = []
    for i in range(array.ndim):
        if array.strides[i] == 0:
            slices.append(slice(0, 1))
        else:
            slices.append(slice(None))
    return array[slices]

This gives:

In [14]: unbroadcast(y).shape
Out[14]: (1, 1, 3)
like image 53
astrofrog Avatar answered Mar 15 '23 19:03

astrofrog