Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multithreading in tensorflow/keras

I would like to train some different models with model.fit() parallel in one python application. The used models dont have necessary something in common, they are started in one application at different times.

First I start one model.fit() with no problems in a seperate thread then the main thread. If I now want to start a second model.fit(), I get the following error message:

Exception in thread Thread-1:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Node 'hidden_1/BiasAdd': Unknown input node 'hidden_1/MatMul'

They are both getting started from a method by the same lines of code:

start_learn(self:)
   tf_session = K.get_session()  # this creates a new session since one doesn't exist already.
   tf_graph = tf.get_default_graph()

   keras_learn_thread.Learn(learning_data, model, self.env_cont, tf_session, tf_graph)
   learning_results.start()

Th called class/method looks like this:

def run(self):
    tf_session = self.tf_session  # take that from __init__()
    tf_graph = self.tf_graph  # take that from __init__()

    with tf_session.as_default():
        with tf_graph.as_default():
            self.learn(self.learning_data, self.model, self.env_cont)
            # now my learn method where model.fit() is located is being started

I think I somehow have to assign a new tf_session and a new tf_graph for each single thread. But I am not quite sure about that. I would be glad about every short idea, since I am sitting on this for too long now.

Thanks

like image 218
Asdf11 Avatar asked Sep 01 '17 10:09

Asdf11


People also ask

Does TensorFlow use multithreading?

The TensorFlow Session object is multithreaded, so multiple threads can easily use the same session and run ops in parallel.

Is TensorFlow thread safe?

As we have seen, the TensorFlow Session object is multithreaded and thread-safe, so multiple threads can easily use the same session and run ops in parallel.

How do you multi thread in Python?

Creating Thread Using Threading Module Define a new subclass of the Thread class. Override the __init__(self [,args]) method to add additional arguments. Then, override the run(self [,args]) method to implement what the thread should do when started.

What is workers in keras?

The Keras methods fit_generator, evaluate_generator, and predict_generator have an argument called workers . By setting workers to 2 , 4 , 8 or multiprocessing. cpu_count() instead of the default 1 , Keras will spawn threads (or processes with the use_multiprocessing argument) when ingesting data batches.


1 Answers

I don't know if you fixed your issue but this looks like another question I recently answered.

  • You need to finish the graph creation in the main thread before starting the others.
  • In the case of keras, the graph is initialized the first time the fit or predict function is called. You can force the graph creation by calling some of the inner functions of model:

    model._make_predict_function()
    model._make_test_function()
    model._make_train_function()
    

    If that doesn't work, try to warm-up the model by calling on dummy data.

  • Once you finish the graph creation, call finalize() on your main graph so it can be safely shared it with different threads (that will make it read-only).

  • Finalizing the graph will also help you find other places where your graph is being unintentionaly modified.

Hope that helps you.

like image 192
Julio Daniel Reyes Avatar answered Oct 14 '22 20:10

Julio Daniel Reyes