I'm using the dataset API, reading data as follows:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
I now want to use flat_map
in order to filter out some, while duplicating some other samples dynamically at training time (this is the input function leading to my model).
The API for flat_map
requires to return a Dataset
object, however I don't know how to create that. Here's a pseudo-code implementation of what I want to achieve:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]
dataset.flat_map(flat_map_impl)
How can I implement this in the flat_map
function?
NOTE: I guess it's possible to implement this via a py_func
, but I'd prefer to avoid this.
To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.
Stay organized with collections Save and categorize content based on your preferences. TensorFlow code, and tf.keras models will transparently run on a single GPU with no code changes required.
With that knowledge, from_tensors makes a dataset where each input tensor is like a row of your dataset, and from_tensor_slices makes a dataset where each input tensor is column of your data; so in the latter case all tensors must be the same length, and the elements (rows) of the resulting dataset are tuples with one ...
Perhaps the most common way to create a tf.data.Dataset
when returning from a Dataset.flat_map()
is to use Dataset.from_tensors()
or Dataset.from_tensor_slices()
. In this case, because tf_example
is a dictionary, it is probably easiest to use a combination of Dataset.from_tensors()
and Dataset.repeat(count)
, where a conditional expression computes count
:
dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))
def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))
return tf.data.Dataset.from_tensors(tf_example).repeat(count)
dataset = dataset.flat_map(flat_map_impl)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With