Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dataset API does not pass dimensionality information for its output tensor when using py_func

To reproduce my problem, try this first (mapping with py_func):

import tensorflow as tf
import numpy as np
def image_parser(image_name):
    a = np.array([1.0,2.0,3.0], dtype=np.float32)
    return a

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser, [image], [tf.float32])), num_parallel_calls = 2)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)

It will give you (TensorShape(None),)

However, if you try this (using direct tensorflow mapping instead of py_func):

import tensorflow as tf
import numpy as np

def image_parser(image_name)
    return image_name

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(image_parser)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)

It will give you the exact tensor dimension (3,)

like image 218
Jiang Wenbo Avatar asked Feb 16 '18 09:02

Jiang Wenbo


1 Answers

This is a general problem with tf.py_func which is intended since TensorFlow cannot infer the output shape itself, see for instance this answer.

You could set the shape yourself if you need to, by moving the tf.py_func inside the parse function:

def parser(x):
    a = np.array([1.0,2.0,3.0])
    y = tf.py_func(lambda: a, [], tf.float32)
    y.set_shape((3,))
    return y

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(parser)
print(dataset.output_shapes)  # will correctly print (3,)
like image 161
Olivier Moindrot Avatar answered Oct 03 '22 20:10

Olivier Moindrot