Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow in_top_k evaluation input argumants

Tags:

tensorflow

I am following the tutorial in this link and trying to change the evaluation method for the model (at the bottom). I would like to get a top-5 evaluation and I'm trying to use to following code:

topFiver=tf.nn.in_top_k(y, y_, 5, name=None)

However, this yields the following error:

File "AlexNet.py", line 111, in <module>
    topFiver = tf.nn.in_top_k(pred, y, 5, name=None)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 346, in in_top_k
    targets=targets, k=k, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 486, in apply_op
    _Attr(op_def, input_arg.type_attr))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 59, in _SatisfiesTypeConstraint
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list)))
TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64

As far as I can tell, the problem is that tf.nn.in_top_k() only works for tf.int32 or tf.int64 data, but my data is in tf.float32 format. Is there any workaround for this?

like image 892
havakok Avatar asked Mar 18 '16 09:03

havakok


1 Answers

The targets argument to tf.nn.in_top_k(predictions, targets, k) must be a vector of class IDs (i.e. indices of columns in the predictions matrix). This means that it only works for single-class classification problems.

If your problem is a single-class problem, then I assume that your y_ tensor is a one-hot encoding of the true labels for your examples (for example because you also pass them to an op like tf.nn.softmax_cross_entropy_with_logits(). In that case, you have two options:

  • If the labels were originally stored as integer labels, pass them directly to tf.nn.in_top_k() without converting them to one-hot. (Also, consider using tf.nn.sparse_softmax_cross_entropy_with_logits() as your loss function, because it may be more efficient.)
  • If the labels were originally stored in the one-hot format, you can convert them to integers using tf.argmax():

    labels = tf.argmax(y_, 1)
    topFiver = tf.nn.in_top_k(y, labels, 5)
    
like image 166
mrry Avatar answered Dec 24 '22 07:12

mrry