I currently am using the Tensorflow dataset api to perform some augmentations to images at a specified path. The filename itself contains information that states whether to augment the file or not. So what I want to do is read in the files from the dataset and for each file, perform a find within the filename and if I find a specific substring, then set a bool flag and replace the substring with "".
The error I get is:
AttributeError: 'Tensor' object has no attribute 'find'
I can't perform a "find" on the tensor with dtype string entries because find is not a part of the Tensor, so I am trying to figure out how I can go about performing the above action. I have shared some code below that I think demonstrates what I am trying to do. Performance is important, so I would prefer to do this the correct way if anyone sees that I am going about doing this via the Dataset API incorrectly.
def preproc_img(filenames):
def parse_fn(filename):
augment_inst = False
if cfg.SPLIT_INTO_INST:
#*****************************************************
#*** THIS IS WHERE THE LOGIC IS CURRENTLY BREAKING ***
#*****************************************************
if filename.find('_data_augmentation') != -1:
augment_inst = True
filename = filename.replace('_data_augmentation', '')
image_string = tf.read_file(filename)
img = tf.image.decode_image(image_string, channels=3)
return dict(zip([filename], [img]))
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.map(parse_fn)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
def perform_train():
if __name__ == '__main__':
filenames = helper.get_image_paths()
next_batch = preproc_img(filenames)
with tf.Session() as sess:
with sess .graph.as_default():
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
dat = sess.run(next_batch)
# I would now go about calling any of my tf op code below
You can use tf.regex_replace for replacing text in a tf.string tensor.
filename = tf.regex_replace(filename, "_data_augmentation", "")
For TF 2.0
filename = tf.strings.regex_replace(filename, "_data_augmentation", "")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With