Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Writing and Reading lists to TFRecord example

I want to write a list of integers (or any multidimensional numpy matrix) to one TFRecords example. For both a single value or a list of multiple values I can creates the TFRecord file without error. I know also how to read the single value back from TFRecord file as shown in the below code sample I compiled from various sources.

# Making an example TFRecord

my_example = tf.train.Example(features=tf.train.Features(feature={
    'my_ints': tf.train.Feature(int64_list=tf.train.Int64List(value=[5]))
}))

my_example_str = my_example.SerializeToString()
with tf.python_io.TFRecordWriter('my_example.tfrecords') as writer:
    writer.write(my_example_str)

# Reading it back via a Dataset

featuresDict = {'my_ints': tf.FixedLenFeature([], dtype=tf.int64)}

def parse_tfrecord(example):
    features = tf.parse_single_example(example, featuresDict)
    return features

Dataset = tf.data.TFRecordDataset('my_example.tfrecords')
Dataset = Dataset.map(parse_tfrecord)
iterator = Dataset.make_one_shot_iterator()
with tf.Session() as sess:
   print(sess.run(iterator.get_next()))

But how can I read back a list of values (e.g. [5,6]) from one example? The featuresDict defines the feature to be of type int64, and it fails when I have multiple values in it and I get below error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: my_ints.  Can't parse serialized Example.
like image 500
Shahriar49 Avatar asked Jun 02 '19 20:06

Shahriar49


1 Answers

You can achieve this by using tf.train.SequenceExample. I've edited your code to return both 1D and 2D data. First, you create a list of features which you place in a tf.train.FeatureList. We convert our 2D data to bytes.

vals = [5, 5]
vals_2d = [np.zeros((5,5), dtype=np.uint8), np.ones((5,5), dtype=np.uint8)]

features = [tf.train.Feature(int64_list=tf.train.Int64List(value=[val])) for val in vals]
features_2d = [tf.train.Feature(bytes_list=tf.train.BytesList(value=[val.tostring()])) for val in vals_2d]
featureList = tf.train.FeatureList(feature=features)
featureList_2d = tf.train.FeatureList(feature=features_2d)

In order to get the correct shape of our 2D feature we need to provide context (non-sequential data), this is done with a context dictionary.

context_dict = {'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[vals_2d[0].shape[0]])), 
            'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[vals_2d[0].shape[1]])),
           'length': tf.train.Feature(int64_list=tf.train.Int64List(value=[len(vals_2d)]))}

Then you place each FeatureList in a tf.train.FeatureLists dictionary. Finally, this is placed in a tf.train.SequenceExample along with the context dictionary

my_example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list={'1D':featureList,
                                                                                   '2D': featureList_2d}),
                                 context = tf.train.Features(feature=context_dict))
my_example_str = my_example.SerializeToString()
with tf.python_io.TFRecordWriter('my_example.tfrecords') as writer:
    writer.write(my_example_str)

To read it back into tensorflow you need to use tf.FixedLenSequenceFeature for the sequential data and tf.FixedLenFeature for the context data. We convert the bytes back to integers and we parse the context data in order to restore the correct shape.

# Reading it back via a Dataset
featuresDict = {'1D': tf.FixedLenSequenceFeature([], dtype=tf.int64),
           '2D': tf.FixedLenSequenceFeature([], dtype=tf.string)}
contextDict = {'height': tf.FixedLenFeature([], dtype=tf.int64),
          'width': tf.FixedLenFeature([], dtype=tf.int64),
          'length':tf.FixedLenFeature([], dtype=tf.int64)}

def parse_tfrecord(example):
    context, features = tf.parse_single_sequence_example(
                            example, 
                            sequence_features=featuresDict,                                                   
                            context_features=contextDict
                        )

    height = context['height']
    width = context['width']
    seq_length = context['length']
    vals = features['1D']
    vals_2d = tf.decode_raw(features['2D'], tf.uint8)
    vals_2d = tf.reshape(vals_2d, [seq_length, height, width])
    return vals, vals_2d

Dataset = tf.data.TFRecordDataset('my_example.tfrecords')
Dataset = Dataset.map(parse_tfrecord)
iterator = Dataset.make_one_shot_iterator()
with tf.Session() as sess:
    print(sess.run(iterator.get_next()))

This will output the sequence of [5, 5] and the 2D numpy arrays. This blog post has a more in depth look at defining sequences with tfrecords https://dmolony3.github.io/Working%20with%20image%20sequences.html

like image 79
DMolony Avatar answered Nov 17 '22 03:11

DMolony