Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: access trained variables after closing the session

I recently started using tensorflow and just tried a linear regression model with some simulated data. I have the following code where I am training two variables w and b(which I initialize randomly using numpy) using the GradientDescentOptimizer. After the model has finished training, I want to look at these variables and see how close they are to the actual values. (I have omitted other parts of the code as they are irrelevant to the question). So when the session exits the with tf.Session()..., I open the default session using sess = tf.Session() and try to use sess.run(w) but I am greeted with a Attempting to use uninitialized value train_weights. That is expected. So I thought of using sess.run(tf.global_variables_initializer()) but that simply initializes the values of w back to the start. So the question is - how do I access the final value of a variable after it has been altered in a tensorflow session given that the session has been closed

_ = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(_)
    for epoch in range(n_epochs):
        for batch_pos in range(num_batches):
            x_batch = X[batch_ind[batch_pos]:batch_ind[batch_pos+1]]
            y_batch = Y[batch_ind[batch_pos]:batch_ind[batch_pos+1]]
            sess.run(optimizer, feed_dict = {x_train_batch: x_batch,\
                     y_train_batch: y_batch})
            cost_ = sess.run(cost, feed_dict = {x_train_batch: x_batch,\
                     y_train_batch: y_batch})
        if (epoch)%display_rate == 0:
            print('Epoch:', epoch+1, 'Cost: ', cost_)

sess = tf.Session()
# sess.run(tf.global_variables_initializer())
sess.run(w)
like image 419
Clock Slave Avatar asked Aug 10 '17 08:08

Clock Slave


1 Answers

If you have already closed the session, there is no way to access weights from it.

In general there are two ways:

1) Don't close the session (don't use construction with tf.Session() as sess);

2) Use saver at the end of the session:

saver = tf.train.Saver() 
saver.save(sess, 'filename.chkp')

Then you'll be able to access the model:

sess = tf.Session()
saver = tf.train.Saver()
saver.restore(sess, 'filename.chkp')
like image 86
Andrey Lukyanenko Avatar answered Nov 07 '22 20:11

Andrey Lukyanenko