Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does the TensorFlow Estimator API take inputs as a lambda?

The tf.estimator API takes input "input functions" that return Datasets. For example, Estimator.train() takes an input_fn (documentation).

In the examples I've seen, whenever this function is supplied manually, it is an argumentless lambda.

Doesn't that mean that the function always returns the same value? Or is it invoked multiple times with no arguments? I wasn't able to find documentation about this. Why don't functions like train() just take input as a Dataset explicitly?

like image 783
jkff Avatar asked Apr 04 '18 03:04

jkff


2 Answers

Dataset objects are also backed by nodes in the computational graph. The Estimator constructs the computational graph on each all to train(), evaluate() etc. By doing so the Estimator API ensures that operations on different Estimator objects are isolated and Tensor/Dataset into independent graphs.

Some code pointers (for TensorFlow 1.7) in case you're interested:

  • Estimator.train() eventually invokes this
  • Which in turn invokes input_fn

Hope that helps.

like image 196
ash Avatar answered Sep 28 '22 20:09

ash


According to Tensorflow documentation:

"Estimators expect an input_fn to take no arguments. To work around this restriction, we use lambda to capture the arguments and provide the expected interface."

https://www.tensorflow.org/guide/datasets_for_estimators

like image 24
Michal Drygajlo Avatar answered Sep 28 '22 22:09

Michal Drygajlo