Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

"No Operation named [input] in the Graph" in Java

Following this Colab exeercise from Google's ML Crash Course, I generated a model in Python for the MNIST database. The code looks as follows:

import pandas as pd
import tensorflow as tf


def create_model(my_learning_rate):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.Input(shape=(28, 28), name='input'))
    model.add(tf.keras.layers.Flatten(input_shape=(28, 28)))
    model.add(tf.keras.layers.Dense(units=256, activation='relu'))
    model.add(tf.keras.layers.Dense(units=128, activation='relu'))
    model.add(tf.keras.layers.Dropout(rate=0.2))
    model.add(tf.keras.layers.Dense(units=10, activation='softmax', name='output'))
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=my_learning_rate),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model


def train_model(model, train_features, train_label, epochs,
                batch_size=None, validation_split=0.1):
    history = model.fit(x=train_features, y=train_label, batch_size=batch_size,
                        epochs=epochs, shuffle=True,
                        validation_split=validation_split)
    epochs = history.epoch
    hist = pd.DataFrame(history.history)
    return epochs, hist


if __name__ == '__main__':
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train_normalized = x_train / 255.0
    x_test_normalized = x_test / 255.0

    learning_rate = 0.003
    epochs = 50
    batch_size = 4000
    validation_split = 0.2

    my_model = create_model(learning_rate)
    epochs, hist = train_model(my_model, x_train_normalized, y_train,
                               epochs, batch_size, validation_split)

    my_model.save('my_model')

The model is saved to the "my_model" folder, as it should. Now I load it again in my Java program:

public class HelloTensorFlow {
    public static void main(final String[] args) {
        final String filePath = Paths.get("my_model").toAbsolutePath().toString();
        try (final SavedModelBundle b = SavedModelBundle.load(filePath, "serve")) {
            final Session sess = b.session();

            final Tensor<Float> x = Tensor.create(new float[1][28 * 28], Float.class);
            final List<Tensor<?>> run = sess.runner()
                    .feed("input", x)
                    .fetch("output")
                    .run();

            final float[] y = run.get(0).copyTo(new float[1]);
            System.out.println(y[0]);
        }
    }
}

The model is loaded but the runner does not work. When I execute the program, I get "No Operation named [input] in the Graph", even though my Input has this name. What am I doing wrong. I have the newest TensorFlow versions: 2.3.0 (Python) and 1.15.0 (Java).

like image 951
Rochen Avatar asked Oct 11 '25 08:10

Rochen


1 Answers

I solved it. TensorFlow 2 seems to have odd naming schemes, but using the MetaGraphDef, this can be deciphered. First, you need the org.tensorflow.proto dependency. Then, you can extract the information from the meta graph like so:

final MetaGraphDef metaGraphDef = MetaGraphDef.parseFrom(bundle.metaGraphDef());
final SignatureDef signatureDef = metaGraphDef.getSignatureDefMap().get("serving_default");

final TensorInfo inputTensorInfo = signatureDef.getInputsMap()
    .values()
    .stream()
    .filter(Objects::nonNull)
    .findFirst()
    .orElseThrow(() -> ...);

final TensorInfo outputTensorInfo = signatureDef.getOutputsMap()
    .values()
    .stream()
    .filter(Objects::nonNull)
    .findFirst()
    .orElseThrow(() -> ...);

Now you can feed the tensor you created into the name from inputTensorInfo.getName() and fetch the results from outputTensorInfo.getName().

like image 151
Rochen Avatar answered Oct 15 '25 13:10

Rochen