In python you can simply pass a numpy array to predict()
to get predictions from your model. What is the equivalent using Java with a SavedModelBundle
?
model = tf.keras.models.Sequential([
# layers go here
])
model.compile(...)
model.fit(x_train, y_train)
predictions = model.predict(x_test_maxabs) # <= This line
SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
TensorFlow Python automatically convert your NumPy array to a tf.Tensor
. In TensorFlow Java, you manipulate tensors directly.
Now the SavedModelBundle
does not have a predict
method. You need to obtain the session and run it, using the SessionRunner
and feeding it with input tensors.
For example, based on the next generation of TF Java (https://github.com/tensorflow/java), your code endup looking like this (note that I'm taking a lot of assumptions here about x_test_maxabs
since your code sample does not explain clearly where it comes from):
try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
Tensor<TFloat32> output = model.session()
.runner()
.feed("input_name", input)
.fetch("output_name")
.run()
.expect(TFloat32.class)) {
float prediction = output.data().getFloat();
System.out.println("prediction = " + prediction);
}
}
If you are not sure what is the name of the input/output tensor in your graph, you can obtain programmatically by looking at the signature definition:
model.metaGraphDef().getSignatureDefMap().get("serving_default")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With