Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Loading a trained HDF5 model into Rust to make predictions

I trained a model to recognize digits using the MNIST dataset. The model has been trained in Python using TensorFlow and Keras with the output saved into a HDF5 file I named "sample_mnist.h5".

I would like to load the trained model from the HDF5 file into Rust to make predictions.

In Python, I could generate the model from the HDF5 and make predictions with the code:

model = keras.models.load_model("./sample_mnist.h5")
model.precict(test_input)  # assumes test_input is the correct input type for the model

What is a Rust equivalent of this Python snippet?

like image 648
bpmason1 Avatar asked Nov 17 '25 05:11

bpmason1


1 Answers

First off, you'll want to save the model in .pb format, not .hdf5, to port it over to Rust, as this format saves everything about the execution graph of the model necessary to reconstruct it outside of Python. There is an open pull request from user justnoxx on the TensorFlow Rust repo that shows how to do this for a simple model. The gist is that given some model in Python ...

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

classifier = Sequential()
classifier.add(Dense(5, activation='relu', name="test_in", input_dim=5)) # Named input
classifier.add(Dense(5, activation='relu'))
classifier.add(Dense(1, activation='sigmoid', name="test_out")) # Named output

classifier.compile(optimizer ='adam', loss='binary_crossentropy', metrics=['accuracy'])

classifier.fit([[0.1, 0.2, 0.3, 0.4, 0.5]], [[1]], batch_size=1, epochs=1);

classifier.save('examples/keras_single_input_saved_model', save_format='tf')

and our named input "test_in" and output "test_out" and their expected sizes, we can apply the saved model in Rust ...

use tensorflow::{Graph, SavedModelBundle, SessionOptions, SessionRunArgs, Tensor};

fn main() {

    // In this file test_in_input is being used while in the python script,
    // that generates the saved model from Keras model it has a name "test_in".
    // For multiple inputs _input is not being appended to signature input parameter name.
    let signature_input_parameter_name = "test_in_input";
    let signature_output_parameter_name = "test_out";

    // Initialize save_dir, input tensor, and an empty graph
    let save_dir =
        "examples/keras_single_input_saved_model";
    let tensor: Tensor<f32> = Tensor::new(&[1, 5])
        .with_values(&[0.1, 0.2, 0.3, 0.4, 0.5])
        .expect("Can't create tensor");
    let mut graph = Graph::new();

    // Load saved model bundle (session state + meta_graph data)
    let bundle = 
        SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, save_dir)
        .expect("Can't load saved model");

    // Get the session from the loaded model bundle
    let session = &bundle.session;

    // Get signature metadata from the model bundle
    let signature = bundle
        .meta_graph_def()
        .get_signature("serving_default")
        .unwrap();

    // Get input/output info
    let input_info = signature.get_input(signature_input_parameter_name).unwrap();
    let output_info = signature
        .get_output(signature_output_parameter_name)
        .unwrap();

    // Get input/output ops from graph
    let input_op = graph
        .operation_by_name_required(&input_info.name().name)
        .unwrap();
    let output_op = graph
        .operation_by_name_required(&output_info.name().name)
        .unwrap();
    
    // Manages inputs and outputs for the execution of the graph
    let mut args = SessionRunArgs::new();
    args.add_feed(&input_op, 0, &tensor); // Add any inputs

    let out = args.request_fetch(&output_op, 0); // Request outputs

    // Run model
    session.run(&mut args) // Pass to session to run
        .expect("Error occurred during calculations");

    // Fetch outputs after graph execution
    let out_res: f32 = args.fetch(out).unwrap()[0];

    println!("Results: {:?}", out_res);
}
like image 86
Ian Graham Avatar answered Nov 19 '25 20:11

Ian Graham