Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to break numpy array into smaller chunks/batches, then iterate through them

Suppose i have this numpy array

[[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]]

And i want to split it in 2 batches and then iterate:

[[1, 2, 3],      Batch 1
[4, 5, 6]]

[[7, 8, 9],      Batch 2
[10, 11, 12]]

What is the simplest way to do it?

EDIT: I'm deeply sorry i missed putting such info: Once i intend to carry on with the iteration, the original array would be destroyed due to splitting and iterating over batches. Once the batch iteration finished, i need to restart again from the first batch hence I should preserve that the original array wouldn't be destroyed. The whole idea is to be consistent with Stochastic Gradient Descent algorithms which require iterations over batches. In a typical example, I could have a 100000 iteration For loop for just 1000 batch that should be replayed again and again.

like image 784
Leb_Broth Avatar asked Sep 21 '16 17:09

Leb_Broth


People also ask

How do you split an array into multiple arrays in python?

Splitting NumPy Arrays Splitting is reverse operation of Joining. Joining merges multiple arrays into one and Splitting breaks one array into multiple. We use array_split() for splitting arrays, we pass it the array we want to split and the number of splits.

How do you split an array into two parts in python?

To split a list into n parts in Python, use the numpy. array_split() function. The np. split() function splits the array into multiple sub-arrays.

How do I split a NumPy array horizontally?

hsplit() function. The hsplit() function is used to split an array into multiple sub-arrays horizontally (column-wise). hsplit is equivalent to split with axis=1, the array is always split along the second axis regardless of the array dimension.

How do you split the element of a given NumPy array with spaces?

To split the elements of a given array with spaces we will use numpy. char. split(). It is a function for doing string operations in NumPy.


1 Answers

You can use numpy.split to split along the first axis n times, where n is the number of desired batches. Thus, the implementation would look like this -

np.split(arr,n,axis=0) # n is number of batches

Since, the default value for axis is 0 itself, so we can skip setting it. So, we would simply have -

np.split(arr,n)

Sample runs -

In [132]: arr  # Input array of shape (10,3)
Out[132]: 
array([[170,  52, 204],
       [114, 235, 191],
       [ 63, 145, 171],
       [ 16,  97, 173],
       [197,  36, 246],
       [218,  75,  68],
       [223, 198,  84],
       [206, 211, 151],
       [187, 132,  18],
       [121, 212, 140]])

In [133]: np.split(arr,2) # Split into 2 batches
Out[133]: 
[array([[170,  52, 204],
        [114, 235, 191],
        [ 63, 145, 171],
        [ 16,  97, 173],
        [197,  36, 246]]), array([[218,  75,  68],
        [223, 198,  84],
        [206, 211, 151],
        [187, 132,  18],
        [121, 212, 140]])]

In [134]: np.split(arr,5) # Split into 5 batches
Out[134]: 
[array([[170,  52, 204],
        [114, 235, 191]]), array([[ 63, 145, 171],
        [ 16,  97, 173]]), array([[197,  36, 246],
        [218,  75,  68]]), array([[223, 198,  84],
        [206, 211, 151]]), array([[187, 132,  18],
        [121, 212, 140]])]
like image 200
Divakar Avatar answered Oct 01 '22 12:10

Divakar