Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reusing Tensorflow session in multiple threads causes crash

Background:

I have some complex reinforcement learning algorithm that I want to run in multiple threads.

Problem

When trying to call sess.run in a thread I get the following error message:

RuntimeError: The Session graph is empty. Add operations to the graph before calling run().

Code reproducing the error:

import tensorflow as tf

import threading

def thread_function(sess, i):
    inn = [1.3, 4.5]
    A = tf.placeholder(dtype=float, shape=(None), name="input")
    P = tf.Print(A, [A])
    Q = tf.add(A, P)
    sess.run(Q, feed_dict={A: inn})

def main(sess):

    thread_list = []
    for i in range(0, 4):
        t = threading.Thread(target=thread_function, args=(sess, i))
        thread_list.append(t)
        t.start()

    for t in thread_list:
        t.join()

if __name__ == '__main__':

    sess = tf.Session()
    main(sess)

If I run the same code outside a thread it works properly.

Can someone give some insight on how to use Tensorflow sessions properly with python threads?

like image 420
Andreas Pasternak Avatar asked Oct 01 '18 23:10

Andreas Pasternak


People also ask

Does TensorFlow use multiple threads?

The TensorFlow Session object is multithreaded, so multiple threads can easily use the same session and run ops in parallel.

Is TensorFlow thread safe?

As we have seen, the TensorFlow Session object is multithreaded and thread-safe, so multiple threads can easily use the same session and run ops in parallel.

Can multiple threads run at the same time Python?

Like all modern programming languages, Python also allows you to implement multithreading in your applications.

What is the use of session in TensorFlow?

A session allows to execute graphs or part of graphs. It allocates resources (on one or more machines) for that and holds the actual values of intermediate results and variables.


2 Answers

Not only can the Session be the current thread default, but also the graph. While you pass in the session and call run on it, the default graph will be a different one.

You can ammend your thread_function like this to make it work:

def thread_function(sess, i):
    with sess.graph.as_default():
        inn = [1.3, 4.5]
        A = tf.placeholder(dtype=float, shape=(None), name="input")
        P = tf.Print(A, [A])
        Q = tf.add(A, P)
        sess.run(Q, feed_dict={A: inn})

However, I wouldn't hope for any significant speedup. Python threading isn't what it means in some other languages, only certain operations, like io, would run in parallel. For CPU heavy operations it's not very useful. Multiprocessing can run code truely in parallel, but you wouldn't share the same session.

like image 166
de1 Avatar answered Oct 28 '22 13:10

de1


Extending de1's answer with another resource on github: tensorflow/tensorflow#28287 (comment)

The following resolved tf's multithreading compatibility for me:

# on thread 1
session = tf.Session(graph=tf.Graph())
with session.graph.as_default():
    k.backend.set_session(session)
    model = k.models.load_model(filepath)

# on thread 2
with session.graph.as_default():
    k.backend.set_session(session)
    model.predict(x)

This keeps both the Session and the Graph for other threads.
The model is loaded in their "context" (instead of the default ones) and kept for other threads to use.
(By default the model is loaded to the default Session and the default Graph)
Another plus is that they're kept in the same object - easier to handle.

like image 22
EliadL Avatar answered Oct 28 '22 12:10

EliadL