Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to flatten only some dimensions of a numpy array

People also ask

How do I flatten a dimension in numpy?

Flatten a NumPy array with reshape(-1) You can also use reshape() to convert the shape of a NumPy array to one dimension. If you use -1 , the size is calculated automatically, so you can flatten a NumPy array with reshape(-1) . reshape() is provided as a method of numpy.

How do you flatten a numpy array in Python?

flatten() function we can flatten a matrix to one dimension in python. order:'C' means to flatten in row-major. 'F' means to flatten in column-major. 'A' means to flatten in column-major order if a is Fortran contiguous in memory, row-major order otherwise.


Take a look at numpy.reshape .

>>> arr = numpy.zeros((50,100,25))
>>> arr.shape
# (50, 100, 25)

>>> new_arr = arr.reshape(5000,25)
>>> new_arr.shape   
# (5000, 25)

# One shape dimension can be -1. 
# In this case, the value is inferred from 
# the length of the array and remaining dimensions.
>>> another_arr = arr.reshape(-1, arr.shape[-1])
>>> another_arr.shape
# (5000, 25)

A slight generalization to Alexander's answer - np.reshape can take -1 as an argument, meaning "total array size divided by product of all other listed dimensions":

e.g. to flatten all but the last dimension:

>>> arr = numpy.zeros((50,100,25))
>>> new_arr = arr.reshape(-1, arr.shape[-1])
>>> new_arr.shape
# (5000, 25)

A slight generalization to Peter's answer -- you can specify a range over the original array's shape if you want to go beyond three dimensional arrays.

e.g. to flatten all but the last two dimensions:

arr = numpy.zeros((3, 4, 5, 6))
new_arr = arr.reshape(-1, *arr.shape[-2:])
new_arr.shape
# (12, 5, 6)

EDIT: A slight generalization to my earlier answer -- you can, of course, also specify a range at the beginning of the of the reshape too:

arr = numpy.zeros((3, 4, 5, 6, 7, 8))
new_arr = arr.reshape(*arr.shape[:2], -1, *arr.shape[-2:])
new_arr.shape
# (3, 4, 30, 7, 8)

An alternative approach is to use numpy.resize() as in:

In [37]: shp = (50,100,25)
In [38]: arr = np.random.random_sample(shp)
In [45]: resized_arr = np.resize(arr, (np.prod(shp[:2]), shp[-1]))
In [46]: resized_arr.shape
Out[46]: (5000, 25)

# sanity check with other solutions
In [47]: resized = np.reshape(arr, (-1, shp[-1]))
In [48]: np.allclose(resized_arr, resized)
Out[48]: True