tensorflow recommends the tf.data.Dataset for importing data. Is it possible to use it for validation and training, if the validation size of the images is different to the training images?
import tensorflow as tf
import generator
import glob
import cv2
BATCH_SIZE = 4
filenames_train = glob.glob("/home/user/Datasets/MsCoco/train2017/*.jpg")
filenames_valid = glob.glob("/home/user/Datasets/Set5_14/*.png")
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename):
image_decoded = cv2.imread(filename, cv2.IMREAD_COLOR)
image_blurred_decoded = cv2.GaussianBlur(image_decoded, (1, 1), 0)
return image_decoded, image_blurred_decoded
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, image_blurred_decoded):
image_decoded.set_shape([None, None, None])
image_blurred_decoded.set_shape([None, None, None])
image_resized = tf.cast(tf.image.resize_images(image_decoded, [288, 288]),tf.uint8)
image_blurred = tf.cast(tf.image.resize_images(image_blurred_decoded, [72, 72]),tf.uint8)
return image_resized, image_blurred
def _cast_function(image_decoded, image_blurred_decoded):
image_resized = tf.cast(image_decoded,tf.uint8)
image_blurred = tf.cast(image_blurred_decoded,tf.uint8)
return image_resized, image_blurred
dataset_train = tf.data.Dataset.from_tensor_slices(filenames_train)
dataset_train = dataset_train.map(
lambda filename: tuple(tf.py_func(
_read_py_function, [filename], [tf.uint8, tf.uint8])))
dataset_train = dataset_train.map(_resize_function)
#dataset_train = dataset_train.shuffle(buffer_size=10000)
dataset_train = dataset_train.repeat()
dataset_train = dataset_train.batch(BATCH_SIZE)
# validation dataset
dataset_valid = tf.data.Dataset.from_tensor_slices(filenames_valid)
dataset_valid = dataset_valid.map(
lambda filename: tuple(tf.py_func(
_read_py_function, [filename], [tf.uint8, tf.uint8])))
dataset_train = dataset_train.map(_cast_function)
dataset_valid = dataset_valid.batch(BATCH_SIZE)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, dataset_train.output_types)
next_element = iterator.get_next()
training_iterator = dataset_train.make_one_shot_iterator()
validation_iterator = dataset_valid.make_initializable_iterator()
my_transformator = generator.johnson(tf.cast(next_element[1],tf.float32))
images_transformed = my_transformator.new_images
images_transformed_uint = tf.cast(images_transformed,tf.uint8)
loss_square = tf.square(tf.cast(next_element[0],tf.float32)-images_transformed)
loss_sum = tf.reduce_sum(loss_square)
loss_norm = tf.cast(tf.shape(next_element[0])[0]*tf.shape(next_element[0])[1]*tf.shape(next_element[0])[2]*tf.shape(next_element[0])[3],tf.float32)
loss = tf.reduce_sum(loss_square)/loss_norm
solver = tf.train.AdamOptimizer(learning_rate=0.001,beta1=0.5).minimize(loss)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
for i in range(200000):
curr_norm,curr_loss_sum, _, curr_loss, curr_labels, curr_transformed, curr_loss_square = sess.run([loss_norm,loss_sum, solver,loss,next_element,images_transformed_uint, loss_square], feed_dict={handle: training_handle})
if i%1000 == 0:
curr_labels, curr_transformed = sess.run([next_element, images_transformed_uint], feed_dict={handle: validation_handle})
If I try that code I get the following error:
FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element. [[Node: IteratorGetNext = IteratorGetNextoutput_shapes=[, ], output_types=[DT_UINT8, DT_UINT8], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
In the code you can see, that I don't resize the images from validation dataset. These validation images have different image size.
You've just forgot to initialize the validation_iterator
.
Just add sess.run(validation_iterator.initializer)
before running the for-loop.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With