Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to load pickle files by tensorflow's tf.data API

I have my data in multiple pickle files stored on disk. I want to use tensorflow's tf.data.Dataset to load my data into training pipeline. My code goes:

def _parse_file(path):
    image, label = *load pickle file*
    return image, label
paths = glob.glob('*.pkl')
print(len(paths))
dataset = tf.data.Dataset.from_tensor_slices(paths)
dataset = dataset.map(_parse_file)
iterator = dataset.make_one_shot_iterator()

Problem is I don't know how to implement the _parse_file fuction. The argument to this function, path, is of tensor type. I tried

def _parse_file(path):
    with tf.Session() as s:
        p = s.run(path)
        image, label = pickle.load(open(p, 'rb'))
    return image, label

and got error message:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
     [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

After some search on the Internet I still have no idea how to do it. I will be grateful to anyone providing me a hint.

like image 942
Zhao Chen Avatar asked Jun 15 '18 06:06

Zhao Chen


People also ask

What does TF data dataset From_tensor_slices do?

Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf. data.

Which are the three main methods of getting data into a TensorFlow program?

Feeding: Python code provides the data when running each step. Reading from files: an input pipeline reads the data from files at the beginning of a TensorFlow graph. Preloaded data: a constant or variable in the TensorFlow graph holds all the data (for small data sets).


1 Answers

I have solved this myself. I should use tf.py_func as in this doc.

like image 96
Zhao Chen Avatar answered Oct 03 '22 13:10

Zhao Chen