Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get data set as numpy array from TFRecordDataset

I'm using the new tf.data API to create an iterator for the CIFAR10 dataset. I'm reading the data from two .tfrecord files. One which holds the training data (train.tfrecords) and another one which holds the test data (test.tfrecords). This works all fine. At some point, however, I need both data sets (training data and test data) as numpy arrays.

Is it possible to retrieve a data set as numpy array from a tf.data.TFRecordDataset object?

like image 430
Marius Mosbach Avatar asked Feb 19 '18 17:02

Marius Mosbach


1 Answers

You can use the tf.data.Dataset.batch() transformation and tf.contrib.data.get_single_element() to do this. As a refresher, dataset.batch(n) will take up to n consecutive elements of dataset and convert them into one element by concatenating each component. This requires all elements to have a fixed shape per component. If n is larger than the number of elements in dataset (or if n doesn't divide the number of elements exactly), then the last batch can be smaller. Therefore, you can choose a large value for n and do the following:

import numpy as np
import tensorflow as tf

# Insert your own code for building `dataset`. For example:
dataset = tf.data.TFRecordDataset(...)  # A dataset of tf.string records.
dataset = dataset.map(...)  # Extract components from each tf.string record.

# Choose a value of `max_elems` that is at least as large as the dataset.
max_elems = np.iinfo(np.int64).max
dataset = dataset.batch(max_elems)

# Extracts the single element of a dataset as one or more `tf.Tensor` objects.
# No iterator needed in this case!
whole_dataset_tensors = tf.contrib.data.get_single_element(dataset)

# Create a session and evaluate `whole_dataset_tensors` to get arrays.
with tf.Session() as sess:
    whole_dataset_arrays = sess.run(whole_dataset_tensors)
like image 117
mrry Avatar answered Sep 26 '22 02:09

mrry