Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you decode one-hot labels in Tensorflow?

Been looking, but can't seem to find any examples of how to decode or convert back to a single integer from a one-hot value in TensorFlow.

I used tf.one_hot and was able to train my model but am a bit confused on how to make sense of the label after my classification. My data is being fed in via a TFRecords file that I created. I thought about storing a text label in the file but wasn't able to get it to work. It appeared as if TFRecords couldn't store text string or maybe I was mistaken.

like image 667
Matt Camp Avatar asked Dec 30 '16 16:12

Matt Camp


1 Answers

You can find out the index of the largest element in the matrix using tf.argmax. Since your one hot vector will be one dimensional and will have just one 1 and other 0s, This will work assuming you are dealing with a single vector.

index = tf.argmax(one_hot_vector, axis=0)

For the more standard matrix of batch_size * num_classes, use axis=1 to get a result of size batch_size * 1.

like image 173
martianwars Avatar answered Sep 25 '22 10:09

martianwars