I just trained a CNN to recognise sunspots with tensorflow. My model is pretty much the same as this. The problem is that I cannot find anywhere a clear explanation on how to make predictions with the checkpoint generated by the training phase.
Tried using the standard restore method:
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')
saver.restore(sess,'./model/model.ckpt')
but then I cannot figure out how to run it.
Tried using tf.estimator.Estimator.predict()
like this:
# Create the Estimator (should reload the last checkpoint but it doesn't)
sunspot_classifier = tf.estimator.Estimator(
model_fn=cnn_model_fn, model_dir="./model")
# Set up logging for predictions
# Log the values in the "Softmax" tensor with label "probabilities"
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=50)
# predict with the model and print results
pred_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": pred_data},
shuffle=False)
pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)
print(pred_results)
but what it does is spitting out <generator object Estimator.predict at 0x10dda6bf8>
.
While if I use the same code but with tf.estimator.Estimator.evaluate()
it works like a charm (reloads the model, performs evaluation and sends it to TensorBoard).
I know there are many similar questions but I couldn't really find the way that worked for me.
sunspot_classifier.predict(input_fn=pred_input_fn)
returns generator. So pred_results
is generator object. To get value from it you need to iterate it by next(pred_results)
The solution is
print(next(pred_results))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With