I'm trying to do the tensorflow
equivalent of torch.transforms.Resize(TRAIN_IMAGE_SIZE)
, which resizes the smallest image dimension to TRAIN_IMAGE_SIZE
. Something like this
def transforms(filename):
parts = tf.strings.split(filename, '/')
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
# this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
image = largest_sq_crop(image)
image = tf.image.resize(image, (256,256))
return image, label
list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)
The simple answer is here: Tensorflow: Crop largest central square region of image
But when I use the method with tf.data.Dataset.map(transforms)
, I get shape=(None,None,3)
from inside largest_sq_crop(image)
. The method works fine when I call it normally.
I found the answer. It had to do with the fact that my resize method worked fine with eager execution, e.g. tf.executing_eagerly()==True
but failed when used within dataset.map()
. Apparently, in that execution environment, tf.executing_eagerly()==False
.
My error was in the way I was unpacking the shape of the image to get dimensions for scaling. Tensorflow graph execution does not seem to support access to the tensor.shape
tuple.
# wrong
b,h,w,c = img.shape
print("ERR> ", h,w,c)
# ERR> None None 3
# also wrong
b = img.shape[0]
h = img.shape[1]
w = img.shape[2]
c = img.shape[3]
print("ERR> ", h,w,c)
# ERR> None None 3
# but this works!!!
shape = tf.shape(img)
b = shape[0]
h = shape[1]
w = shape[2]
c = shape[3]
img = tf.reshape( img, (-1,h,w,c))
print("OK> ", h,w,c)
# OK> Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)
I was using shape dimensions downstream in my dataset.map()
function and it threw the following exception because it was getting None
instead of a value.
TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (-1, None, None, 3). Consider casting elements to a supported type.
When I switched to manually unpacking the shape from tf.shape()
, everything worked fine.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With