Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Loading ONNX Model in Java




I have a trained PyTorch model that I would now like to export to Caffe2 using ONNX. This part seems fairly simple and well documented. However, I now want to "load" that model into a Java program in order to perform predictions within my program (a Flink streaming application). What is the best way to do this? I haven't been able to find any documentation on the website describing how to do this.

like image 463
igodfried Avatar asked Nov 23 '17 23:11


People also ask

How do I load an ONNX model?

Checking an ONNX Model import onnx # Preprocessing: load the ONNX model model_path = "path/to/the/model. onnx" onnx_model = onnx. load(model_path) print(f"The model is:\n{onnx_model}") # Check the model try: onnx. checker.

Can we train ONNX model?

You can train an ONNX model using ORT and Pytorch.

How do I view an ONNX file?

Click on Open Model and specify ONNX or Prototxt. Once opened, the graph of the model is displayed. By clicking on the layer, you can see the kernel size of Convolution and the names of the INPUTS and OUTPUTS blobs.

Is ONNX faster than Tensorflow?

The onnx converted model will be faster than the vanilla tensorflow. I confirmed for onnx converted Pytorch BERT, it was faster than regular torchscript/torch BERT with multiple threads.

1 Answers

Currently it's a bit tricky but there is a way. You will need to use JavaCPP:

  • NGraph https://github.com/bytedeco/javacpp-presets/tree/master/ngraph
  • ONNX https://github.com/bytedeco/javacpp-presets/tree/master/onnx

I will use single_relu.onnx as example:

    //read ONNX
    byte[] bytes = Files.readAllBytes(Paths.get("single_relu.onnx"));
    ModelProto model = new ModelProto(); 
    ParseProtoFromBytes(model, new BytePointer(bytes), bytes.length); // parse ONNX -> protobuf model

    //preprocess model in any way you like (you can skip this step)
    StringVector passes = new StringVector("eliminate_nop_transpose", "eliminate_nop_pad", "fuse_consecutive_transposes", "fuse_transpose_into_gemm");
    Optimize(model, passes);
    ConvertVersion(model, 8);
    BytePointer serialized = model.SerializeAsString();

    //prepare nGraph backend
    Backend backend = Backend.create("CPU");
    Shape shape = new Shape(new SizeTVector(1,2 ));
    Tensor input =backend.create_tensor(f32(), shape);
    Tensor output =backend.create_tensor(f32(), shape);
    Function ng_function = import_onnx_model(serialized); // convert ONNX -> nGraph
    Executable exec = backend.compile(ng_function);
    exec.call(new NgraphTensorVector(output), new NgraphTensorVector(input));

    //collect result to array
    float[] r = new float[2];
    FloatPointer p = new FloatPointer(r);
    output.read(p, 0, r.length * 4);

    //print result
    for (int i = 0; i < shape.get(0); i++) {
        System.out.print(" [");
        for (int j = 0; j < shape.get(1); j++) {
            System.out.print(r[i * (int)shape.get(1) + j] + " ");
like image 78
alagris Avatar answered Sep 27 '22 20:09
