Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Predict in Tensorflow estimator using input fn

I use the tutorial code from https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py and the code works fine until I tried to make a prediction instead of just evaluate it. I tried to make another function for prediction that look like this (by just removing parameter y):

def input_fn_predict(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn( #removed paramter y
      x=df_data,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)

And to call it like this:

predictions = m.predict(
      input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
  )
  for i, p in enumerate(predictions):
      print(i, p)
  • Am I doing it right?
  • Why do I get the prediction 81404 instead of 16282(number of line in test file)?
  • Each line contains something like this:

{'probabilities': array([ 0.78595656, 0.21404342], dtype=float32), 'logits': array([-1.3007226], dtype=float32), 'classes': array(['0'], dtype=object), 'class_ids': array([0]), 'logistic': array([ 0.21404341], dtype=float32)}

How do I read that?

like image 878
Gregorius Edwadr Avatar asked Oct 26 '17 07:10

Gregorius Edwadr


People also ask

What is input function in TensorFlow?

Input functions take an arbitrary data source (in-memory data sets, streaming data, custom data format, and so on) and generate Tensors that can be supplied to TensorFlow models. More concretely, input functions are used to: Turn raw data sources into Tensors, and.

What is the usage of TF estimator estimator in TensorFlow?

Used in the notebooks The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)

What is Model_fn?

The “model_fn” parameter is a function that consumes the features, labels, mode and params in the following order: def model_fn(features, labels, mode, params): The Estimator will always supply those parameters when it executes the model function for training, evaluation or prediction.

What is TPUEstimator?

TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks. TPUEstimator transforms a global batch size in params to a per-shard batch size when calling the input_fn and model_fn .

What is a TensorFlow estimator?

A TensorFlow estimator. An input function, typically generated by the input_fn () helper function. The path to a specific model checkpoint to be used for prediction.

What is the prediction mode in TensorFlow?

The most basic mode is the prediction mode “ tf.estimator.ModeKeys.PREDICT ”, which as the name suggests is used to do predictions on data using the Estimator object. In this mode the “EstimatorSpec” expects a dictionary of tensors which will be executed and the results of which will be made available as numpy values to python.

Does TensorFlow V2 use conrib to make predictions?

There were answers though but all with Tensorflow version 1. In the answers, they have used a module Conrib to do the predictions but this module is entirely removed from TF V2.

What is the estimator input function?

This input function is used by the Estimator as an input for the model function. A quick reminder, the model function the estimator invokes during training, evaluation and prediction, should accept the following arguments as explained earlier:


1 Answers

You need to set shuffle=False since to predict new label, you need to maintain data order.

Below is my code to run the prediction (I've tested it). The input file is like test data (in csv), but there is no label column.



    def predict_input_fn(data_file):
        global CSV_COLUMNS
        CSV_COLUMNS = CSV_COLUMNS[:-1]
        df_data = pd.read_csv(
            tf.gfile.Open(data_file),
            names=CSV_COLUMNS,
            skipinitialspace=True,
            engine='python',
            skiprows=1
        )

        # remove NaN elements
        df_data = df_data.dropna(how='any', axis=0)

        return tf.estimator.inputs.pandas_input_fn(
            x=df_data,
            num_epochs=1,
           shuffle=False
        )

To call it:



    predict_file_name = 'tutorials/data/adult.predict'
    results = m.predict(
        input_fn=predict_input_fn(predict_file_name)
    )
    for result in results:
        print 'result: {}'.format(result)

The prediction result for one sample is below:



    {
        'probabilities': array([0.78595656, 0.21404342], dtype = float32),
        'logits': array([-1.3007226], dtype = float32),
        'classes': array(['0'], dtype = object),
        'class_ids': array([0]),
        'logistic': array([0.21404341], dtype = float32)
    }

What each field means are

  • 'probabilities': array([0.78595656, 0.21404342], dtype = float32).
    It predicts the output label is class-0 (in this case <=50K) with confidence 0.78595656
  • 'logits': array([-1.3007226], dtype = float32)
    The value of z in equation 1/(1+e^(-z)) is -1.3.
  • 'classes': array(['0'], dtype = object)
    The class label is 0
like image 104
impulse Avatar answered Nov 11 '22 20:11

impulse