Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Running Tensorflow on big data

I want to train a Tensorflow neural network on a large data-set of 100 millions rows with around 15000 features per row. Training on a single machine would likely be too slow, so I want to be running distributed.

All the distributed examples I've seen so far start with loading the whole data into memory and then sending to the slaves, which would be too expensive in my case.

Does anyone know how to set up have the slaves to stream in their training data? Currently the data is stored in Google cloud storage, but we can be flexible about this.

like image 460
Daniel Slater Avatar asked Nov 07 '22 23:11

Daniel Slater


1 Answers

Although I never tried it with so many samples (just out of interest, what dataset are you training on?), I think you should use the Queuerunner objects!

They can be found on this page: https://www.tensorflow.org/programmers_guide/reading_data , in the section "Creating threads to prefetch using QueueRunner objects".

A quote on how they work:

The short version: many of the tf.train functions listed above add tf.train.QueueRunner objects to your graph. These require that you call tf.train.start_queue_runners before running any training or inference steps, or it will hang forever. This will start threads that run the input pipeline, filling the example queue so that the dequeue to get the examples will succeed. This is best combined with a tf.train.Coordinator to cleanly shut down these threads when there are errors.

The recommended code pattern from that page to combine this is:

# Create the graph, etc.
init_op = tf.global_variables_initializer()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

Although Stackoverflow always likes complete explanations instead of links to relevant pages, there is way more information on the page I linked above!

Would love to know if this solved your problem, and good luck!

like image 140
rmeertens Avatar answered Nov 15 '22 09:11

rmeertens