Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow - Read all examples from a TFRecords at once?

How do you read all examples from a TFRecords at once?

I've been using tf.parse_single_example to read out individual examples using code similar to that given in the method read_and_decode in the example of the fully_connected_reader. However, I want to run the network against my entire validation dataset at once, and so would like to load them in their entirety instead.

I'm not entirely sure, but the documentation seems to suggest I can use tf.parse_example instead of tf.parse_single_example to load the entire TFRecords file at once. I can't seem to get this to work though. I'm guessing it has to do with how I specify the features, but I'm not sure how in the feature specification to state that there are multiple examples.

In other words, my attempt of using something similar to:

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_example(serialized_example, features={
    'image_raw': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64),
})

isn't working, and I assume it's because the features aren't expecting multiple examples at once (but again, I'm not sure). [This results in an error of ValueError: Shape () must have rank 1]

Is this the proper way to read all the records at once? And if so, what do I need to change to actually read the records? Thank you much!

like image 879
golmschenk Avatar asked May 11 '16 02:05

golmschenk


4 Answers

Just for clarity, I have a few thousand images in a single .tfrecords file, they're 720 by 720 rgb png files. The labels are one of 0,1,2,3.

I also tried using the parse_example and couldn't make it work but this solution works with the parse_single_example.

The downside is that right now I have to know how many items are in each .tf record, which is kind of a bummer. If I find a better way, I'll update the answer. Also, be careful going out of bounds of the number of records in the .tfrecords file, it will start over at the first record if you loop past the last record

The trick was to have the queue runner use a coordinator.

I left some code in here to save the images as they're being read in so that you can verify the image is correct.

from PIL import Image
import numpy as np
import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  # Defaults are not specified since both keys are required.
  features={
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
      'height': tf.FixedLenFeature([], tf.int64),
      'width': tf.FixedLenFeature([], tf.int64),
      'depth': tf.FixedLenFeature([], tf.int64)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 label = tf.cast(features['label'], tf.int32)
 height = tf.cast(features['height'], tf.int32)
 width = tf.cast(features['width'], tf.int32)
 depth = tf.cast(features['depth'], tf.int32)
 return image, label, height, width, depth


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([ FILE ])
   image, label, height, width, depth = read_and_decode(filename_queue)
   image = tf.reshape(image, tf.pack([height, width, 3]))
   image.set_shape([720,720,3])
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   for i in range(2053):
     example, l = sess.run([image, label])
     img = Image.fromarray(example, 'RGB')
     img.save( "output/" + str(i) + '-train.png')

     print (example,l)
   coord.request_stop()
   coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')
like image 130
Andrew Pierno Avatar answered Nov 15 '22 23:11

Andrew Pierno


If you need to read all the data from TFRecord at once, you can write way easier solution just in a few lines of code using tf_record_iterator:

An iterator that read the records from a TFRecords file.

To do this, you just:

  1. create an example
  2. iterate over records from the iterator
  3. parse each record and read each feature depending on its type

Here is an example with explanation how to read each type.

example = tf.train.Example()
for record in tf.python_io.tf_record_iterator(<tfrecord_file>):
    example.ParseFromString(record)
    f = example.features.feature
    v1 = f['int64 feature'].int64_list.value[0]
    v2 = f['float feature'].float_list.value[0]
    v3 = f['bytes feature'].bytes_list.value[0]
    # for bytes you might want to represent them in a different way (based on what they were before saving)
    # something like `np.fromstring(f['img'].bytes_list.value[0], dtype=np.uint8
    # Now do something with your v1/v2/v3
like image 25
Salvador Dali Avatar answered Nov 15 '22 23:11

Salvador Dali


To read all the data just once, you need to pass num_epochs to the string_input_producer. When all the record are read, the .read method of reader will throw an error, which you can catch. Simplified example:

import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  features={
      'image_raw': tf.FixedLenFeature([], tf.string)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 return image


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([FILE], num_epochs=1)
   image = read_and_decode(filename_queue)
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   try:
     while True:
       example = sess.run([image])
   except tf.errors.OutOfRangeError, e:
     coord.request_stop(e)
   finally:
     coord.request_stop()
     coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')

And to use tf.parse_example (which is faster than tf.parse_single_example) you need to first batch the examples like that:

batch = tf.train.batch([serialized_example], num_examples, capacity=num_examples)
parsed_examples = tf.parse_example(batch, feature_spec)

Unfortunately this way you'd need to know the num of examples beforehand.

like image 12
sygi Avatar answered Nov 15 '22 22:11

sygi


You can also use tf.python_io.tf_record_iterator to manually iterate all examples in a TFRecord.

I test that with an illustration code below:

import tensorflow as tf

X = [[1, 2],
     [3, 4],
     [5, 6]]


def _int_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def dump_tfrecord(data, out_file):
    writer = tf.python_io.TFRecordWriter(out_file)
    for x in data:
        example = tf.train.Example(
            features=tf.train.Features(feature={
                'x': _int_feature(x)
            })
        )
        writer.write(example.SerializeToString())
    writer.close()


def load_tfrecord(file_name):
    features = {'x': tf.FixedLenFeature([2], tf.int64)}
    data = []
    for s_example in tf.python_io.tf_record_iterator(file_name):
        example = tf.parse_single_example(s_example, features=features)
        data.append(tf.expand_dims(example['x'], 0))
    return tf.concat(0, data)


if __name__ == "__main__":
    dump_tfrecord(X, 'test_tfrecord')
    print('dump ok')
    data = load_tfrecord('test_tfrecord')

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        Y = sess.run([data])
        print(Y)

Of course you have to use your own feature specification.

The disadvantage is that I don't how to use multi-threads in this way. However, the most occasion we read all examples is when we evaluate validation data set, which is usually not very big. So I think the efficiency may be not a bottleneck.

And I have another issue when I test this problem, which is that I have to specify the feature length. Instead of tf.FixedLenFeature([], tf.int64), I have to write tf.FixedLenFeature([2], tf.int64), otherwise, an InvalidArgumentError occured. I've no idea how to avoid this.

Python: 3.4
Tensorflow: 0.12.0

like image 9
Frederic Shen Avatar answered Nov 15 '22 21:11

Frederic Shen