Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use tf.data.Dataset.padded_batch with a nested shape?

I am building a dataset with two tensors of shape [batch,width,heigh,3] and [batch,class] for each element. For simplicity lets say class = 5.

What shape do you feed to dataset.padded_batch(1000,shape) such that image is padded along the width/height/3 axis?

I have tried the following:

tf.TensorShape([[None,None,None,3],[None,5]])
[tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])]
[[None,None,None,3],[None,5]]
([None,None,None,3],[None,5])
(tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])‌​)

Each raising TypeError

The docs state:

padded_shapes: A nested structure of tf.TensorShape or tf.int64 vector tensor-like objects representing the shape to which the respective component of each input element should be padded prior to batching. Any unknown dimensions (e.g. tf.Dimension(None) in a tf.TensorShape or -1 in a tensor-like object) will be padded to the maximum size of that dimension in each batch.

The relevant code:

dataset = tf.data.Dataset.from_generator(generator,tf.float32)
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)
like image 360
fezzik Avatar asked Nov 03 '17 19:11

fezzik


1 Answers

Thanks to mrry for finding the solution. Turns out that the type in from_generator has to match the number of tensors in the entries.

new code:

dataset = tf.data.Dataset.from_generator(generator,(tf.float32,tf.float32))
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)
like image 163
fezzik Avatar answered Sep 20 '22 22:09

fezzik