Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to map a function with additional parameter using the new Dataset api in TF1.3?

I'm playing with the Dataset API in Tensorflow v1.3. It's great. It is possible to map a dataset with a function as described here. I am interested to know how can I pass a function which has an additional argument, for example arg1:

def _parse_function(example_proto, arg1):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

Of course,

dataset = dataset.map(_parse_function)

will not work since there is no way to pass in arg1.

like image 627
AmirHJ Avatar asked Sep 17 '17 12:09

AmirHJ


People also ask

What does TF data dataset From_tensor_slices do?

from_tensor_slices() It removes the first dimension and use it as a dataset dimension.

What does dataset map do?

Dataset. map() function is used to map the dataset through a 1-to-1 transform.

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.


1 Answers

Here is an example using a lambda expression to wrap the function to which we want to pass an argument:

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))

In the above, the signature of the function provided to map must match the contents of our dataset. So we have to write our lambda expression to match that. Here it is simple, as there is only one element contained in the dataset, the x that contains elements in the range from 0 to 4.

If necessary, you can pass in an arbitrary number of external arguments from outside the dataset: ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3), and so on.

To verify that the above works, we can observe that the mapping indeed multiplies each dataset element by two:

iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)

    while True:
      try:
        print(sess.run(next_x))
      except tf.errors.OutOfRangeError:
        break

The output:

0
2
4
6
8
like image 118
mikkola Avatar answered Sep 28 '22 00:09

mikkola