Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.nn.in_top_k: targets out of range

Tags:

tensorflow

I adapt the cifar10 network from the tensorflow, to address my own classification problem. I have trained the network and now I try to evaluate the trained model with the cifar10_eval.py

top_k_op = tf.nn.in_top_k(logits, labels, 1)

But I get following error. After further investigation, the index of target varies between 2,3 and 4

tensorflow.python.framework.errors.InvalidArgumentError: targets[3] is out of range

By now, I understand that something is wrong with my label-Tensor. It's a int32-Tensor with shape(50,) shown below.

labels = {Tensor} Tensor("batch_processing/Reshape_1:0", shape=(50,), dtype=int32, device=/device:CPU:0)

My dataset has just 2 classes/labels. Maybe this could be the problem. Does anyone know, what the problem is?

like image 525
Simon Avatar asked Feb 06 '23 21:02

Simon


1 Answers

To sum it up, the function tf.nn.in_top_k(predictions, targets, k) (see the doc) has arguments:

  • predictions: shape [batch_size, num_classes], type float32
  • targets (the correct label): shape [batch_size], type int32 or int64

The function raises the error InvalidArgumentError: targets[i] is out of range when the element targets[i] is out of range in predictions[i].

For instance, there are 2 classes (num_classes=2) and targets=[1, 3]. With these targets, you will see an error InvalidArgumentError: targets[1] is out of range because targets[1] = 3 is out of range for predictions[1] which has only shape 2.


To check that your labels are correct, you can print the max of them:

labels = ...
labels_max = tf.reduce_max(labels)

sess = tf.Session()
print sess.run(labels_max)

If the value printed is superior to num_classes, you have a problem.

like image 120
Olivier Moindrot Avatar answered Feb 28 '23 05:02

Olivier Moindrot