I write a test code,and when I run it ,it said Fetch argument cannot be interpreted as a Tensor.I really don't know what's going on .Can somebody tell me how to fix it? thank you very much .Here is the code
# coding=utf-8
from color_1 import read_and_decode, get_batch, get_test_batch
import color_inference
import cv2
import os
import time
import numpy as np
import tensorflow as tf
import color_train
import math
EVAL_INTERVAL_SECS=10
batch_size=128
num_examples = 10000
crop_size=56
def test(test_x, test_y):
with tf.Graph().as_default() as g:
image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
label_holder = tf.placeholder(tf.int32, [batch_size], name='y-input')
y=color_inference.inference(image_holder)
num_iter = int(math.ceil(num_examples / batch_size))
true_count = 0
total_sample_count = num_iter * batch_size
saver=tf.train.Saver()
top_k_op = tf.nn.in_top_k(y, label_holder, 1)
while True:
with tf.Session() as sess:
ckpt=tf.train.get_checkpoint_state(color_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
global_step=ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
image_batch, label_batch = sess.run([test_x, test_y])
predictions = sess.run([top_k_op], feed_dict={image_holder: image_batch,
label_holder: label_batch})
true_count += np.sum(predictions)
precision = true_count * 1.0 / total_sample_count
print("After %s training step,the prediction is :%g",global_step,precision)
else:
print('No checkpoint file found')
return
time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
test_image, test_label = read_and_decode('val.tfrecords')
test_images, test_labels = get_test_batch(test_image, test_label, batch_size, crop_size)
test(test_images, test_labels)
if __name__=='__main__':
tf.app.run()
And the error is here:
File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 57, in <module>
tf.app.run()
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 54, in main
test(test_images, test_labels)
File "/home/vrview/tensorflow/example/char/tfrecords/color_test.py", line 39, in test
image_batch, label_batch = sess.run([test_x, test_y])
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
run_metadata_ptr)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 952, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 408, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 230, in for_fetch
return _ListFetchMapper(fetch)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 337, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 238, in for_fetch
return _ElementFetchMapper(fetches, contraction_fn)
File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 274, in __init__
'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'batch:0' shape=(128, 56, 56, 3) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("batch:0", shape=(128, 56, 56, 3), dtype=float32) is not an element of this graph.)
You focused on the wrong part of the error message. The relevant part is
Tensor is not an element of this graph.
The problem is that you create a graph g
in your function test
, that is not the same one in which placeholders test_x
and test_y
provided as arguments have been created.
The easiest solution would be to create your graph g
in main
,
def main(argv=None):
test_image, test_label = read_and_decode('val.tfrecords')
with tf.Graph().as_default():
test_images, test_labels = get_test_batch(test_image, test_label,
batch_size, crop_size)
test(test_images, test_labels)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With