Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using TFRecords with keras

I have transformed an image database into two TFRecords, one for training and the other for validation. I want to train a simple model with keras using these two files for data input but I obtain an error I can't understand related to the shape of the data.

Here is the code (all-capital variables are defined elsewhere):

def _parse_function(proto):
    f = {
        "x": tf.FixedLenSequenceFeature([IMG_SIZE[0] * IMG_SIZE[1]], tf.float32, default_value=0., allow_missing=True),
        "label": tf.FixedLenSequenceFeature([1], tf.int64, default_value=0, allow_missing=True)
    }
    parsed_features = tf.parse_single_example(proto, f)

    x = tf.reshape(parsed_features['x'] / 255, (IMG_SIZE[0], IMG_SIZE[1], 1))
    y = tf.cast(parsed_features['label'], tf.float32)
    return x, y

def load_dataset(input_path, batch_size, shuffle_buffer):
    dataset = tf.data.TFRecordDataset(input_path)
    dataset = dataset.shuffle(shuffle_buffer).repeat()  # shuffle and repeat
    dataset = dataset.map(_parse_function, num_parallel_calls=16)
    dataset = dataset.batch(batch_size).prefetch(1)  # batch and prefetch

    return dataset.make_one_shot_iterator()

train_iterator = load_dataset(TRAIN_TFRECORDS, BATCH_SIZE, SHUFFLE_BUFFER)
val_iterator = load_dataset(VALIDATION_TFRECORDS, BATCH_SIZE, SHUFFLE_BUFFER)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=(IMG_SIZE[0], IMG_SIZE[1], 1)))
model.add(tf.keras.layers.Dense(1, 'sigmoid'))

model.compile(
    optimizer=tf.train.AdamOptimizer(),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

model.fit(
    train_iterator,
    epochs=N_EPOCHS,
    steps_per_epoch=N_TRAIN // BATCH_SIZE,
    validation_data=val_iterator,
    validation_steps=N_VALIDATION // BATCH_SIZE

)

And here is the error I obtain:

tensorflow.python.framework.errors_impl.InvalidArgumentError: data[0].shape = [3] does not start with indices[0].shape = [2]
     [[Node: training/TFOptimizer/gradients/loss/dense_loss/Mean_grad/DynamicStitch = DynamicStitch[N=2, T=DT_INT32, _class=["loc:@training/TFOptimizer/gradients/loss/dense_loss/Mean_grad/floordiv"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](training/TFOptimizer/gradients/loss/dense_loss/Mean_grad/range, training/TFOptimizer/gradients/loss/dense_loss/Mean_3_grad/Maximum, training/TFOptimizer/gradients/loss/dense_loss/Mean_grad/Shape/_35, training/TFOptimizer/gradients/loss/dense_loss/Mean_3_grad/Maximum/_41)]]

(I know that the model defined here is not a good model for image analysis, I just took the simplest possible architecture that reproduces the error)

like image 659
Gianluca Micchi Avatar asked Jan 30 '19 12:01

Gianluca Micchi


People also ask

How do TFRecords train their models?

How to use a dataset from tfrecord for training a model? To use data extracted from tfrecord for training a model, we will be creating an iterator on the dataset object. After creating this iterator, we will loop into this iterator so that we can train the model on every image extracted from this iterator.

Should I use TFRecord?

Using the TFRecord format has many advantages: Efficiency: Data in the TFRecord format can take up less space than the original data. Fast I/O: TensorFlow can read data in the TFRecord format with parallel I/O operations. This is very useful when you are working with GPU or TPU devices.

What is TFRecords?

The TFRecord format is a simple format for storing a sequence of binary records. Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data. Protocol messages are defined by . proto files, these are often the easiest way to understand a message type.


1 Answers

Change:

"label": tf.FixedLenSequenceFeature([1]...

into:

"label": tf.FixedLenSequenceFeature([]...

This is unfortunately not explained in the documentation on the website, but some explanation can be found in the docstring of FixedLenSequenceFeature on github. Basically, if your data consists of a single dimension (+ a batch dimension), you don't need to specify it.

like image 162
sdcbr Avatar answered Oct 13 '22 18:10

sdcbr