I use the tutorial code from https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py and the code works fine until I tried to make a prediction instead of just evaluate it. I tried to make another function for prediction that look like this (by just removing parameter y):
def input_fn_predict(data_file, num_epochs, shuffle):
"""Input builder function."""
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine="python",
skiprows=1)
# remove NaN elements
df_data = df_data.dropna(how="any", axis=0)
labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
return tf.estimator.inputs.pandas_input_fn( #removed paramter y
x=df_data,
batch_size=100,
num_epochs=num_epochs,
shuffle=shuffle,
num_threads=5)
And to call it like this:
predictions = m.predict(
input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
)
for i, p in enumerate(predictions):
print(i, p)
{'probabilities': array([ 0.78595656, 0.21404342], dtype=float32), 'logits': array([-1.3007226], dtype=float32), 'classes': array(['0'], dtype=object), 'class_ids': array([0]), 'logistic': array([ 0.21404341], dtype=float32)}
How do I read that?
Input functions take an arbitrary data source (in-memory data sets, streaming data, custom data format, and so on) and generate Tensors that can be supplied to TensorFlow models. More concretely, input functions are used to: Turn raw data sources into Tensors, and.
Used in the notebooks The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)
The “model_fn” parameter is a function that consumes the features, labels, mode and params in the following order: def model_fn(features, labels, mode, params): The Estimator will always supply those parameters when it executes the model function for training, evaluation or prediction.
TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks. TPUEstimator transforms a global batch size in params to a per-shard batch size when calling the input_fn and model_fn .
A TensorFlow estimator. An input function, typically generated by the input_fn () helper function. The path to a specific model checkpoint to be used for prediction.
The most basic mode is the prediction mode “ tf.estimator.ModeKeys.PREDICT ”, which as the name suggests is used to do predictions on data using the Estimator object. In this mode the “EstimatorSpec” expects a dictionary of tensors which will be executed and the results of which will be made available as numpy values to python.
There were answers though but all with Tensorflow version 1. In the answers, they have used a module Conrib to do the predictions but this module is entirely removed from TF V2.
This input function is used by the Estimator as an input for the model function. A quick reminder, the model function the estimator invokes during training, evaluation and prediction, should accept the following arguments as explained earlier:
You need to set shuffle=False
since to predict new label, you need to maintain data order.
Below is my code to run the prediction (I've tested it). The input file is like test data (in csv), but there is no label column.
def predict_input_fn(data_file):
global CSV_COLUMNS
CSV_COLUMNS = CSV_COLUMNS[:-1]
df_data = pd.read_csv(
tf.gfile.Open(data_file),
names=CSV_COLUMNS,
skipinitialspace=True,
engine='python',
skiprows=1
)
# remove NaN elements
df_data = df_data.dropna(how='any', axis=0)
return tf.estimator.inputs.pandas_input_fn(
x=df_data,
num_epochs=1,
shuffle=False
)
To call it:
predict_file_name = 'tutorials/data/adult.predict'
results = m.predict(
input_fn=predict_input_fn(predict_file_name)
)
for result in results:
print 'result: {}'.format(result)
The prediction result for one sample is below:
{
'probabilities': array([0.78595656, 0.21404342], dtype = float32),
'logits': array([-1.3007226], dtype = float32),
'classes': array(['0'], dtype = object),
'class_ids': array([0]),
'logistic': array([0.21404341], dtype = float32)
}
What each field means are
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With