Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Where does next_batch in the TensorFlow tutorial batch_xs, batch_ys = mnist.train.next_batch(100) come from?

I am trying out the TensorFlow tutorial and don't understand where does next_batch in this line come from?

 batch_xs, batch_ys = mnist.train.next_batch(100)

I looked at

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

And didn't see next_batch there either.

Now when trying out next_batch in my own code, I am getting

AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'

So I would like to understand where does next_batch come from?

like image 549
Dan Avatar asked Nov 01 '16 21:11

Dan


2 Answers

next_batch is a method of the DataSet class (see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py for more information on what's in the class).

When you load the mnist data and assign it to the variable mnist with:

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

look at the class of mnist.train. You can see it by typing:

print mnist.train.__class__

You'll see the following:

<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>

Because mnist.train is an instance of class DataSet, you can use the class's function next_batch. For more information on classes, check out the documentation.

like image 52
Nick Becker Avatar answered Oct 18 '22 05:10

Nick Becker


After looking through the tensorflow repository, it seems to originate here:

https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905

However if you're looking to implement it in your own code (for your own dataset), it would likely be much simpler to write it yourself in a dataset object, as I did. As I understand it, it's a method to shuffle the entire dataset, and return $mini_batch_size number of samples from the shuffled dataset.

Here's some pseudocode:

shuffle data.x and data.y while retaining relation return [data.x[:mb_n], data.y[:mb_n]]

like image 23
Dark Element Avatar answered Oct 18 '22 05:10

Dark Element