Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using flat_map in Tensorflow's Dataset API

I'm using the dataset API, reading data as follows:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

I now want to use flat_map in order to filter out some, while duplicating some other samples dynamically at training time (this is the input function leading to my model).

The API for flat_map requires to return a Dataset object, however I don't know how to create that. Here's a pseudo-code implementation of what I want to achieve:

def flat_map_impl(tf_example):
    # Pseudo-code:
    # if tf_example["a"] == 1:
    #     return []
    # else:
    #     return [tf_example, tf_example]

dataset.flat_map(flat_map_impl)

How can I implement this in the flat_map function?

NOTE: I guess it's possible to implement this via a py_func, but I'd prefer to avoid this.

like image 531
knub Avatar asked May 25 '18 13:05

knub


People also ask

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

Does tf data use GPU?

Stay organized with collections Save and categorize content based on your preferences. TensorFlow code, and tf.keras models will transparently run on a single GPU with no code changes required.

What is the difference between dataset From_tensors and dataset From_tensor_slices?

With that knowledge, from_tensors makes a dataset where each input tensor is like a row of your dataset, and from_tensor_slices makes a dataset where each input tensor is column of your data; so in the latter case all tensors must be the same length, and the elements (rows) of the resulting dataset are tuples with one ...


1 Answers

Perhaps the most common way to create a tf.data.Dataset when returning from a Dataset.flat_map() is to use Dataset.from_tensors() or Dataset.from_tensor_slices(). In this case, because tf_example is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors() and Dataset.repeat(count), where a conditional expression computes count:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
  count = tf.cond(tf.equal(tf_example["a"], 1)),
                  lambda: tf.constant(0, dtype=tf.int64),
                  lambda: tf.constant(2, dtype=tf.int64))

  return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)
like image 71
mrry Avatar answered Sep 16 '22 17:09

mrry