Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to implement tensorflow's next_batch for own data

In the tensorflow MNIST tutorial the mnist.train.next_batch(100) function comes very handy. I am now trying to implement a simple classification myself. I have my training data in a numpy array. How could I implement a similar function for my own data to give me the next batch?

sess = tf.InteractiveSession() tf.global_variables_initializer().run() Xtr, Ytr = loadData() for it in range(1000):     batch_x = Xtr.next_batch(100)     batch_y = Ytr.next_batch(100) 
like image 912
blckbird Avatar asked Dec 06 '16 11:12

blckbird


2 Answers

The link you posted says: "we get a "batch" of one hundred random data points from our training set". In my example I use a global function (not a method like in your example) so there will be a difference in syntax.

In my function you'll need to pass the number of samples wanted and the data array.

Here is the correct code, which ensures samples have correct labels:

import numpy as np  def next_batch(num, data, labels):     '''     Return a total of `num` random samples and labels.      '''     idx = np.arange(0 , len(data))     np.random.shuffle(idx)     idx = idx[:num]     data_shuffle = [data[ i] for i in idx]     labels_shuffle = [labels[ i] for i in idx]      return np.asarray(data_shuffle), np.asarray(labels_shuffle)  Xtr, Ytr = np.arange(0, 10), np.arange(0, 100).reshape(10, 10) print(Xtr) print(Ytr)  Xtr, Ytr = next_batch(5, Xtr, Ytr) print('\n5 random samples') print(Xtr) print(Ytr) 

And a demo run:

[0 1 2 3 4 5 6 7 8 9] [[ 0  1  2  3  4  5  6  7  8  9]  [10 11 12 13 14 15 16 17 18 19]  [20 21 22 23 24 25 26 27 28 29]  [30 31 32 33 34 35 36 37 38 39]  [40 41 42 43 44 45 46 47 48 49]  [50 51 52 53 54 55 56 57 58 59]  [60 61 62 63 64 65 66 67 68 69]  [70 71 72 73 74 75 76 77 78 79]  [80 81 82 83 84 85 86 87 88 89]  [90 91 92 93 94 95 96 97 98 99]]  5 random samples [9 1 5 6 7] [[90 91 92 93 94 95 96 97 98 99]  [10 11 12 13 14 15 16 17 18 19]  [50 51 52 53 54 55 56 57 58 59]  [60 61 62 63 64 65 66 67 68 69]  [70 71 72 73 74 75 76 77 78 79]] 
like image 108
edo Avatar answered Sep 21 '22 06:09

edo


In order to shuffle and sampling each mini-batch, the state whether a sample has been selected inside the current epoch should also be considered. Here is an implementation which use the data in the above answer.

import numpy as np   class Dataset:  def __init__(self,data):     self._index_in_epoch = 0     self._epochs_completed = 0     self._data = data     self._num_examples = data.shape[0]     pass   @property def data(self):     return self._data  def next_batch(self,batch_size,shuffle = True):     start = self._index_in_epoch     if start == 0 and self._epochs_completed == 0:         idx = np.arange(0, self._num_examples)  # get all possible indexes         np.random.shuffle(idx)  # shuffle indexe         self._data = self.data[idx]  # get list of `num` random samples      # go to the next batch     if start + batch_size > self._num_examples:         self._epochs_completed += 1         rest_num_examples = self._num_examples - start         data_rest_part = self.data[start:self._num_examples]         idx0 = np.arange(0, self._num_examples)  # get all possible indexes         np.random.shuffle(idx0)  # shuffle indexes         self._data = self.data[idx0]  # get list of `num` random samples          start = 0         self._index_in_epoch = batch_size - rest_num_examples #avoid the case where the #sample != integar times of batch_size         end =  self._index_in_epoch           data_new_part =  self._data[start:end]           return np.concatenate((data_rest_part, data_new_part), axis=0)     else:         self._index_in_epoch += batch_size         end = self._index_in_epoch         return self._data[start:end]  dataset = Dataset(np.arange(0, 10)) for i in range(10):     print(dataset.next_batch(5)) 

the output is:

[2 8 6 3 4] [1 5 9 0 7] [1 7 3 0 8] [2 6 5 9 4] [1 0 4 8 3] [7 6 2 9 5] [9 5 4 6 2] [0 1 8 7 3] [9 7 8 1 6] [3 5 2 4 0] 

the first and second (3rd and 4th,...) mini-batch correspond to one whole epoch..

like image 24
Brother_Mumu Avatar answered Sep 23 '22 06:09

Brother_Mumu