Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there unstack in NumPy?

Tags:

python

numpy

There is np.stack in NumPy, but is there an opposite np.unstack same as tf.unstack?

like image 619
Arty Avatar asked Sep 28 '20 07:09

Arty


People also ask

What is NumPy stacking?

Stacking is the concept of joining arrays in NumPy. Arrays having the same dimensions can be stacked. The stacking is done along a new axis. Stacking leads to increased customization of arrays. We can combine the stack function with other functions to further increase its capabilities.

How do I stack rows in NumPy?

NumPy: vstack() function The vstack() function is used to stack arrays in sequence vertically (row wise). This is equivalent to concatenation along the first axis after 1-D arrays of shape (N,) have been reshaped to (1,N). The arrays must have the same shape along all but the first axis.

What is split function in NumPy?

Splitting NumPy Arrays Splitting is reverse operation of Joining. Joining merges multiple arrays into one and Splitting breaks one array into multiple. We use array_split() for splitting arrays, we pass it the array we want to split and the number of splits.


3 Answers

Coming across this late, here is a much simpler answer:

def unstack(a, axis=0):
    return np.moveaxis(a, axis, 0)
#    return list(np.moveaxis(a, axis, 0))

As a bonus, the result is still a numpy array. The unwrapping happens if you just python-unwrap it:

A, B, = unstack([[1, 2], [3, 4]], axis=1)
assert list(A) == [1, 3]
assert list(B) == [2, 4]

Unsurprisingly, it is also the fastest:

# np.squeeze
❯ python -m timeit -s "import numpy as np; a=np.array(np.meshgrid(np.arange(1000), np.arange(1000)));" "C = [np.squeeze(e, 1) for e in np.split(a, a.shape[1], axis = 1)]"
100 loops, best of 5: 2.64 msec per loop
    
# np.take
❯ python -m timeit -s "import numpy as np; a=np.array(np.meshgrid(np.arange(1000), np.arange(1000)));" "C = [np.take(a, i, axis = 1) for i in range(a.shape[1])]"       
50 loops, best of 5: 5.08 msec per loop

# np.moveaxis
❯ python -m timeit -s "import numpy as np; a=np.array(np.meshgrid(np.arange(1000), np.arange(1000)));" "C = np.moveaxis(a, 1, 0)"
100000 loops, best of 5: 3.89 usec per loop

# list(np.moveaxis)
❯ python -m timeit -s "import numpy as np; a=np.array(np.meshgrid(np.arange(1000), np.arange(1000)));" "C = list(np.moveaxis(a, 1, 0))"
1000 loops, best of 5: 205 usec per loop
like image 163
Ivorius Avatar answered Oct 17 '22 01:10

Ivorius


Thanks to suggestion of @hpaulj solved task efficiently using np.split.

Try it online!

import numpy as np

def unstack(a, axis = 0):
    return [np.squeeze(e, axis) for e in np.split(a, a.shape[axis], axis = axis)]

a = [np.array([[1,2,3],[4,5,6]]), np.array([[7,8,9],[10,11,12]])]

for axis in range(len(a[0].shape) + 1):
    b = np.stack(a, axis)
    c = unstack(b, axis)
    # Check that we have same "c" as input "a"
    assert len(c) == len(a) and all(np.all(sc == sa) for sc, sa in zip(c, a)), (c, a)
like image 4
Arty Avatar answered Oct 17 '22 03:10

Arty


it's simple like:

list(array)

it will iterate over the first axis.

an example:

a = np.array([1,2,3])
b = np.array([4,5,6])
c = np.stack([a,b])
d = list(c)

Output of c

array([[1, 2, 3],
       [4, 5, 6]])

Output of d

[array([1,2,3]), array([4,5,6])]
like image 1
Gabriel Barbosa Avatar answered Oct 17 '22 01:10

Gabriel Barbosa