I am trying out the TensorFlow tutorial and don't understand where does next_batch in this line come from?
batch_xs, batch_ys = mnist.train.next_batch(100)
I looked at
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
And didn't see next_batch there either.
Now when trying out next_batch in my own code, I am getting
AttributeError: 'numpy.ndarray' object has no attribute 'next_batch'
So I would like to understand where does next_batch come from?
next_batch
is a method of the DataSet
class (see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py for more information on what's in the class).
When you load the mnist data and assign it to the variable mnist
with:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
look at the class of mnist.train
. You can see it by typing:
print mnist.train.__class__
You'll see the following:
<class 'tensorflow.contrib.learn.python.learn.datasets.mnist.Dataset'>
Because mnist.train
is an instance of class DataSet
, you can use the class's function next_batch
. For more information on classes, check out the documentation.
After looking through the tensorflow repository, it seems to originate here:
https://github.com/tensorflow/tensorflow/blob/9230423668770036179a72414482d45ddde40a3b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py#L905
However if you're looking to implement it in your own code (for your own dataset), it would likely be much simpler to write it yourself in a dataset object, as I did. As I understand it, it's a method to shuffle the entire dataset, and return $mini_batch_size number of samples from the shuffled dataset.
Here's some pseudocode:
shuffle data.x and data.y while retaining relation
return [data.x[:mb_n], data.y[:mb_n]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With