Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you write and retrieve TFRecord features that are lists?

Tags:

tensorflow

I have a CNN model that takes N classification labels per training example and I am trying to create TFRecords from my data set that have a label feature that is a list of int64s.

On the shard creation side I am using something like the following. I have put the label data explicitly in the code but obviously it would be different for each sample:

example = tf.train.Example(features=tf.train.Features(feature={
  ... # other stuff  
  'label': tf.train.Feature(
                  int64_list=tf.train.Int64List(value=[1, 2, 3, 4])}))
writer.write(example.SerializeToString())

On the reading side, I am doing something like the following. I assume a fixed number of labels (4):

features = tf.parse_single_example(
      serialized_example,
      # Defaults are not specified since both keys are required.
      features={
        ... # other stuff
        'label': tf.FixedLenFeature(
            [4], dtype=tf.int64, default_value=-1)}
      )
label = features['label']

When I try this Tensorflow reports:

ValueError: Cannot reshape a tensor with 1 elements to shape [4] (4 elements)

Clearly, I'm not understanding something fairly fundamental

like image 649
bobw Avatar asked Apr 01 '16 20:04

bobw


1 Answers

Try setting default value= [-1]*4

like image 51
Eugene Brevdo Avatar answered Oct 14 '22 06:10

Eugene Brevdo