Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Java Tensorflow + Keras Equivalent of model.predict()

In python you can simply pass a numpy array to predict() to get predictions from your model. What is the equivalent using Java with a SavedModelBundle?

Python

model = tf.keras.models.Sequential([
  # layers go here
])
model.compile(...)
model.fit(x_train, y_train)

predictions = model.predict(x_test_maxabs) # <= This line 

Java

SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
like image 629
alexdriedger Avatar asked Jan 01 '23 01:01

alexdriedger


1 Answers

TensorFlow Python automatically convert your NumPy array to a tf.Tensor. In TensorFlow Java, you manipulate tensors directly.

Now the SavedModelBundle does not have a predict method. You need to obtain the session and run it, using the SessionRunner and feeding it with input tensors.

For example, based on the next generation of TF Java (https://github.com/tensorflow/java), your code endup looking like this (note that I'm taking a lot of assumptions here about x_test_maxabs since your code sample does not explain clearly where it comes from):

try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
    try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
        Tensor<TFloat32> output = model.session()
            .runner()
            .feed("input_name", input)
            .fetch("output_name")
            .run()
            .expect(TFloat32.class)) {

        float prediction = output.data().getFloat();
        System.out.println("prediction = " + prediction);
    }        
}

If you are not sure what is the name of the input/output tensor in your graph, you can obtain programmatically by looking at the signature definition:

model.metaGraphDef().getSignatureDefMap().get("serving_default")
like image 137
Karl Lessard Avatar answered Jan 12 '23 15:01

Karl Lessard