Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way to pass dictionary in tf.data.Dataset w/ tf.py_func?

I'm using tf.data.Dataset in data processing and I want to do apply some python code with tf.py_func.

BTW, I found that in tf.py_func, I cannot return a dictionary. Is there any way to do it or workaround?

I have code which looks like below

def map_func(images, labels):
    """mapping python function"""
    # do something
    # cannot be expressed as a tensor graph
    return {
        'images': images,
        'labels': labels,
        'new_key': new_value}
def tf_py_func(images, labels):
    return tf.py_func(map_func, [images, labels], [tf.uint8, tf.string], name='blah')

return dataset.map(tf_py_func)

===========================================================================

It's been a while and I forgot I asked this question. I solved it other way around and it was so easy that I felt I was almost a stupid. The problem was:

  1. tf.py_func cannot return dictionary.
  2. dataset.map can return dictionary.

And the answer is: map twice.

def map_func(images, labels):
    """mapping python function"""
    # do something
    # cannot be expressed as a tensor graph
    return processed_images, processed_labels

def tf_py_func(images, labels):
    return tf.py_func(map_func, [images, labels], [tf.uint8, tf.string], name='blah')

def _to_dict(images, labels):
    return { 'images': images, 'labels': labels }

return dataset.map(tf_py_func).map(_to_dict)
like image 914
Hongjoo Lee Avatar asked Mar 29 '19 06:03

Hongjoo Lee


People also ask

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.

What TF data dataset From_tensor_slices do?

With the help of tf. data. Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf.

What does TF data dataset do?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.


1 Answers

You could turn the dictionary into a string which you return and then split into a dictionary.

This could look something like this:

return (images + " " + labels + " " + new value)

and then in your other function:

l = map_func(image, label).split(" ")
d['images'] = l[0]
d[
...
like image 78
SdahlSean Avatar answered Oct 30 '22 04:10

SdahlSean