Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Inference with a model trained with tf.Dataset

I have trained a model using the tf.data.Dataset API, so my training code looks something like this

with graph.as_default():
    dataset = tf.data.TFRecordDataset(tfrecord_path)
    dataset = dataset.map(scale_features, num_parallel_calls=n_workers)
    dataset = dataset.shuffle(10000)
    dataset = dataset.padded_batch(batch_size, padded_shapes={...})
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_dataset.output_types,
                                                   train_dataset.output_shapes)
    batch = iterator.get_next()
    ... 
    # Model code
    ...
    iterator = dataset.make_initializable_iterator()

with tf.Session(graph=graph) as sess:
    train_handle = sess.run(iterator.string_handle())
    sess.run(tf.global_variables_initializer())
    for epoch in range(n_epochs):
        sess.run(train_iterator.initializer)
        while True:
            try:
                sess.run(optimizer, feed_dict={handle: train_handle})
            except tf.errors.OutOfRangeError:
               break

Now after the model is trained I want to infer on examples that are not in the datasets and I am not sure how to go about doing it.

Just to be clear, I know how to use another dataset, for example I just pass a handle to my test set upon testing.

The question is about given the scaling scheme and the fact that the network expects a handle, if I want to make a prediction to a new example which is not written to a TFRecord, how would I go about doing that?

If I'd modify the batch I'd be responsible for the scaling beforehand which is something I would like to avoid if possible.

So how should I infer single examples from a model traiend the tf.data.Dataset way? (This is not for production purposes it is for evaluating what will happen if I change specific features)

like image 718
bluesummers Avatar asked Jun 20 '18 05:06

bluesummers


People also ask

What does TF data dataset do?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What does TF data dataset From_tensor_slices do?

from_tensor_slices() It removes the first dimension and use it as a dataset dimension.


1 Answers

actually there is a tensor name called "IteratorGetNext:0" in the graph when you use dataset api, so you can using following way to directly set input:

#get a tensor from a graph 
input tensor : input = graph.get_tensor_by_name("IteratorGetNext:0")
# difine the target tensor you want evaluate for your prediction
prediction tensor: predictions=...
# finally call session to run 
then sess.run(predictions, feed_dict={input: np.asanyarray(images), ...})
like image 62
zhao yufei Avatar answered Sep 21 '22 14:09

zhao yufei