Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I access the filenames gathered by tf.data.Dataset.list_files()?

I am using

file_data = tf.data.Dataset.list_files("../*.png")

to collect image files for training in TensorFlow, but would like to access the list of gathered filenames so I can perform a label lookup.

Calling sess.run([file_data]) has been unsuccessful:

TypeError: Fetch argument <TensorSliceDataset shapes: (), types: tf.string> has invalid type <class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>, must be a string or Tensor. (Can not convert a TensorSliceDataset into a Tensor or Operation.)

Are there any other methods I can use?

like image 837
ROS Avatar asked Jul 03 '18 20:07

ROS


2 Answers

With some additional experimenting, I found a way to solve this:

First, turn the Dataset into an iterator:

iterator_helper = file_data.make_one_shot_iterator()

Then, iterate through the elements in a tf Session:

with tf.Session() as sess:
    filename_temp = iterator_helper.get_next()
    print(sess.run[filename_temp])
like image 184
ROS Avatar answered Oct 11 '22 17:10

ROS


The Dataset.list_files() API uses the tf.matching_files() op to list the files matching the given pattern. You can also get the list of files as a tf.Tensor using that op, and pass it directly to sess.run():

filenames_as_tensor = tf.matching_files("../*.png")
filenames_as_array = sess.run(filenames_as_tensor)

for filename in filenames_as_array:
  print(filename)
like image 24
mrry Avatar answered Oct 11 '22 17:10

mrry