Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I merge multiple tfrecords file into one file?

Tags:

merge

My question is, if I want to create one tfrecords file for my data , it will take approximately 15 days to finish it, it has 500000 pairs of template , and each template is 32 frames( images). In order to save the time, I have 3 GPUs, so I thought I can create three tfrocords file each one file on one GPUs and then I can finish creating the tfrecords in 5 days. But then I searched about a way to merge these three files in one file and couldn't find proper solution.

So Is there any way to merge these three files in one file, OR is there any way that I can train my network by feeding batch of example extracted form the three tfrecords files, knowing I am using Dataset API.

like image 963
W. Sam Avatar asked May 10 '18 03:05

W. Sam


People also ask

How do I merge TF Records?

Just use the new DataSet API: dataset = tf. data. TFRecordDataset(filenames_to_read, compression_type=None, # or 'GZIP', 'ZLIB' if compress you data.

What is a TFRecord?

The TFRecord format is a simple format for storing a sequence of binary records. Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data. Protocol messages are defined by . proto files, these are often the easiest way to understand a message type.


2 Answers

As the question is asked two months ago, I thought you already find the solution. For the follows, the answer is NO, you do not need to create a single HUGE tfrecord file. Just use the new DataSet API:

dataset = tf.data.TFRecordDataset(filenames_to_read,
    compression_type=None,    # or 'GZIP', 'ZLIB' if compress you data.
    buffer_size=10240,        # any buffer size you want or 0 means no buffering
    num_parallel_reads=os.cpu_count()  # or 0 means sequentially reading
)

# Maybe you want to prefetch some data first.
dataset = dataset.prefetch(buffer_size=batch_size)

# Decode the example
dataset = dataset.map(single_example_parser, num_parallel_calls=os.cpu_count())

dataset = dataset.shuffle(buffer_size=number_larger_than_batch_size)
dataset = dataset.batch(batch_size).repeat(num_epochs)
...

For details, check the document.

like image 53
holmescn Avatar answered Sep 22 '22 23:09

holmescn


The answer by MoltenMuffins works for higher versions of tensorflow. However, if you are using lower versions, you have to iterate through the three tfrecords and save them them into a new record file as follows. This works for tf versions 1.0 and above.

def comb_tfrecord(tfrecords_path, save_path, batch_size=128):
        with tf.Graph().as_default(), tf.Session() as sess:
            ds = tf.data.TFRecordDataset(tfrecords_path).batch(batch_size)
            batch = ds.make_one_shot_iterator().get_next()
            writer = tf.python_io.TFRecordWriter(save_path)
            while True:
                try:
                    records = sess.run(batch)
                    for record in records:
                        writer.write(record)
                except tf.errors.OutOfRangeError:
                    break
like image 44
Deepak Sridhar Avatar answered Sep 19 '22 23:09

Deepak Sridhar