Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Passing multiple inputs to keras model from tf.dataset API?

Tags:

keras

My Keras model has two inputs and three outputs. And my tfrecords file has a pair of images and pair of labels . If I used fit_generator , it will work fine. In which I created my own generator that provide the two images to both model inputs, the the three labels to the three model outputs. But I want to use model.fit. In which I can pass the dataset instance directly. So any one know how can pass a tuple of (x1, x2(, (y1, y2, y3) to Keras model vi tf.dataset API.

what I used before:

def _parse_function_all(example_proto):

  features = {'image_raw1': tf.FixedLenFeature([], tf.string),
      'image_raw2': tf.FixedLenFeature([], tf.string),
      'label1': tf.FixedLenFeature([], tf.int64),
      'label2': tf.FixedLenFeature([], tf.int64),
      'label3': tf.FixedLenFeature([], tf.int64),

      }


      features = tf.parse_single_example(example_proto, features)
      image1 = tf.decode_raw(features['image_raw1'], tf.uint8)
      image2 = tf.decode_raw(features['image_raw2'], tf.uint8)

      image1.set_shape([  224 * 224 * 3])
      image2.set_shape([  224 * 224 * 3])

      image1= tf.reshape(image1, (  224 , 224 , 3))
      image2 = tf.reshape(image2, (224 , 224 , 3))

      label1 = tf.cast(features['label1'], tf.int32)
      label2 = tf.cast(features['label2'], tf.int32)
      label3 = tf.cast(features['label3'], tf.int32)

      image_pair = tf.stack([image1, image2], 0)

      label_pair = tf.stack([label1, label2, label3], 0)

  return image_pair, label_pair


 def data_gen(  sess=None):

    dataset = tf.data.TFRecordDataset(val_files, num_parallel_reads=8)  

    dataset = dataset(tf.contrib.data.shuffle_and_repeat(buffer_size=4 * batch_size))


    dataset = dataset(_parse_function_all, num_parallel_calls=4) 
    dataset = dataset.batch(batch_size)

    dataset_val = dataset_val.prefetch(tf.contrib.data.AUTOTUNE)

    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    sess.run(iterator.initializer)
    while True:
        try:
            next_val = sess.run(next_element)
            images = np.array(next_val[0])
            labels = np.array(next_val[1])

            y_true1 = one_hot(labels[:, 0], num_classes=num_classes)
            y_true2 = one_hot(labels[:, 1], num_classes=num_classes)
            y_true_3 = labels[:, 2]


            yield ({'input_1': images[:,0], 'input_2': images[:,1]}, {'out_1': y_true1,'out_2': y_true2, 'concatenate':y_true_3 })
        except tf.errors.OutOfRangeError:

          break

model.fit_generator(generator = data_gen(sess)) 

what I want to use

def _parse_function_all(example_proto):

  features = {'image_raw1': tf.FixedLenFeature([], tf.string),
      'image_raw2': tf.FixedLenFeature([], tf.string),
      'label1': tf.FixedLenFeature([], tf.int64),
      'label2': tf.FixedLenFeature([], tf.int64),
      'label3': tf.FixedLenFeature([], tf.int64),
      }

      features = tf.parse_single_example(example_proto, features)
      image1 = tf.decode_raw(features['image_raw1'], tf.uint8)
      image2 = tf.decode_raw(features['image_raw2'], tf.uint8)

      image1.set_shape([  224 * 224 * 3])
      image2.set_shape([  224 * 224 * 3])

      image1= tf.reshape(image1, (  224 , 224 , 3))
      image2 = tf.reshape(image2, (224 , 224 , 3))

      label1 = tf.cast(features['label1'], tf.int32)
      label2 = tf.cast(features['label2'], tf.int32)
      label3 = tf.cast(features['label3'], tf.int32)

      image_pair = tf.stack([image1, image2], 0)

      label_pair = tf.stack([label1, label2, label3], 0)

  return ((image1, image2), (label1, label2, label3))  # it gave error in this line. because it is wrong way. 

    dataset = tf.data.TFRecordDataset(val_files, num_parallel_reads=8)  
    dataset = dataset(tf.contrib.data.shuffle_and_repeat(buffer_size=4 * batch_size))
    dataset = dataset(_parse_function_all, num_parallel_calls=4) 
    dataset = dataset.batch(batch_size)
    dataset_val = dataset_val.prefetch(tf.contrib.data.AUTOTUNE)

    model.fit(dataset_val)

So Is there any solution for passing tuple of ( image, labels) to Keras model that have multiple inputs?

like image 396
W. Sam Avatar asked Sep 14 '18 03:09

W. Sam


People also ask

Does TF keras sequential support multiple inputs?

Keras is able to handle multiple inputs (and even multiple outputs) via its functional API. Learn more about 3 ways to create a Keras model with TensorFlow 2.0 (Sequential, Functional, and Model Subclassing).

What does TF data dataset From_tensor_slices do?

Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf. data.

Which API is used to build performant complex input pipelines from simple re usable pieces that will feed your model's training or evaluation loops?

Dataset API to build a pipeline for feeding data to your model. tf. data. Dataset is used to build performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.


1 Answers

In the new version of TensorFlow (1.14 and above( , tf.keras allow me to pass multiple instances to model.fit.

like image 85
W. Sam Avatar answered Jan 02 '23 10:01

W. Sam