Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to reset initialization in TensorFlow 2

If I try to change parallelism in TensorFlow 2 after initializing a tf.Variable,

import tensorflow as tf
_ = tf.Variable([1])
tf.config.threading.set_inter_op_parallelism_threads(1)

I get an error

RuntimeError: Inter op parallelism cannot be modified after initialization.

I understand why that could be, but it (and possibly other factors) are causing my tests to interfere with each other. For example

def test_model():  # this test
   v = tf.Variable([1])
   ...

def test_threading():  # is breaking this test
   tf.config.threading.set_inter_op_parallelism_threads(1)
   ...

How do I reset the TensorFlow state so that I can set the threading?

like image 884
joel Avatar asked Jan 06 '20 17:01

joel


1 Answers

This is achievable in a "hacky" way. But I'd recommend doing this the right way (i.e. by setting up config at the beginning).

import tensorflow as tf
from tensorflow.python.eager import context

_ = tf.Variable([1])

context._context = None
context._create_context()

tf.config.threading.set_inter_op_parallelism_threads(1)

Edit: What is meant by setting up config at the beginning,

import tensorflow as tf
from tensorflow.python.eager import context

tf.config.threading.set_inter_op_parallelism_threads(1)
_ = tf.Variable([1])

But there could be circumstances where you cannot always do this. Merely pointing out the conventional way of setting up config in tf. So if your circumstances don't allow you to fix tf.config at the beginning you have to reset your tf.eager.context as shown in the solution above.

like image 63
thushv89 Avatar answered Sep 21 '22 20:09

thushv89