Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow parsing and reshaping float list in Dataset.map()

I am trying to write a 3D float list into a TFrecord, so I successfully, write it by flattening it first, I parse it but it raises an error while reshaping it.

Error: ValueError: Shapes () and (8,) are not compatible

This is how I write the TFrecord file

def _floats_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))

def write(output_path, data_rgb, data_depth, data_decalib):
    with tf.python_io.TFRecordWriter(output_path) as writer:

        feature = {'data_rgb': _floats_feature(data_rgb),
                   'data_depth': _floats_feature(data_depth),
                   'data_decalib': _floats_feature(data_decalib)}
        sample = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(sample.SerializeToString())

And This is how I read the TFrecord file

def get_batches(date, drives, batch_size=1):
    """
    Create a generator that returns batches of tuples
    rgb, depth and calibration
    :param date: date of the drive
    :param drives: array of the drive_numbers within the drive date
    :return: batch generator
    """

    filenames = get_paths_drives(date, drives)
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(input_parser)  # Parse the record into tensors.
    dataset = dataset.repeat()  # Repeat the input indefinitely.
    dataset = dataset.batch(batch_size)

    return dataset

config = configparser.ConfigParser()
config.read(path_helpers.get_config_file_path())

IMAGE_WIDTH = int(config['DATA_INFORMATION']['IMAGE_WIDTH'])
IMAGE_HEIGHT = int(config['DATA_INFORMATION']['IMAGE_HEIGHT'])

INPUT_RGB_SHAPE = [IMAGE_HEIGHT, IMAGE_WIDTH, 3]
INPUT_DEPTH_SHAPE = [IMAGE_HEIGHT, IMAGE_WIDTH, 1]
LABEL_CALIB_SHAPE = [8]

def input_parser(example_proto):
    features = {'data_rgb': tf.FixedLenFeature([], tf.float32),
                'data_depth': tf.FixedLenFeature([], tf.float32),
                'data_decalib': tf.FixedLenFeature([], tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)

    data_rgb = parsed_features['data_rgb']
    data_rgb.set_shape(np.prod(INPUT_RGB_SHAPE))
    img_rgb = tf.reshape(data_rgb, INPUT_RGB_SHAPE)

    data_depth = parsed_features['data_depth']
    data_depth.set_shape(np.prod(INPUT_DEPTH_SHAPE))
    img_depth = tf.reshape(data_depth, INPUT_DEPTH_SHAPE)

    data_decalib = parsed_features['data_decalib']
    data_decalib.set_shape(LABEL_CALIB_SHAPE)

    return img_rgb, img_depth, data_decalib
like image 783
Mark Rofail Avatar asked May 17 '18 09:05

Mark Rofail


1 Answers

Turns out I needed to change my input parser as follows:

def input_parser(example_proto):
    features = {'data_rgb': tf.FixedLenFeature(shape=[np.prod(INPUT_RGB_SHAPE)], dtype=tf.float32),
                'data_depth': tf.FixedLenFeature(shape=[np.prod(INPUT_DEPTH_SHAPE)], dtype=tf.float32),
                'data_decalib': tf.FixedLenFeature(shape=LABEL_CALIB_SHAPE, dtype=tf.float32)}
    parsed_features = tf.parse_single_example(example_proto, features)

as the documentation for tf.FixedLenFeature (now tf.io.FixedLenFeature) dictates. The first argument is the shape, which I set to [] hence the error ValueError: Shapes () and (8,) are not compatible. Setting it to their real values worked out.

like image 120
Mark Rofail Avatar answered Oct 24 '22 13:10

Mark Rofail