Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I use values read from TFRecords as arguments to tf.reshape?

def read_and_decode(filename_queue):
  reader = tf.TFRecordReader()
  _, serialized_example = reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      # Defaults are not specified since both keys are required.
      features={
          'image_raw': tf.FixedLenFeature([], tf.string),
          'label': tf.FixedLenFeature([], tf.int64),
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'depth': tf.FixedLenFeature([], tf.int64)
      })
  # height = tf.cast(features['height'],tf.int32)
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image = tf.reshape(image,[32, 32, 3])
  image = tf.cast(image,tf.float32)
  label = tf.cast(features['label'], tf.int32)
  return image, label

I'm using a TFRecord to store all my data. The function read_and_decode is from the TFRecords example provided by TensorFlow. Currently I reshape by having predefined values:

image = tf.reshape(image,[32, 32, 3])

However, the data that I will be using now is of different dimensions. For example, I could have an image that is [40, 30, 3] (scaling this is not an option as I don't want it to be warped). I would like to read in data of different dimensions and use random_crop in the data augmentation stage to circumvent this problem. What I need is something like the following.

height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image,[height, width, 3])

However, I can't seem to find a way to do this. Thanks for your help!

EDIT:

ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(None), Dimension(None)]), TensorShape([])]

image = tf.reshape(image, tf.pack([height, width, 3]))
image = tf.reshape(image, [32,32,3])

The problem is definitely with these 2 lines. The hard coded variables work, but not the one with tf.pack().

like image 302
jkschin Avatar asked Feb 25 '16 06:02

jkschin


People also ask

How does TF reshape work?

The tf. reshape does not change the order of or the total number of elements in the tensor, and so it can reuse the underlying data buffer. This makes it a fast operation independent of how big of a tensor it is operating on. To instead reorder the data to rearrange the dimensions of a tensor, see tf.


1 Answers

You're very close to having a working solution! Right now there's no automatic way to give TensorFlow a list made up of tensors and numbers and make a tensor from it, which tf.reshape() is expecting. The answer is to use tf.stack(), which explicitly takes a list of N-dimensional tensors (or things convertible to tensors) and packs them into an (N+1)-dimensional tensor.

This means you can write:

features = ...  # Parse from an example proto.
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)

image = tf.reshape(image, tf.stack([height, width, 3]))
like image 185
mrry Avatar answered Nov 01 '22 17:11

mrry