Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to perform string find and replace on Tensorflow String Tensor?

I currently am using the Tensorflow dataset api to perform some augmentations to images at a specified path. The filename itself contains information that states whether to augment the file or not. So what I want to do is read in the files from the dataset and for each file, perform a find within the filename and if I find a specific substring, then set a bool flag and replace the substring with "".

The error I get is:

AttributeError: 'Tensor' object has no attribute 'find'

I can't perform a "find" on the tensor with dtype string entries because find is not a part of the Tensor, so I am trying to figure out how I can go about performing the above action. I have shared some code below that I think demonstrates what I am trying to do. Performance is important, so I would prefer to do this the correct way if anyone sees that I am going about doing this via the Dataset API incorrectly.

def preproc_img(filenames):
  def parse_fn(filename):
    augment_inst = False
    if cfg.SPLIT_INTO_INST:
      #*****************************************************
      #*** THIS IS WHERE THE LOGIC IS CURRENTLY BREAKING ***
      #*****************************************************
      if filename.find('_data_augmentation') != -1:
        augment_inst = True
        filename = filename.replace('_data_augmentation', '')

    image_string = tf.read_file(filename)
    img = tf.image.decode_image(image_string, channels=3)
    return dict(zip([filename], [img]))   

  dataset = tf.data.Dataset.from_tensor_slices(filenames)
  dataset = dataset.map(parse_fn)
  iterator = dataset.make_one_shot_iterator()
  return iterator.get_next()


def perform_train():
  if __name__ == '__main__':
    filenames = helper.get_image_paths()
    next_batch = preproc_img(filenames)

  with tf.Session() as sess:
    with sess .graph.as_default():
      sess.run(tf.local_variables_initializer())
      sess.run(tf.global_variables_initializer())

      dat = sess.run(next_batch)
      # I would now go about calling any of my tf op code below
like image 783
xtr33me Avatar asked Apr 25 '26 14:04

xtr33me


1 Answers

You can use tf.regex_replace for replacing text in a tf.string tensor.

filename = tf.regex_replace(filename, "_data_augmentation", "")

For TF 2.0

filename = tf.strings.regex_replace(filename, "_data_augmentation", "")
like image 141
nessuno Avatar answered Apr 27 '26 23:04

nessuno



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!